├── .devcontainer ├── Dockerfile-client ├── Dockerfile-cuda ├── Dockerfile-prod-cpp ├── Dockerfile-prod-vllm ├── Dockerfile-vllm ├── client │ └── devcontainer.json ├── common.dockerfile ├── cuda-settings.dockerfile ├── cuda │ └── devcontainer.json ├── vllm-requirements.txt └── vllm │ └── devcontainer.json ├── .gitattributes ├── .github └── workflows │ ├── aicirt-release.yml │ ├── aicirt.yml │ ├── codeql.yml │ ├── links.yml │ └── rllm-cuda.yml ├── .gitignore ├── .gitmodules ├── .vscode ├── launch.json ├── settings.json └── tasks.json ├── CODE_OF_CONDUCT.md ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── NOTICE.md ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── TRANSPARENCY.md ├── aici.sh ├── aicirt ├── Cargo.toml ├── README.md ├── src │ ├── api.rs │ ├── bench.rs │ ├── futexshm.rs │ ├── hostimpl.rs │ ├── lib.rs │ ├── macos.rs │ ├── main.rs │ ├── moduleinstance.rs │ ├── msgchannel.rs │ ├── semaphore.rs │ ├── shm.rs │ └── worker.rs └── vllm.md ├── controllers ├── aici_abi │ ├── .cargo │ │ └── config.toml │ ├── Cargo.toml │ ├── README.md │ ├── grammars │ │ ├── c.y │ │ ├── json0.guidance │ │ └── sample.c │ ├── implementation.md │ └── src │ │ ├── cfg.rs │ │ ├── dlex.rs │ │ ├── host.rs │ │ ├── lex.rs │ │ ├── lib.rs │ │ ├── rx.rs │ │ ├── substring.rs │ │ └── yesno.rs ├── aici_native │ ├── Cargo.toml │ ├── README.md │ └── src │ │ ├── bintokens.rs │ │ ├── lib.rs │ │ ├── log.rs │ │ └── variables.rs ├── declctrl │ ├── .cargo │ │ └── config.toml │ ├── Cargo.toml │ ├── README.md │ ├── arg.json │ ├── arg2.json │ ├── genarg.py │ ├── native.sh │ ├── size.js │ ├── src │ │ └── declctrl.rs │ └── wasm.sh ├── jsctrl │ ├── .cargo │ │ └── config.toml │ ├── Cargo.toml │ ├── README.md │ ├── build.rs │ ├── gen-dts.mjs │ ├── samples │ │ ├── aici-types.d.ts │ │ ├── count-yellow.js │ │ ├── hello.js │ │ ├── hellots.ts │ │ ├── mapping.js │ │ ├── schema.js │ │ ├── test.ts │ │ └── tsconfig.json │ ├── src │ │ └── jsctrl.rs │ └── ts │ │ ├── aici.ts │ │ ├── native.d.ts │ │ └── tsconfig.json ├── llguidance_ctrl │ ├── .cargo │ │ └── config.toml │ ├── Cargo.toml │ ├── go.sh │ ├── run_g.py │ ├── src │ │ └── gctrl.rs │ └── text_req.py ├── pyctrl │ ├── .cargo │ │ └── config.toml │ ├── Cargo.toml │ ├── Lib │ │ ├── LICENSE │ │ ├── _collections_abc.py │ │ ├── _py_abc.py │ │ ├── _weakrefset.py │ │ ├── abc.py │ │ ├── collections │ │ │ ├── __init__.py │ │ │ ├── _defaultdict.py │ │ │ └── abc.py │ │ ├── contextlib.py │ │ ├── copy.sh │ │ ├── copyreg.py │ │ ├── enum.py │ │ ├── functools.py │ │ ├── genericpath.py │ │ ├── keyword.py │ │ ├── operator.py │ │ ├── os.py │ │ ├── posixpath.py │ │ ├── re.py │ │ ├── reprlib.py │ │ ├── sre_compile.py │ │ ├── sre_constants.py │ │ ├── sre_parse.py │ │ ├── stat.py │ │ ├── types.py │ │ └── typing.py │ ├── README.md │ ├── build.rs │ ├── driver.py │ ├── samples │ │ ├── cfg.py │ │ ├── forkbomb.py │ │ ├── idents.py │ │ ├── phi.py │ │ ├── substr.py │ │ ├── test.py │ │ ├── tla.py │ │ ├── warsaw.py │ │ └── yesno.py │ ├── src │ │ └── pyctrl.rs │ └── wasm.sh └── uppercase │ ├── .cargo │ └── config.toml │ ├── Cargo.toml │ ├── README.md │ └── src │ └── main.rs ├── docs ├── FAQ.md ├── REST.md ├── aicirt-proto.md └── proxy.md ├── py ├── promptlib │ ├── README.md │ ├── __init__.py │ ├── notebooks │ │ ├── basics_tutorial.ipynb │ │ └── information_flow_examples.ipynb │ └── promptlib │ │ ├── __init__.py │ │ ├── aici.py │ │ ├── gen.py │ │ ├── prompt.py │ │ └── vars.py ├── pyaici │ ├── __init__.py │ ├── _vllm_protocol.py │ ├── _vllm_runner.py │ ├── _vllm_sampling_ctrl.py │ ├── ast_.py │ ├── cli.py │ ├── comms.py │ ├── jssrc.py │ ├── rest.py │ ├── server.py │ ├── server_native.py │ ├── util.py │ ├── vllm.py │ └── vllm_server.py ├── setup.py └── tests │ ├── conftest.py │ ├── declctrl_test.py │ └── test-prompt.txt ├── pytest.ini ├── rllm ├── llama-cpp-low │ ├── Cargo.toml │ ├── build.rs │ └── src │ │ ├── lib.rs │ │ └── main.rs ├── rllm-base │ ├── Cargo.toml │ └── src │ │ ├── config.rs │ │ ├── engine.rs │ │ ├── exec.rs │ │ ├── expected.rs │ │ ├── iface.rs │ │ ├── lib.rs │ │ ├── logits.rs │ │ ├── scheduler.rs │ │ ├── seq.rs │ │ ├── server │ │ ├── api.rs │ │ ├── completion.rs │ │ ├── mod.rs │ │ └── openai │ │ │ ├── LICENSE │ │ │ ├── mod.rs │ │ │ ├── requests.rs │ │ │ └── responses.rs │ │ └── util.rs ├── rllm-cuda │ ├── Cargo.toml │ ├── README.md │ ├── expected │ │ ├── codellama │ │ │ ├── args.txt │ │ │ ├── cats.safetensors │ │ │ ├── lighthouse.safetensors │ │ │ └── primes.safetensors │ │ ├── codellama34 │ │ │ ├── args.txt │ │ │ ├── cats.safetensors │ │ │ ├── lighthouse.safetensors │ │ │ └── primes.safetensors │ │ ├── go.sh │ │ ├── llama │ │ │ ├── args.txt │ │ │ ├── cats.safetensors │ │ │ ├── lighthouse.safetensors │ │ │ └── primes.safetensors │ │ ├── orca │ │ │ ├── args.txt │ │ │ ├── cats.safetensors │ │ │ ├── lighthouse.safetensors │ │ │ └── primes.safetensors │ │ ├── phi-1_5 │ │ │ ├── args.txt │ │ │ ├── cats.safetensors │ │ │ ├── lighthouse.safetensors │ │ │ └── primes.safetensors │ │ └── phi-2 │ │ │ ├── args.txt │ │ │ ├── cats.safetensors │ │ │ ├── lighthouse.safetensors │ │ │ └── primes.safetensors │ ├── scripts │ │ ├── cmp2.sh │ │ ├── convert.py │ │ ├── tensorcmp.py │ │ └── testgen.py │ ├── server.sh │ ├── src │ │ ├── llm │ │ │ ├── config.rs │ │ │ ├── kernels.rs │ │ │ ├── llama.rs │ │ │ ├── loader.rs │ │ │ ├── mod.rs │ │ │ ├── paged │ │ │ │ ├── batch_info.rs │ │ │ │ ├── blocks.rs │ │ │ │ ├── cache_engine.rs │ │ │ │ ├── cuda_stub.rs │ │ │ │ └── mod.rs │ │ │ ├── phi.rs │ │ │ ├── refkernels.rs │ │ │ ├── tmodel.rs │ │ │ └── util.rs │ │ └── rllm-cuda.rs │ └── test.sh ├── rllm-llamacpp │ ├── Cargo.toml │ ├── README.md │ ├── server.sh │ └── src │ │ ├── llamacpp │ │ ├── blocks.rs │ │ ├── loader.rs │ │ ├── mod.rs │ │ ├── seqid.rs │ │ └── tmodel.rs │ │ └── rllm-llamacpp.rs └── tch-cuda │ ├── Cargo.toml │ ├── README.md │ ├── build.rs │ ├── convhd.js │ ├── kernels │ ├── cuda.cpp │ ├── flash_attn │ │ ├── AUTHORS │ │ ├── LICENSE │ │ ├── block_info.h │ │ ├── flash.h │ │ ├── flash_api.cpp │ │ ├── flash_fwd_kernel.h │ │ ├── flash_fwd_launch_template.h │ │ ├── flash_fwd_split_hdim128_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim128_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim160_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim160_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim192_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim192_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim224_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim224_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim256_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim256_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim32_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim32_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim64_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim64_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim96_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim96_fp16_sm80.cu │ │ ├── kernel_traits.h │ │ ├── philox.cuh │ │ ├── softmax.h │ │ ├── static_switch.h │ │ └── utils.h │ ├── vllm │ │ ├── LICENSE │ │ ├── activation_kernels.cu │ │ ├── attention │ │ │ ├── attention_dtypes.h │ │ │ ├── attention_generic.cuh │ │ │ ├── attention_kernels.cu │ │ │ ├── attention_utils.cuh │ │ │ ├── dtype_bfloat16.cuh │ │ │ ├── dtype_float16.cuh │ │ │ └── dtype_float32.cuh │ │ ├── cache.h │ │ ├── cache_kernels.cu │ │ ├── cuda_compat.h │ │ ├── cuda_utils.h │ │ ├── cuda_utils_kernels.cu │ │ ├── dispatch_utils.h │ │ ├── layernorm_kernels.cu │ │ ├── ops.h │ │ ├── pos_encoding_kernels.cu │ │ ├── pybind.cpp │ │ ├── quantization │ │ │ ├── awq │ │ │ │ ├── dequantize.cuh │ │ │ │ └── gemm_kernels.cu │ │ │ └── squeezellm │ │ │ │ └── quant_cuda_kernel.cu │ │ └── reduction_utils.cuh │ └── vllm_bindings.cpp │ ├── src │ ├── event.rs │ ├── lib.rs │ └── stream.rs │ └── tests │ └── flash_attn_tests.rs ├── rustfmt.toml └── scripts ├── aici.sh ├── bench-comms.sh ├── bench-earley.sh ├── bench-server.sh ├── bump.sh ├── checkdeps.js ├── checklinks.js ├── checklinks.sh ├── docker-build.sh ├── docker-cpp-run.sh ├── docker-run.sh ├── docker-vllm-build.sh ├── hf.sh ├── host.sh ├── kill-rt.sh ├── kill-server.sh ├── py ├── bench_server.py ├── hgwells.txt ├── run_hf.py ├── run_vllm.py └── vllm_server.py ├── random ├── disasm.js ├── parse-tokenizer-automaton.js └── tokenizer-stats.js ├── release.sh ├── sample-uppercase.sh ├── sample-yesno.sh ├── tag-ctrls.sh ├── test-all.sh ├── test-guidance.sh ├── test-infer1.sh ├── test-jsctrl.sh ├── test-llg1.sh ├── test-parallel.sh ├── test-pyctrl.sh ├── tokenizer-stats.js ├── upload-all.sh ├── vllm-init.sh └── vllm-server.sh /.devcontainer/Dockerfile-client: -------------------------------------------------------------------------------- 1 | # syntax = edrevo/dockerfile-plus 2 | # ^^^ this line enables the INCLUDE+ directive 3 | 4 | FROM mcr.microsoft.com/devcontainers/cpp:ubuntu-22.04 5 | 6 | INCLUDE+ common.dockerfile 7 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile-cuda: -------------------------------------------------------------------------------- 1 | # syntax = edrevo/dockerfile-plus 2 | # ^^^ this line enables the INCLUDE+ directive 3 | 4 | FROM nvcr.io/nvidia/pytorch:23.09-py3 5 | 6 | INCLUDE+ cuda-settings.dockerfile 7 | INCLUDE+ common.dockerfile 8 | 9 | RUN pip install torch==2.1.0 nvidia-cuda-runtime 10 | # the .so file seems to be missing 11 | RUN ln -s /usr/local/lib/python3.10/dist-packages/nvidia/cuda_runtime/lib/libcudart.so{.12,} 12 | 13 | # perf tool 14 | RUN apt-get install -y linux-tools-`uname -r` 15 | 16 | RUN source /usr/local/nvm/nvm.sh && npm install -g yarn 17 | 18 | # we mostly need guidance deps 19 | RUN pip install guidance 20 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile-prod-cpp: -------------------------------------------------------------------------------- 1 | # docker container with aicirt and CPU-only rllm (llama.cpp) 2 | # TAG: aici/rllm-llamacpp 3 | 4 | FROM bitnami/minideb:bookworm 5 | 6 | RUN apt-get update && apt-get install -y libssl3 && apt-get clean 7 | 8 | # install aicirt and rllm 9 | COPY target/dist/aicirt/aicirt /usr/bin/aicirt 10 | COPY target/dist/rllm-llamacpp/rllm-llamacpp /usr/bin/rllm-llamacpp 11 | 12 | RUN mkdir /workspace 13 | 14 | # copy the controllers 15 | WORKDIR /workspace 16 | 17 | # RUN mkdir wasm 18 | # COPY target/dist/aici_*.wasm wasm/ 19 | # # "upload" and tag the controllers 20 | # RUN aicirt --module wasm/aici_guidance_ctrl.wasm --tag guidance 21 | # RUN aicirt --module wasm/aici_pyctrl.wasm --tag pyctrl --gh-module gh:microsoft/aici/pyctrl 22 | # RUN aicirt --module wasm/aici_jsctrl.wasm --tag jsctrl --gh-module gh:microsoft/aici/jsctrl 23 | 24 | ENV RUST_LOG info,tokenizers=error 25 | 26 | ENTRYPOINT ["rllm-llamacpp", "--aicirt=/usr/bin/aicirt"] 27 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile-prod-vllm: -------------------------------------------------------------------------------- 1 | FROM rust:1.75.0-bookworm AS aicirt 2 | 3 | WORKDIR /workspace 4 | 5 | RUN rustup target add wasm32-wasi 6 | RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - 7 | RUN apt-get install -y nodejs 8 | 9 | COPY controllers controllers 10 | COPY aicirt aicirt 11 | COPY scripts scripts 12 | COPY py/pyaici py/pyaici 13 | COPY Cargo.toml Cargo.lock /workspace/ 14 | 15 | # make sure we rebuild these 16 | RUN rm -rf controllers/jsctrl/samples/dist controllers/jsctrl/ts/dist 17 | 18 | RUN grep -v rllm Cargo.toml > Cargo.toml.tmp && mv Cargo.toml.tmp Cargo.toml 19 | 20 | RUN --mount=type=cache,target=/usr/local/cargo/git \ 21 | --mount=type=cache,target=/usr/local/cargo/registry \ 22 | --mount=type=cache,target=/workspace/target \ 23 | cargo fetch 24 | 25 | ARG tag=latest 26 | ENV BUILD_TAG=$tag 27 | RUN --mount=type=cache,target=/usr/local/cargo/git \ 28 | --mount=type=cache,target=/usr/local/cargo/registry \ 29 | --mount=type=cache,target=/workspace/target \ 30 | SKIP_LLAMA_CPP=1 \ 31 | ./scripts/release.sh && cp -r /workspace/target/dist /workspace/ 32 | 33 | 34 | FROM vllm/vllm-openai as vllm-base 35 | 36 | # install pyaici pre-requisites 37 | RUN pip install posix_ipc ujson 38 | 39 | # install pyaici 40 | RUN mkdir /tmp/pyaici 41 | COPY py/setup.py /tmp/pyaici/ 42 | COPY py/pyaici /tmp/pyaici/pyaici 43 | RUN cd /tmp/pyaici && pip install . && rm -rf /tmp/pyaici 44 | 45 | # patch the vllm python files 46 | RUN --mount=source=py/vllm/vllm,target=/tmp/vllm \ 47 | (cd /tmp/vllm && find . -name '*.py' -print0 | tar -cf - --null -T -) | \ 48 | tar -C /usr/local/lib/python3.10/dist-packages/vllm -xf - 49 | 50 | # copy the controllers and aicirt 51 | WORKDIR /vllm-workspace 52 | RUN mkdir wasm 53 | 54 | RUN --mount=from=aicirt,source=/workspace/dist,target=/tmp/dist \ 55 | cp /tmp/dist/aicirt/aicirt /usr/bin/aicirt && \ 56 | cp /tmp/dist/aici_*.wasm wasm/ 57 | 58 | RUN ls -l wasm/ 59 | 60 | ENV RUST_LOG info,tokenizers=error 61 | 62 | ENTRYPOINT ["python3", "-m", "pyaici.vllm_server", "--enforce-eager", "--use-v2-block-manager", "--enable-chunked-prefill", "--aici-rt=/usr/bin/aicirt", "-A--restricted", "-A--wasm-timer-resolution-us=10"] 63 | 64 | FROM vllm-base AS vllm-guidance 65 | 66 | RUN aicirt --module wasm/aici_llguidance_ctrl.wasm --tag llguidance 67 | 68 | FROM vllm-base as vllm-general 69 | 70 | RUN aicirt --module wasm/aici_llguidance_ctrl.wasm --tag llguidance 71 | RUN aicirt --module wasm/aici_pyctrl.wasm --tag pyctrl --gh-module gh:microsoft/aici/pyctrl 72 | RUN aicirt --module wasm/aici_jsctrl.wasm --tag jsctrl --gh-module gh:microsoft/aici/jsctrl 73 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile-vllm: -------------------------------------------------------------------------------- 1 | # syntax = edrevo/dockerfile-plus 2 | # ^^^ this line enables the INCLUDE+ directive 3 | 4 | FROM nvcr.io/nvidia/pytorch:23.10-py3 5 | 6 | INCLUDE+ cuda-settings.dockerfile 7 | INCLUDE+ common.dockerfile 8 | 9 | ENV NVTE_FRAMEWORK=pytorch 10 | 11 | COPY vllm-requirements.txt /tmp/requirements.txt 12 | RUN pip install -r /tmp/requirements.txt 13 | 14 | # Uninstall the transformer engine that comes with the base image. 15 | # Otherwise it will cause error when importing vLLM (LLAVA models). 16 | RUN pip uninstall -y transformer_engine 17 | 18 | # crashes docker? 19 | # RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers 20 | 21 | # takes forever! 22 | # RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable 23 | # RUN pip install typing_extensions==4.5.0 24 | # RUN pip install -U flash-attn 25 | 26 | # RUN pip install torch==2.1.0 nvidia-cuda-runtime 27 | # the .so file seems to be missing 28 | RUN ln -s /usr/local/lib/python3.10/dist-packages/nvidia/cuda_runtime/lib/libcudart.so{.12,} 29 | 30 | # perf tool 31 | RUN apt-get install -y linux-tools-`uname -r` 32 | 33 | RUN source /usr/local/nvm/nvm.sh && npm install -g yarn 34 | -------------------------------------------------------------------------------- /.devcontainer/client/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json 2 | { 3 | "name": "AICI Client-side", 4 | "build": { 5 | "dockerfile": "../Dockerfile-client", 6 | "context": ".." 7 | }, 8 | "customizations": { 9 | "vscode": { 10 | "extensions": [ 11 | "ms-python.python", 12 | "ms-python.black-formatter", 13 | "1YiB.rust-bundle", 14 | "dtsvet.vscode-wasm", 15 | "ms-vscode.cpptools", 16 | "esbenp.prettier-vscode", 17 | "streetsidesoftware.code-spell-checker", 18 | "GitHub.copilot" 19 | ] 20 | } 21 | }, 22 | "remoteUser": "root", 23 | "containerUser": "root", 24 | "mounts": [ 25 | "source=profile,target=/root,type=volume", 26 | "target=/root/.vscode-server,type=volume" 27 | ] 28 | } -------------------------------------------------------------------------------- /.devcontainer/common.dockerfile: -------------------------------------------------------------------------------- 1 | # makes it easier to diagnose ccache issues 2 | ENV CCACHE_DEBUG="1" 3 | 4 | # need git 2.41 for GCM/Github EMU account switching 5 | # https://askubuntu.com/questions/568591/how-do-i-install-the-latest-version-of-git-with-apt 6 | RUN apt-get update && apt-get install -y software-properties-common 7 | RUN apt-add-repository ppa:git-core/ppa 8 | RUN apt-get update && apt-get install -y git 9 | 10 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 11 | build-essential ca-certificates ccache \ 12 | cmake curl libjpeg-dev libpng-dev \ 13 | strace linux-tools-common linux-tools-generic \ 14 | llvm-dev libclang-dev clang ccache apache2-utils git-lfs \ 15 | screen bsdmainutils pip python3-dev python-is-python3 \ 16 | nodejs npm pkg-config 17 | 18 | RUN pip install pytest pytest-forked ujson posix_ipc numpy requests 19 | 20 | # RUN curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz \ 21 | # | tar zxf - --strip-components=1 -C /usr/local 22 | 23 | RUN cd /tmp && \ 24 | curl -L https://github.com/WebAssembly/wabt/releases/download/1.0.33/wabt-1.0.33.tar.xz | tar Jxf - && \ 25 | cd wabt-1.0.33 && make gcc-release && cp -v bin/wasm-* /usr/bin && cd .. && rm -rf wabt-1.0.33 26 | 27 | ENV RUSTUP_HOME=/usr/local/rustup \ 28 | CARGO_HOME=/usr/local/cargo \ 29 | PATH=/usr/local/cargo/bin:$PATH \ 30 | RUST_VERSION=1.75.0 31 | 32 | RUN curl https://sh.rustup.rs -sSf | sh -s -- \ 33 | -y --no-modify-path --profile minimal --default-toolchain $RUST_VERSION 34 | RUN rustup target add wasm32-wasi 35 | RUN rustup component add rustfmt 36 | 37 | # run as root please; note that settings in devcontainer.json are also needed... 38 | USER root 39 | -------------------------------------------------------------------------------- /.devcontainer/cuda-settings.dockerfile: -------------------------------------------------------------------------------- 1 | # A100: 2 | ENV TORCH_CUDA_ARCH_LIST="8.0" 3 | ENV CUDA_COMPUTE_CAP="80" 4 | 5 | # candle is slow without caching; pytorch image has this on by default 6 | ENV CUDA_CACHE_DISABLE="" 7 | 8 | ENV LIBTORCH_USE_PYTORCH="1" 9 | ENV LIBTORCH_BYPASS_VERSION_CHECK="1" 10 | -------------------------------------------------------------------------------- /.devcontainer/cuda/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json 2 | { 3 | "name": "AICI with CUDA", 4 | "build": { 5 | "dockerfile": "../Dockerfile-cuda", 6 | "context": ".." 7 | }, 8 | "runArgs": [ 9 | "--privileged", 10 | "--gpus", 11 | "all", 12 | "--shm-size=8g" 13 | ], 14 | "mounts": [ 15 | "source=profile,target=/root,type=volume", 16 | "target=/root/.vscode-server,type=volume" 17 | ], 18 | "customizations": { 19 | "vscode": { 20 | "extensions": [ 21 | "ms-python.python", 22 | "ms-python.black-formatter", 23 | "1YiB.rust-bundle", 24 | "dtsvet.vscode-wasm", 25 | "ms-vscode.cpptools", 26 | "esbenp.prettier-vscode", 27 | "streetsidesoftware.code-spell-checker", 28 | "GitHub.copilot" 29 | ] 30 | } 31 | }, 32 | "forwardPorts": [ 33 | 4242 34 | ] 35 | } -------------------------------------------------------------------------------- /.devcontainer/vllm-requirements.txt: -------------------------------------------------------------------------------- 1 | # vllm: requirements.txt 2 | ninja # For faster builds. 3 | psutil 4 | ray >= 2.9 5 | sentencepiece # Required for LLaMA tokenizer. 6 | numpy 7 | torch == 2.1.2 8 | transformers >= 4.37.0 # Required for Qwen2 9 | xformers == 0.0.23.post1 # Required for CUDA 12.1. 10 | fastapi 11 | uvicorn[standard] 12 | pydantic >= 2.0 # Required for OpenAI server. 13 | aioprometheus[starlette] 14 | pynvml == 11.5.0 15 | triton >= 2.1.0 16 | cupy-cuda12x == 12.3.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. 17 | 18 | # vllm: requirements-dev.txt 19 | # formatting 20 | yapf==0.32.0 21 | toml==0.10.2 22 | ruff==0.1.5 23 | 24 | # type checking 25 | mypy==0.991 26 | types-PyYAML 27 | types-requests 28 | types-setuptools 29 | 30 | # testing 31 | pytest 32 | pytest-forked 33 | pytest-asyncio 34 | httpx 35 | einops # required for MPT 36 | flash_attn # required for HuggingFace's llama implementation 37 | openai 38 | requests 39 | # ray - XXX 40 | 41 | # vllm: requirements-build.txt 42 | # Should be mirrored in pyproject.toml 43 | ninja 44 | packaging 45 | setuptools>=49.4.0 46 | # torch==2.1.2 - XXX 47 | wheel 48 | 49 | # non-vllm: 50 | ujson 51 | posix_ipc 52 | accelerate 53 | fschat 54 | -------------------------------------------------------------------------------- /.devcontainer/vllm/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json 2 | { 3 | "name": "AICI with CUDA and vLLM (experimental)", 4 | "build": { 5 | "dockerfile": "../Dockerfile-vllm", 6 | "context": ".." 7 | }, 8 | "runArgs": [ 9 | "--privileged", 10 | "--gpus", 11 | "all", 12 | "--shm-size=8g" 13 | ], 14 | "mounts": [ 15 | "source=profile,target=/root,type=volume", 16 | "target=/root/.vscode-server,type=volume" 17 | ], 18 | "customizations": { 19 | "vscode": { 20 | "extensions": [ 21 | "ms-python.python", 22 | "ms-python.black-formatter", 23 | "eeyore.yapf", 24 | "1YiB.rust-bundle", 25 | "dtsvet.vscode-wasm", 26 | "ms-vscode.cpptools", 27 | "esbenp.prettier-vscode", 28 | "streetsidesoftware.code-spell-checker", 29 | "GitHub.copilot" 30 | ] 31 | } 32 | }, 33 | "forwardPorts": [ 34 | 4242 35 | ] 36 | } -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | rllm/tch-cuda/kernels/vllm/** linguist-vendored 2 | rllm/tch-cuda/kernels/flash_attn/** linguist-vendored 3 | controllers/pyctrl/Lib/** linguist-vendored 4 | -------------------------------------------------------------------------------- /.github/workflows/aicirt-release.yml: -------------------------------------------------------------------------------- 1 | name: AICIrt release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | env: 9 | CARGO_TERM_COLOR: always 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | 15 | permissions: 16 | contents: write 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | with: 21 | submodules: true 22 | - run: rustup target add wasm32-wasi 23 | - uses: hendrikmuhs/ccache-action@v1.2 24 | - uses: Swatinem/rust-cache@v2 25 | with: 26 | cache-on-failure: true 27 | - name: Release script 28 | run: ./scripts/release.sh --xz 29 | - name: Release 30 | uses: softprops/action-gh-release@v1 31 | if: startsWith(github.ref, 'refs/tags/') 32 | with: 33 | body_path: target/dist/README.md 34 | files: | 35 | target/dist/*.tar.gz 36 | target/dist/*.tar.xz 37 | target/dist/*.wasm 38 | -------------------------------------------------------------------------------- /.github/workflows/aicirt.yml: -------------------------------------------------------------------------------- 1 | name: AICIrt 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | with: 20 | submodules: true 21 | - run: rustup target add wasm32-wasi 22 | - uses: hendrikmuhs/ccache-action@v1.2 23 | - uses: Swatinem/rust-cache@v2 24 | with: 25 | cache-on-failure: true 26 | - name: Build aici_abi 27 | run: cargo build --verbose --release 28 | working-directory: controllers/aici_abi 29 | - name: Build uppercase 30 | run: cargo build --verbose --release 31 | working-directory: controllers/uppercase 32 | - name: Build pyctrl 33 | run: cargo build --verbose --release 34 | working-directory: controllers/pyctrl 35 | - name: Build jsctrl 36 | run: cargo build --verbose --release 37 | working-directory: controllers/jsctrl 38 | - name: Build declctrl 39 | run: cargo build --verbose --release 40 | working-directory: controllers/declctrl 41 | - name: Build aicirt 42 | run: cargo build --verbose --release 43 | working-directory: aicirt 44 | - name: Build rllm-llamacpp 45 | run: cargo build --verbose --release --no-default-features 46 | working-directory: rllm/rllm-llamacpp 47 | - name: Release script 48 | run: ./scripts/release.sh --xz 49 | - name: Artifact upload 50 | uses: actions/upload-artifact@v4 51 | with: 52 | name: aicirt-xz 53 | path: target/dist/*.tar.xz 54 | -------------------------------------------------------------------------------- /.github/workflows/links.yml: -------------------------------------------------------------------------------- 1 | name: Markdown link check 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - run: ./scripts/checklinks.sh 17 | -------------------------------------------------------------------------------- /.github/workflows/rllm-cuda.yml: -------------------------------------------------------------------------------- 1 | name: rLLM with CUDA 2 | 3 | on: 4 | push: 5 | branches: [ "disabled-main" ] 6 | pull_request: 7 | branches: [ "disabled-main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | TORCH_CUDA_ARCH_LIST: 8.0 12 | CUDA_COMPUTE_CAP: 80 13 | LIBTORCH_USE_PYTORCH: 1 14 | LIBTORCH_BYPASS_VERSION_CHECK: 1 15 | 16 | jobs: 17 | build: 18 | 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | with: 24 | submodules: true 25 | - run: sudo apt-get install ccache 26 | 27 | - run: sudo df -h 28 | - run: sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/share/boost 29 | - run: sudo df -h 30 | 31 | - run: pip install torch==2.1.0 32 | - run: sudo df -h 33 | 34 | - uses: Jimver/cuda-toolkit@v0.2.13 35 | id: cuda-toolkit 36 | with: 37 | cuda: '12.3.2' 38 | 39 | - run: echo "Installed cuda version is ${{ steps.cuda-toolkit.outputs.cuda }}" 40 | - run: echo "Cuda install location ${{ steps.cuda-toolkit.outputs.CUDA_PATH }}" 41 | 42 | - run: nvcc -V 43 | 44 | - name: Build rLLM 45 | run: cargo build --verbose --release 46 | working-directory: rllm/rllm-cuda 47 | 48 | - run: strip target/release/rllm-cuda 49 | - name: Artifact upload 50 | uses: actions/upload-artifact@v4 51 | with: 52 | name: rllm-cuda 53 | path: target/release/rllm-cuda 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | .env 4 | .venv 5 | built 6 | tmp 7 | cache 8 | target 9 | tokenizer.json 10 | .hypothesis 11 | perf.data 12 | tokenizer.bin 13 | log.txt 14 | rllm/rllm-cuda/*.safetensor 15 | rllm/rllm-cuda/report*.sqlite 16 | *.ncu-rep 17 | *.nsys-rep 18 | .DS_Store 19 | build/ 20 | dist/ 21 | *.egg-info 22 | .vscode/c_cpp_properties.json 23 | logs 24 | controllers/RustPython/ 25 | *.cpython*.so 26 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "vllm"] 2 | path = py/vllm 3 | url = https://github.com/mmoskal/vllm 4 | branch = hooks 5 | [submodule "tch-cuda/cutlass"] 6 | path = rllm/tch-cuda/cutlass 7 | url = https://github.com/NVIDIA/cutlass.git 8 | [submodule "llama-cpp-low/llama.cpp"] 9 | path = rllm/llama-cpp-low/llama.cpp 10 | url = https://github.com/ggerganov/llama.cpp 11 | [submodule "py/guidance"] 12 | path = py/guidance 13 | url = https://github.com/hudson-ai/guidance 14 | branch = lazy_grammars 15 | [submodule "controllers/toktrie"] 16 | path = controllers/toktrie 17 | url = https://github.com/microsoft/toktrie 18 | [submodule "controllers/derivre"] 19 | path = controllers/derivre 20 | url = https://github.com/microsoft/derivre 21 | [submodule "controllers/llguidance"] 22 | path = controllers/llguidance 23 | url = https://github.com/microsoft/llguidance 24 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "vLLM server", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "module": "pyaici.vllm_server", 12 | "env": { 13 | "RUST_LOG": "info,tokenizers=error,aicirt=info", 14 | "RUST_BACKTRACE": "1", 15 | "PYTHONPATH": "${workspaceFolder}/py:${workspaceFolder}/py/vllm" 16 | }, 17 | "args": [ 18 | "--enforce-eager", 19 | "--use-v2-block-manager", 20 | "--enable-chunked-prefill", 21 | "--served-model-name=model", 22 | "--aici-rt", 23 | "./target/release/aicirt", 24 | "-A--wasm-timer-resolution-us=10", 25 | "--model", 26 | "microsoft/Phi-3-mini-128k-instruct", 27 | "--trust-remote-code", 28 | "--port", 29 | "4242", 30 | "--host", 31 | "127.0.0.1", 32 | "--trust-remote-code" 33 | ] 34 | }, 35 | { 36 | "type": "lldb", 37 | "request": "launch", 38 | "name": "rllm-llamacpp phi", 39 | "cwd": "rllm/rllm-llamacpp", 40 | "preLaunchTask": "rllm-llamacpp: build", 41 | "program": "${workspaceFolder}/target/debug/rllm-llamacpp", 42 | "env": { 43 | "RUST_LOG": "info,tokenizers=error,rllm=trace,aicirt=info,llama_cpp_low=trace" 44 | }, 45 | "args": [ 46 | "--verbose", 47 | "--aicirt=${workspaceFolder}/target/release/aicirt", 48 | "--model=https://huggingface.co/TheBloke/phi-2-GGUF/blob/main/phi-2.Q8_0.gguf", 49 | "--gpu-layers=100" 50 | ] 51 | } 52 | ] 53 | } -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "type": "cargo", 6 | "command": "build", 7 | "options": { 8 | "cwd": "${workspaceFolder}/rllm/rllm-llamacpp" 9 | }, 10 | "problemMatcher": [ 11 | "$rustc" 12 | ], 13 | "group": "build", 14 | "label": "rllm-llamacpp: build" 15 | } 16 | ] 17 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | exclude = ["tch-rs"] 3 | members = [ 4 | "aicirt", 5 | "controllers/toktrie/core", 6 | "controllers/toktrie/hf_tokenizers", 7 | "controllers/aici_abi", 8 | "controllers/aici_native", 9 | "controllers/declctrl", 10 | "controllers/pyctrl", 11 | "controllers/jsctrl", 12 | "controllers/llguidance/rust", 13 | "controllers/llguidance/parser", 14 | "controllers/llguidance_ctrl", 15 | "controllers/uppercase", 16 | "controllers/derivre", 17 | "rllm/rllm-base", 18 | "rllm/rllm-cuda", 19 | "rllm/rllm-llamacpp", 20 | "rllm/tch-cuda", 21 | "rllm/llama-cpp-low", 22 | ] 23 | resolver = "2" 24 | 25 | [profile.release] 26 | debug = 1 27 | 28 | [patch.'https://github.com/microsoft/toktrie'] 29 | toktrie = { path = "controllers/toktrie/core" } 30 | 31 | [patch.'https://github.com/microsoft/derivre'] 32 | derivre = { path = "controllers/derivre" } 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For other help and questions about using this project, please use GitHub Discussions. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this the AI Controller Interface (AICI) is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /aici.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "X$AICI_API_BASE" = "X" ] ; then 4 | export AICI_API_BASE="http://127.0.0.1:4242/v1/" 5 | fi 6 | 7 | PYTHONPATH=`dirname $0`/py \ 8 | python3 -m pyaici.cli "$@" 9 | -------------------------------------------------------------------------------- /aicirt/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aicirt" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | aici_abi = { path = "../controllers/aici_abi" } 10 | aici_native = { path = "../controllers/aici_native" } 11 | anyhow = "1.0.75" 12 | base64 = "0.21.4" 13 | clap = { version = "4.4.4", features = ["derive"] } 14 | hex = "0.4.3" 15 | libc = "0.2.148" 16 | log = "0.4.20" 17 | rayon = "1.7.0" 18 | serde = { version = "1.0.192", features = ["derive"] } 19 | serde_json = { version = "1.0.108", features = ["preserve_order"] } 20 | sha2 = "0.10.7" 21 | wasmtime = { version = "16.0.0", default-features = false, features = ["cranelift", "parallel-compilation", "pooling-allocator"] } 22 | tokenizers = { version = "0.15.0", features = ["http"] } 23 | thread-priority = "0.15.1" 24 | cap = "0.1.2" 25 | bincode = "1.3.3" 26 | uuid = { version = "1.6.1", features = ["v4"] } 27 | regex = "1.10.3" 28 | ureq = "2.9.5" 29 | 30 | [target.'cfg(target_os = "linux")'.dependencies] 31 | linux-futex = "0.2.0" 32 | 33 | [target.'cfg(target_os = "macos")'.dependencies] 34 | ulock-sys = "0.1.0" 35 | 36 | -------------------------------------------------------------------------------- /aicirt/README.md: -------------------------------------------------------------------------------- 1 | # AICI Runtime (aicirt) 2 | 3 | Multi-threaded runner for AICI Controllers, built on top of [Wasmtime](https://wasmtime.dev/). 4 | 5 | ```mermaid 6 | graph TD 7 | User1 <-- HTTP --> LLM 8 | User2 <-- HTTP --> LLM 9 | UserN <-- HTTP --> LLM["LLM Server
(batching)"] 10 | LLM <-- CUDA/pytorch --> GPU 11 | LLM <-- POSIX SHM --> aicirt[AICI-runtime] 12 | aicirt <-- Sockets+SHM --> Worker1[Worker1
Running Wasm] 13 | aicirt <-- Sockets+SHM --> Worker2[Worker2
Running Wasm] 14 | aicirt <-- Sockets+SHM --> WorkerM[WorkerM
Running Wasm] 15 | ``` 16 | 17 | ```mermaid 18 | sequenceDiagram 19 | actor User 20 | participant GPU 21 | participant LLM 22 | participant aicirt as AICI-runtime 23 | LLM -->> GPU: Model 24 | User -->> LLM: Request (Prompt + Wasm) 25 | LLM -->>+ aicirt: Prompt + Wasm 26 | aicirt -->>- LLM: logit bias 1 27 | LLM -->>+ GPU: Prompt 28 | LLM -->> GPU: logit bias 1 29 | GPU -->> LLM: token 1 30 | LLM -->>+ aicirt: token 1 31 | LLM -->> User: token 1 32 | aicirt -->>- LLM: logit bias 2 33 | LLM -->> GPU: logit bias 2 34 | GPU -->>- LLM: token 2 35 | LLM -->> User: token 2 36 | ``` 37 | 38 | Below is process structure. 39 | 40 | - dotted arrow from A to B indicates that A sends requests to B (and gets responses) 41 | - solid arrow from A to B indicates that A spawns (forks) B 42 | - `spawner` is a special process, forked from `aicirt` at the beginning; 43 | for every user requests it spawns a process for top-level controller and a `common state` process 44 | for handling shared state between 45 | all controller instances for that request (they can talk to the `common state` process) 46 | - the top-level constraint can spawn more constraints, which can spawn yet more; 47 | `aicirt` has a direct connection to all these constraints though 48 | 49 | ```mermaid 50 | graph TD 51 | LLM ---> aicirt[AICI-runtime] 52 | LLM -..-> aicirt 53 | aicirt -..-> spawner 54 | aicirt -..-> A0((A0)) 55 | aicirt -..-> A1((A1)) 56 | aicirt -..-> A2((A2)) 57 | aicirt -..-> A3((A3)) 58 | aicirt -..-> A4((A4)) 59 | aicirt ---> spawner 60 | spawner --> A0 61 | spawner --> CommsA[Common state for A] 62 | subgraph User Request A 63 | A0 --> A1 64 | A0 --> A2 65 | A2 --> A3 66 | A2 --> A4 67 | A0 -..-> CommsA 68 | A1 -..-> CommsA 69 | A2 -..-> CommsA 70 | A3 -..-> CommsA 71 | A4 -..-> CommsA 72 | end 73 | aicirt -..-> B0((B0)) 74 | aicirt -..-> B1((B1)) 75 | spawner --> B0 76 | spawner --> CommsB[Common state for B] 77 | subgraph User Request B 78 | B0 -..-> CommsB 79 | B1 -..-> CommsB 80 | B0 --> B1 81 | end 82 | ``` -------------------------------------------------------------------------------- /aicirt/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod api; 2 | mod bench; 3 | pub mod futexshm; 4 | pub mod msgchannel; 5 | pub mod semaphore; 6 | pub mod shm; 7 | 8 | pub use aici_native::*; 9 | 10 | #[cfg(target_os = "macos")] 11 | mod macos; 12 | 13 | pub use bench::*; 14 | use thread_priority::{ 15 | set_thread_priority_and_policy, thread_native_id, RealtimeThreadSchedulePolicy, ThreadPriority, 16 | ThreadSchedulePolicy, 17 | }; 18 | 19 | pub fn get_unix_time() -> u64 { 20 | std::time::SystemTime::now() 21 | .duration_since(std::time::UNIX_EPOCH) 22 | .unwrap() 23 | .as_secs() 24 | } 25 | 26 | /// An error thrown from the WASM runtime or otherwise originating from user error 27 | /// - should not generate additional stacktraces from where it's caught. 28 | #[derive(Debug)] 29 | pub struct UserError { 30 | pub msg: String, 31 | } 32 | 33 | impl UserError { 34 | pub fn new(msg: String) -> Self { 35 | Self { msg } 36 | } 37 | 38 | pub fn anyhow(msg: String) -> anyhow::Error { 39 | anyhow::anyhow!(Self::new(msg)) 40 | } 41 | 42 | pub fn is_self(e: &anyhow::Error) -> bool { 43 | e.downcast_ref::().is_some() 44 | } 45 | 46 | pub fn maybe_stacktrace(e: &anyhow::Error) -> String { 47 | if let Some(e) = e.downcast_ref::() { 48 | format!("{}", e) 49 | } else { 50 | format!("{:?}", e) 51 | } 52 | } 53 | } 54 | 55 | impl std::fmt::Display for UserError { 56 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 57 | write!(f, "{}", self.msg) 58 | } 59 | } 60 | 61 | impl std::error::Error for UserError {} 62 | 63 | #[macro_export] 64 | macro_rules! user_error { 65 | ($($tt:tt)*) => { 66 | $crate::UserError::anyhow(format!($($tt)*)) 67 | }; 68 | } 69 | 70 | #[macro_export] 71 | macro_rules! bail_user { 72 | ($($tt:tt)*) => { 73 | return Err($crate::UserError::anyhow(format!($($tt)*))) 74 | }; 75 | } 76 | 77 | #[macro_export] 78 | macro_rules! ensure_user { 79 | ($cond:expr, $($tt:tt)*) => { 80 | if !$cond { 81 | return Err($crate::UserError::anyhow(format!($($tt)*))) 82 | } 83 | }; 84 | } 85 | 86 | pub fn is_hex_string(s: &str) -> bool { 87 | s.chars().all(|c| c.is_digit(16)) 88 | } 89 | 90 | pub fn valid_module_or_tag(s: &str) -> bool { 91 | valid_module_id(s) || valid_tagname(s) 92 | } 93 | 94 | pub fn valid_module_id(s: &str) -> bool { 95 | s.len() == 64 && is_hex_string(s) 96 | } 97 | 98 | pub fn valid_tagname(s: &str) -> bool { 99 | match s.chars().next() { 100 | Some(c) if c.is_alphabetic() => { 101 | !valid_module_id(s) 102 | && s.chars().all(|c| { 103 | c == '_' || c == '-' || c == '.' || c.is_digit(10) || c.is_alphabetic() 104 | }) 105 | } 106 | _ => false, 107 | } 108 | } 109 | 110 | fn set_priority(pri: ThreadPriority) { 111 | // this fails on WSL 112 | let _ = set_thread_priority_and_policy( 113 | thread_native_id(), 114 | pri, 115 | ThreadSchedulePolicy::Realtime(RealtimeThreadSchedulePolicy::Fifo), 116 | ); 117 | } 118 | 119 | pub fn set_max_priority() { 120 | set_priority(ThreadPriority::Max); 121 | } 122 | 123 | pub fn set_min_priority() { 124 | set_priority(ThreadPriority::Min); 125 | } 126 | -------------------------------------------------------------------------------- /aicirt/src/macos.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use std::{sync::atomic::AtomicU32, time::Duration}; 3 | use ulock_sys::{ 4 | __ulock_wait, __ulock_wake, darwin19::UL_COMPARE_AND_WAIT_SHARED, ULF_NO_ERRNO, ULF_WAKE_ALL, 5 | }; 6 | 7 | pub trait AsFutex { 8 | fn as_futex(&self) -> &Futex; 9 | } 10 | 11 | impl AsFutex for AtomicU32 { 12 | #[must_use] 13 | #[inline] 14 | fn as_futex(&self) -> &Futex { 15 | unsafe { std::mem::transmute(self) } 16 | } 17 | } 18 | 19 | #[repr(transparent)] 20 | pub struct Futex { 21 | pub value: AtomicU32, 22 | } 23 | 24 | impl Futex { 25 | fn wait_core(&self, expected_value: u32, micros: u32) -> Result<()> { 26 | let r = unsafe { 27 | __ulock_wait( 28 | UL_COMPARE_AND_WAIT_SHARED | ULF_NO_ERRNO, 29 | self.value.as_ptr() as *mut libc::c_void, 30 | expected_value as u64, 31 | micros, 32 | ) 33 | }; 34 | 35 | if r >= 0 { 36 | Ok(()) 37 | } else { 38 | // TODO: can copy errors from https://github.com/ziglang/zig/blob/9e684e8d1af39904055abe64a9afda69a3d44a59/lib/std/Thread/Futex.zig#L192 39 | Err(anyhow!("__ulock_wait failed: {}", r)) 40 | } 41 | } 42 | 43 | /// Wait until this futex is awoken by a `wake` call. 44 | /// The thread will only be sent to sleep if the futex's value matches the 45 | /// expected value. 46 | pub fn wait(&self, expected_value: u32) -> Result<()> { 47 | self.wait_core(expected_value, 0) 48 | } 49 | 50 | /// Wait until this futex is awoken by a `wake` call, or until the timeout expires. 51 | /// The thread will only be sent to sleep if the futex's value matches the 52 | /// expected value. 53 | pub fn wait_for(&self, expected_value: u32, timeout: Duration) -> Result<()> { 54 | if timeout >= Duration::from_micros(u32::MAX as u64) { 55 | self.wait_core(expected_value, 0) 56 | } else { 57 | self.wait_core(expected_value, timeout.as_micros() as u32) 58 | } 59 | } 60 | 61 | /// Wake up `n` waiters. 62 | pub fn wake(&self, _n: i32) -> i32 { 63 | loop { 64 | let r = unsafe { 65 | __ulock_wake( 66 | UL_COMPARE_AND_WAIT_SHARED | ULF_NO_ERRNO | ULF_WAKE_ALL, 67 | self.value.as_ptr() as *mut libc::c_void, 68 | 0, 69 | ) 70 | }; 71 | if r == -libc::ENOENT { 72 | return 0; 73 | } 74 | if r >= 0 { 75 | return 1; 76 | } 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /aicirt/src/msgchannel.rs: -------------------------------------------------------------------------------- 1 | use crate::{semaphore::Semaphore, shm::{Shm, Unlink}}; 2 | use anyhow::Result; 3 | use std::time::Duration; 4 | 5 | pub struct MessageChannel { 6 | shm: Shm, 7 | write_sem: Semaphore, 8 | read_sem: Semaphore, 9 | } 10 | 11 | unsafe impl Send for MessageChannel {} 12 | 13 | impl MessageChannel { 14 | pub fn shm_name(name: &str) -> String { 15 | format!("{0}-shm", name) 16 | } 17 | 18 | pub fn new_cmd(name: &str, size: usize) -> Result { 19 | Self::new_ex(name, size, true) 20 | } 21 | 22 | pub fn new(name: &str, size: usize) -> Result { 23 | Self::new_ex(name, size, false) 24 | } 25 | 26 | fn new_ex(name: &str, size: usize, unlink: bool) -> Result { 27 | log::debug!("msg ch: {} size={}k", name, size / 1024); 28 | 29 | let shm = Shm::new( 30 | &Self::shm_name(name), 31 | size, 32 | if unlink { Unlink::Pre } else { Unlink::None }, 33 | )?; 34 | let write_sem = Semaphore::new(&format!("{0}-wr", name), 1, unlink)?; 35 | let read_sem = Semaphore::new(&format!("{0}-rd", name), 0, unlink)?; 36 | 37 | Ok(Self { 38 | shm, 39 | write_sem, 40 | read_sem, 41 | }) 42 | } 43 | 44 | pub fn send(&self, msg: &[u8]) -> Result<()> { 45 | self.shm.fits_msg(msg)?; 46 | self.write_sem.wait()?; 47 | self.shm.write_msg(msg).unwrap(); 48 | self.read_sem.post()?; 49 | Ok(()) 50 | } 51 | 52 | #[allow(dead_code)] 53 | pub fn busy_reset(&self) { 54 | unsafe { std::ptr::write_volatile(self.shm.ptr_at(0) as *mut u32, 0) }; 55 | } 56 | 57 | #[allow(dead_code)] 58 | pub fn busy_send(&self, msg: &[u8]) -> Result<()> { 59 | self.shm.fits_msg(msg)?; 60 | loop { 61 | let len = self.shm.read_len()?; 62 | if len != 0 { 63 | std::hint::spin_loop(); 64 | continue; 65 | } 66 | return Ok(self.shm.write_msg(msg).unwrap()); 67 | } 68 | } 69 | 70 | #[allow(dead_code)] 71 | pub fn busy_recv(&self) -> Result> { 72 | loop { 73 | let len = self.shm.read_len()?; 74 | if len == 0 { 75 | std::hint::spin_loop(); 76 | continue; 77 | } 78 | let res = self.shm.read_msg(); 79 | return res; 80 | } 81 | } 82 | 83 | pub fn recv(&self, busy_wait_duration: &Duration) -> Result> { 84 | self.read_sem.busy_wait(busy_wait_duration)?; 85 | let res = self.shm.read_msg(); 86 | self.write_sem.post()?; 87 | res 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /aicirt/src/semaphore.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use std::{ 3 | ffi::CString, 4 | io, 5 | time::{Duration, Instant}, 6 | }; 7 | 8 | pub struct Semaphore { 9 | sem: *mut libc::sem_t, 10 | } 11 | 12 | impl Semaphore { 13 | fn last_error() -> Result { 14 | Err(io::Error::last_os_error().into()) 15 | } 16 | 17 | pub fn new(name: &str, initial_value: u32, unlink: bool) -> Result { 18 | log::trace!("sem_open: {}", name); 19 | let c_name = CString::new(name).unwrap(); 20 | if unlink { 21 | unsafe { 22 | libc::sem_unlink(c_name.as_ptr()); 23 | }; 24 | } 25 | let sem = unsafe { libc::sem_open(c_name.as_ptr(), libc::O_CREAT, 0o666, initial_value) }; 26 | 27 | if sem.is_null() { 28 | return Self::last_error(); 29 | } 30 | 31 | Ok(Self { sem }) 32 | } 33 | 34 | pub fn wait(&self) -> Result<()> { 35 | let ret = unsafe { libc::sem_wait(self.sem) }; 36 | if ret < 0 { 37 | return Self::last_error(); 38 | } 39 | Ok(()) 40 | } 41 | 42 | pub fn busy_wait(&self, wait_duration: &Duration) -> Result<()> { 43 | let deadline = Instant::now() + *wait_duration; 44 | loop { 45 | let ret = unsafe { libc::sem_trywait(self.sem) }; 46 | if ret < 0 { 47 | #[cfg(target_os = "linux")] 48 | let last_error = unsafe { *libc::__errno_location() }; 49 | #[cfg(not(target_os = "linux"))] 50 | let last_error = unsafe { *libc::__error() }; 51 | if last_error == libc::EAGAIN { 52 | if Instant::now() > deadline { 53 | return self.wait(); 54 | } else { 55 | // std::hint::spin_loop(); 56 | continue; 57 | } 58 | } else { 59 | return Self::last_error(); 60 | } 61 | } else { 62 | return Ok(()); 63 | } 64 | } 65 | } 66 | 67 | pub fn post(&self) -> Result<()> { 68 | let ret = unsafe { libc::sem_post(self.sem) }; 69 | if ret < 0 { 70 | return Self::last_error(); 71 | } 72 | Ok(()) 73 | } 74 | } 75 | 76 | impl Drop for Semaphore { 77 | fn drop(&mut self) { 78 | unsafe { 79 | libc::sem_close(self.sem); 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /aicirt/vllm.md: -------------------------------------------------------------------------------- 1 | # Notes on integration with vLLM 2 | 3 | Following are callbacks in the vLLM flow: 4 | 5 | ```python 6 | LLMEngine.step() 7 | Scheduler.schedule() 8 | scheduler_outputs = Scheduler._schedule() 9 | SamplingParams.initiate_step(scheduler, llm_engine.counter, scheduler_outputs) 10 | return SequenceGroupMetadata(scheduler_outputs) 11 | samples = LLMEngine._run_workers("execute_model") 12 | Worker.execute_model 13 | Worker._prepare_inputs 14 | # ... 15 | SamplingParams.recv_attention_mask() 16 | # ... 17 | self.model() 18 | # ... 19 | Sampler.forward() 20 | logits = ... 21 | SamplingParams.apply_dynamic_logit_bias(logits) 22 | return _sample(logits) : SequenceOutputs 23 | return LLMEngine._process_model_outputs(samples) 24 | LLMEngine._process_sequence_group_samples() 25 | SamplingParams.append_ff_tokens(seq_group) 26 | # free and fork sequences as needed 27 | SamplingParams.finish_sampling() 28 | json_output = ... 29 | return json_output 30 | ``` 31 | 32 | Thoughts: 33 | - expose Scheduler._schedule() and call it from LLMEngine; move initiate_step to LLMEngine 34 | - return logits from Sampler.forward() and call _sample() from LLMEngine; move apply_dynamic_logit_bias to LLMEngine 35 | - pass attn_mask to execute model from LLMEngine 36 | 37 | - vllm forks sequences in _process_sequence_group_samples(); this means fork processing in AICI is done 38 | in pre_process(), not process(), so it blocks; in full AICI env you would only fork from AICI module not 39 | n= parameter -------------------------------------------------------------------------------- /controllers/aici_abi/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | target = "wasm32-wasi" 3 | 4 | [profile.dev] 5 | strip = "debuginfo" 6 | 7 | [profile.release] 8 | strip = "debuginfo" 9 | -------------------------------------------------------------------------------- /controllers/aici_abi/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aici_abi" 3 | version = "0.1.0" 4 | edition = "2021" 5 | rust-version = "1.75.0" 6 | 7 | [lib] 8 | name = "aici_abi" 9 | 10 | [dependencies] 11 | toktrie = { path = "../toktrie/core" } 12 | serde = { version = "1.0.192", features = ["derive"] } 13 | serde_json = "1.0.108" 14 | anyhow = "1.0.75" 15 | regex-automata = { version = "0.4.6", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"], optional = true } 16 | cfgrammar = { version = "0.13.3", optional = true } 17 | lrtable = { version = "0.13.3", optional = true } 18 | vob = { version = "3.0.3", optional = true } 19 | rustc-hash = { version = "2.0.0", optional = true } 20 | bytemuck = "1.16.0" 21 | bytemuck_derive = "1.6.0" 22 | 23 | [features] 24 | default = ["cfg", "rx"] 25 | cfg = ["dep:cfgrammar", "dep:lrtable", "dep:vob", "dep:rustc-hash"] 26 | rx = ["dep:regex-automata"] 27 | 28 | [[bin]] 29 | name = "yesno" 30 | path = "src/yesno.rs" 31 | -------------------------------------------------------------------------------- /controllers/aici_abi/grammars/json0.guidance: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/controllers/aici_abi/grammars/json0.guidance -------------------------------------------------------------------------------- /controllers/aici_abi/implementation.md: -------------------------------------------------------------------------------- 1 | # Implementation notes 2 | 3 | ## LR(1) parsing 4 | 5 | The LR(1) parsing consists of DFA-based lexer and the actual LR(1) parser. 6 | DFA has a single number as the state, while the state of the LR(1) is a stack of numbers. 7 | The LR(1) action is determined based on the next token from the lexer and the top of the stack. 8 | 9 | The `Recognizer` interface also has a concept of stack, however every entry on that 10 | stack contains a DFA state and an LR(1) stack. 11 | 12 | Most of the time (~98.5% for the C grammar), pushing a byte involves only updating the DFA state, 13 | while the LR(1) stack is copied unchanged (the memory is shared). 14 | 15 | 16 | ### Early error detection 17 | 18 | Consider the following invalid C program: 19 | 20 | ```c 21 | int 123456; 22 | ``` 23 | 24 | The lexer would produce `int` keyword, whitespace, `123456` constant and `;` keyword. 25 | The parser would reject `123456`, however only after all six characters of it have been read. 26 | This is too late for the LLM. 27 | 28 | To detect such errors early, we compute a set of reachable tokens for each DFA state. 29 | For example, consider a DFA that recognizes `int`, `if`, `ID` (`/[a-z][a-z0-9]*/`) and `INTLIT` (`/[0-9]+/`). 30 | The initial DFA state has a full set of tokens, while a state after `'i'` 31 | has only `int`, `if`, and `ID`, 32 | and a state after `'1'` includes only `INTLIT`. 33 | In the picture below, each state is labelled by its reachable set, 34 | and the token for which it is a match (if any) is postfixed with `*`. We only use lower-case letters and digits for simplicity. 35 | 36 | ```mermaid 37 | graph LR 38 | 0["{int,if,ID,INTLIT}"] -- "[i]" --> i(("{int,if,ID*}")) 39 | 0 -- "[a-z] - [i]" --> id(("{ID*}")) 40 | 0 -- "[0-9]" --> const(("{INTLIT*}")) 41 | const -- "[0-9]" --> const 42 | const -- "[a-z]" --> bot["{}"] 43 | i -- "[a-z0-9] - [nf]" --> id 44 | id -- "[a-z0-9]" --> id 45 | i -- "[n]" --> in(("{int,ID*}")) 46 | in -- "[t]" --> int(("{int*,ID}")) 47 | in -- "[a-z0-9] - [t]" --> id 48 | int -- "[a-z0-9]" --> id 49 | i -- "[f]" --> if(("{if*,ID}")) 50 | if -- "[a-z0-9]" --> id 51 | ``` 52 | 53 | For each LR(1) automaton state we compute a set of viable tokens, i.e., ones that do 54 | not immediately lead to an error. 55 | 56 | While parsing input, if the intersection of viable and reachable tokens is empty, we report an error. 57 | 58 | In the example above, the viable tokens after `int` do not include `INTLIT`, 59 | and thus the parser fails immediately at `1`. 60 | 61 | -------------------------------------------------------------------------------- /controllers/aici_abi/src/yesno.rs: -------------------------------------------------------------------------------- 1 | use aici_abi::{host_trie, tokenize, toktrie::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId}; 2 | 3 | pub struct Runner { 4 | toktrie: TokTrie, 5 | tokens: Vec, 6 | yes: TokenId, 7 | no: TokenId, 8 | } 9 | 10 | impl Runner { 11 | pub fn new() -> Self { 12 | let yes = tokenize("Yes")[0]; 13 | let no = tokenize("No")[0]; 14 | // ignore user-passed arg 15 | Runner { 16 | toktrie: host_trie(), 17 | tokens: Vec::new(), 18 | yes, 19 | no, 20 | } 21 | } 22 | } 23 | 24 | impl AiciCtrl for Runner { 25 | fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { 26 | arg.save_tokens(&mut self.tokens); 27 | if self.tokens.len() >= 1 { 28 | // we only want the first token 29 | MidProcessResult::stop() 30 | } else { 31 | let mut set = self.toktrie.alloc_token_set(); 32 | set.allow_token(self.yes); 33 | set.allow_token(self.no); 34 | MidProcessResult::sample(set) 35 | } 36 | } 37 | } 38 | 39 | fn main() { 40 | // test code here? 41 | } 42 | 43 | aici_abi::aici_expose_all!(Runner, Runner::new()); 44 | -------------------------------------------------------------------------------- /controllers/aici_native/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aici_native" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "aici_native" 8 | 9 | [dependencies] 10 | aici_abi = { path = "../aici_abi" } 11 | toktrie_hf_tokenizers = { path = "../toktrie/hf_tokenizers" } 12 | serde = { version = "1.0.192", features = ["derive"] } 13 | serde_json = "1.0.108" 14 | anyhow = "1.0.75" 15 | rustc-hash = "2.0.0" 16 | tokenizers = { version = "0.15.0", features = ["http"] } 17 | log = "0.4.21" 18 | flexi_logger = "0.28.0" 19 | -------------------------------------------------------------------------------- /controllers/aici_native/README.md: -------------------------------------------------------------------------------- 1 | # AICI native 2 | 3 | Utilities for building native (non-Wasm) AICI Controllers. 4 | -------------------------------------------------------------------------------- /controllers/aici_native/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod bintokens; 2 | mod log; 3 | pub mod variables; 4 | 5 | pub use log::*; 6 | 7 | pub use rustc_hash::FxHashMap as HashMap; 8 | pub use rustc_hash::FxHashSet as HashSet; 9 | -------------------------------------------------------------------------------- /controllers/aici_native/src/log.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Write; 2 | 3 | use anyhow::Result; 4 | use flexi_logger::style; 5 | use flexi_logger::{DeferredNow, Logger, WriteMode}; 6 | use log::Record; 7 | 8 | pub enum LogMode { 9 | Normal, 10 | Test, 11 | Daemon, 12 | } 13 | 14 | struct LimitedWrite { 15 | limit: usize, 16 | dst: Vec, 17 | } 18 | 19 | impl Write for LimitedWrite { 20 | fn write_str(&mut self, s: &str) -> std::fmt::Result { 21 | if self.dst.len() > self.limit { 22 | return Err(std::fmt::Error); 23 | } 24 | if self.dst.len() + s.len() < self.limit { 25 | self.dst.extend_from_slice(s.as_bytes()); 26 | Ok(()) 27 | } else { 28 | let remaining = self.limit - self.dst.len(); 29 | self.dst.extend_from_slice(&s.as_bytes()[..remaining]); 30 | self.dst.extend_from_slice(b" (...)"); 31 | Err(std::fmt::Error) 32 | } 33 | } 34 | } 35 | 36 | fn args_to_str(limit: usize, args: &std::fmt::Arguments) -> String { 37 | // let capacity = args.estimated_capacity(); 38 | let mut output = LimitedWrite { 39 | limit, 40 | dst: Vec::with_capacity(128), 41 | }; 42 | if output.write_fmt(*args).is_err() { 43 | assert!(output.dst.len() > limit); 44 | } 45 | match String::from_utf8(output.dst) { 46 | Ok(s) => s, 47 | Err(err) => String::from_utf8_lossy(err.as_bytes()).to_string(), 48 | } 49 | } 50 | 51 | fn truncated_format( 52 | w: &mut dyn std::io::Write, 53 | _now: &mut DeferredNow, 54 | record: &Record, 55 | ) -> Result<(), std::io::Error> { 56 | let level = record.level(); 57 | write!( 58 | w, 59 | "{} [{}] {}", 60 | style(level).paint(level.to_string()), 61 | record.module_path().unwrap_or(""), 62 | style(level).paint(args_to_str(1000, record.args())) 63 | ) 64 | } 65 | 66 | fn daemon_format( 67 | w: &mut dyn std::io::Write, 68 | now: &mut DeferredNow, 69 | record: &Record, 70 | ) -> Result<(), std::io::Error> { 71 | write!( 72 | w, 73 | "{} {} [{}] {}", 74 | now.format("%Y-%m-%d %H:%M:%S%.3f"), 75 | record.level(), 76 | record.module_path().unwrap_or(""), 77 | args_to_str(5000, record.args()) 78 | ) 79 | } 80 | 81 | pub fn init_log(mode: LogMode) -> Result<()> { 82 | let logger = match mode { 83 | LogMode::Normal => Logger::try_with_env_or_str("info")? 84 | .format(truncated_format) 85 | .log_to_stdout(), 86 | LogMode::Test => { 87 | Logger::try_with_env_or_str("debug")?.write_mode(WriteMode::SupportCapture) 88 | } 89 | LogMode::Daemon => Logger::try_with_env_or_str("info")? 90 | .format(daemon_format) 91 | .log_to_stdout(), 92 | }; 93 | 94 | logger.start()?; 95 | Ok(()) 96 | } 97 | 98 | pub fn setup_log() { 99 | init_log(LogMode::Normal).expect("Failed to initialize log") 100 | } 101 | -------------------------------------------------------------------------------- /controllers/aici_native/src/variables.rs: -------------------------------------------------------------------------------- 1 | use aici_abi::{StorageCmd, StorageOp, StorageResp}; 2 | use rustc_hash::FxHashMap; 3 | 4 | #[derive(Default)] 5 | pub struct Variables { 6 | pub variables: FxHashMap)>, 7 | } 8 | 9 | impl Variables { 10 | pub fn process_cmd(&mut self, cmd: StorageCmd) -> StorageResp { 11 | match cmd { 12 | StorageCmd::ReadVar { name } => match self.variables.get(&name).map(|x| x.clone()) { 13 | None => StorageResp::VariableMissing {}, 14 | Some((version, value)) => StorageResp::ReadVar { value, version }, 15 | }, 16 | StorageCmd::WriteVar { 17 | name, 18 | value, 19 | when_version_is, 20 | op, 21 | } => { 22 | let curr = self.variables.get(&name).map(|x| x.clone()); 23 | match curr { 24 | Some((prev_version, prev_val)) => match when_version_is { 25 | Some(v) if v != prev_version => StorageResp::ReadVar { 26 | version: prev_version, 27 | value: prev_val, 28 | }, 29 | _ => { 30 | let value = match op { 31 | StorageOp::Append => { 32 | let mut v = prev_val.clone(); 33 | v.extend(value); 34 | v 35 | } 36 | StorageOp::Set => value, 37 | }; 38 | let version = prev_version + 1; 39 | self.variables.insert(name, (version, value)); 40 | StorageResp::WriteVar { version } 41 | } 42 | }, 43 | 44 | None => match when_version_is { 45 | None => { 46 | self.variables.insert(name, (1, value)); 47 | StorageResp::WriteVar { version: 1 } 48 | } 49 | Some(_) => StorageResp::VariableMissing {}, 50 | }, 51 | } 52 | } 53 | } 54 | } 55 | } 56 | 57 | -------------------------------------------------------------------------------- /controllers/declctrl/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | target = "wasm32-wasi" 3 | 4 | [profile.dev] 5 | strip = "debuginfo" 6 | 7 | [profile.release] 8 | strip = "debuginfo" 9 | -------------------------------------------------------------------------------- /controllers/declctrl/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aici_declctrl" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | aici_abi = { path = "../aici_abi" } 8 | regex-automata = { version = "0.3.8", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"] } 9 | serde = { version = "1.0.192", features = ["derive"] } 10 | serde_json = "1.0.108" 11 | anyhow = "1.0.75" 12 | 13 | [[bin]] 14 | name = "aici_declctrl" 15 | path = "src/declctrl.rs" 16 | -------------------------------------------------------------------------------- /controllers/declctrl/README.md: -------------------------------------------------------------------------------- 1 | # DeclCtrl 2 | 3 | The [DeclCtrl](src/declctrl.rs) exposes similar constraints 4 | to [PyCtrl](../pyctrl), but the glueing is done via a JSON AST (Abstract Syntax Tree) and thus is 5 | more restrictive. 6 | 7 | There is no reason to use it as is, but it can be used as a base for other controller. 8 | 9 | -------------------------------------------------------------------------------- /controllers/declctrl/arg2.json: -------------------------------------------------------------------------------- 1 | { 2 | "steps": [ 3 | { 4 | "Fixed": { 5 | "text": { 6 | "String": { 7 | "str": "The word 'hello'" 8 | } 9 | } 10 | } 11 | }, 12 | { 13 | "Fixed": { 14 | "text": { 15 | "String": { 16 | "str": " in French is translated as" 17 | } 18 | }, 19 | "label": "lang" 20 | } 21 | }, 22 | { 23 | "Gen": { 24 | "rx": " '[^']*'", 25 | "inner": [], 26 | "max_tokens": 15, 27 | "stmts": [ 28 | { 29 | "Set": { 30 | "var": "french", 31 | "expr": { 32 | "Current": {} 33 | } 34 | } 35 | } 36 | ] 37 | } 38 | }, 39 | { 40 | "Fixed": { 41 | "text": { 42 | "String": { 43 | "str": " or" 44 | } 45 | }, 46 | "following": "lang" 47 | } 48 | }, 49 | { 50 | "Gen": { 51 | "rx": " '[^']*'", 52 | "inner": [], 53 | "max_tokens": 15, 54 | "stmts": [ 55 | { 56 | "Set": { 57 | "var": "blah", 58 | "expr": { 59 | "Current": {} 60 | } 61 | } 62 | } 63 | ] 64 | } 65 | }, 66 | { 67 | "Fixed": { 68 | "text": { 69 | "Concat": { 70 | "parts": [ 71 | { 72 | "String": { 73 | "str": "\nResults: " 74 | } 75 | }, 76 | { 77 | "Var": { 78 | "var": "french" 79 | } 80 | }, 81 | { 82 | "String": { 83 | "str": " " 84 | } 85 | }, 86 | { 87 | "Var": { 88 | "var": "blah" 89 | } 90 | } 91 | ] 92 | } 93 | } 94 | } 95 | } 96 | ] 97 | } -------------------------------------------------------------------------------- /controllers/declctrl/genarg.py: -------------------------------------------------------------------------------- 1 | # run as: 2 | # PYTHONPATH=.. python genarg.py 3 | import ujson 4 | import pyaici._ast as ast 5 | 6 | def main(): 7 | aa = { 8 | "steps": [ast.fixed("Here's some JSON about J.R.Hacker from Seattle:\n")] 9 | + ast.json_to_steps( 10 | { 11 | "name": "", 12 | "valid": True, 13 | "description": "", 14 | "type": "foo|bar|baz|something|else", 15 | "address": {"street": "", "city": "", "state": "[A-Z][A-Z]"}, 16 | "age": 1, 17 | "fraction": 1.5, 18 | } 19 | ) 20 | } 21 | 22 | aa = { 23 | "steps": [ 24 | ast.fixed("The word 'hello'"), 25 | ast.label("lang", ast.fixed(" in French is translated as")), 26 | ast.gen(rx=r" '[^']*'", max_tokens=15, set_var="french"), 27 | ast.fixed(" or", following="lang"), 28 | ast.gen(rx=r" '[^']*'", max_tokens=15, set_var="blah"), 29 | ast.fixed("\nResults: {{french}} {{blah}}", expand_vars=True), 30 | ] 31 | } 32 | 33 | ast.clear_none(aa) 34 | print(ujson.dumps(aa)) 35 | 36 | main() -------------------------------------------------------------------------------- /controllers/declctrl/native.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | T=llama 4 | T=gpt4 5 | 6 | set -x 7 | set -e 8 | if test -f tokenizer.bin ; then 9 | echo "Skipping tokenizer" 10 | else 11 | (cd ../aicirt && cargo run --release -- --tokenizer $T --save-tokenizer ../declctrl/tokenizer.bin) 12 | fi 13 | cargo build --release 14 | if [ `uname` = Linux ] ; then 15 | perf stat ./target/release/declctrl 16 | else 17 | ./target/release/declctrl 18 | fi 19 | -------------------------------------------------------------------------------- /controllers/declctrl/size.js: -------------------------------------------------------------------------------- 1 | const fs = require("fs"); 2 | 3 | // run "cmd" and capture the output 4 | function run(cmd) { 5 | return new Promise((resolve, reject) => { 6 | require("child_process").exec(cmd, (error, stdout, stderr) => { 7 | if (error) { 8 | reject(error); 9 | } else { 10 | resolve(stdout.trim()); 11 | } 12 | }); 13 | }); 14 | } 15 | 16 | function fmt(perc) { 17 | return perc.toFixed(2).padStart(6) + "%"; 18 | } 19 | 20 | function gethd(o) { 21 | return ( 22 | fmt(o.retained_size_percent) + 23 | " " + 24 | fmt(o.shallow_size_percent) + 25 | " " + 26 | o.name 27 | ); 28 | } 29 | 30 | function cchildren(o) { 31 | const r = {}; 32 | (o.children ?? []).forEach((c) => { 33 | r[gethd(c)] = cchildren(c); 34 | }); 35 | return r; 36 | } 37 | 38 | async function main() { 39 | const o = JSON.parse( 40 | await run("twiggy dominators target/strip.wasm -f json") 41 | ); 42 | o.root = cchildren({ name: "ROOT", children: o.items }); 43 | delete o.items; 44 | fs.writeFileSync( 45 | "target/dominators.json", 46 | JSON.stringify(o, null, 2) 47 | ); 48 | } 49 | main(); 50 | -------------------------------------------------------------------------------- /controllers/declctrl/wasm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | cargo build --release 6 | BIN=$(cd ../target; pwd) 7 | cp $BIN/wasm32-wasi/release/aici_declctrl.wasm $BIN/opt.wasm 8 | ls -l $BIN/opt.wasm 9 | if [ "X$1" = "Xbuild" ] ; then 10 | exit 11 | fi 12 | if [ "X$1" = "Xsize" ] ; then 13 | node size.js 14 | fx $BIN/dominators.json 15 | exit 16 | fi 17 | 18 | (cd ../aicirt; cargo build --release) 19 | 20 | mkdir -p tmp 21 | if [ "X$1" = "Xcache" ] ; then 22 | $BIN/release/aicirt --module $BIN/opt.wasm | tee tmp/runlog.txt 23 | exit 24 | fi 25 | 26 | PERF= 27 | if [ `uname` = Linux ] ; then 28 | PERF="perf stat" 29 | fi 30 | RUST_LOG=info $PERF $BIN/release/aicirt --tokenizer gpt4 --module $BIN/opt.wasm 31 | RUST_LOG=info $PERF $BIN/release/aicirt \ 32 | --tokenizer gpt4 --module $BIN/opt.wasm --run | tee tmp/runlog.txt 33 | ls -l $BIN/opt.wasm 34 | -------------------------------------------------------------------------------- /controllers/jsctrl/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | target = "wasm32-wasi" 3 | 4 | [profile.dev] 5 | strip = "debuginfo" 6 | 7 | [profile.release] 8 | strip = "debuginfo" 9 | -------------------------------------------------------------------------------- /controllers/jsctrl/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aici_jsctrl" 3 | version = "0.1.0" 4 | edition = "2021" 5 | build = "build.rs" 6 | 7 | [dependencies] 8 | aici_abi = { path = "../aici_abi" } 9 | serde = { version = "1.0.192", features = ["derive"] } 10 | serde_json = "1.0.108" 11 | anyhow = "1.0.75" 12 | lazy_static = "1.4.0" 13 | rquickjs = { git = "https://github.com/DelSkayn/rquickjs", rev = "343b21b742d3bb052710dc53144b79dc61bb592d", features = ["array-buffer", "macro"] } 14 | 15 | [[bin]] 16 | name = "aici_jsctrl" 17 | path = "src/jsctrl.rs" 18 | 19 | [build-dependencies] 20 | glob = "0.3.1" 21 | -------------------------------------------------------------------------------- /controllers/jsctrl/README.md: -------------------------------------------------------------------------------- 1 | # JsCtrl 2 | 3 | This crate implements AI Controller Interface by embedding 4 | [QuickJS](https://bellard.org/quickjs/) (JavaScript (ES2023) interpreter) 5 | via [rquickjs](https://github.com/DelSkayn/rquickjs) 6 | in a Wasm module together with native 7 | primitives for specific kinds of output constraints: 8 | fixed token output, regexps, LR(1) grammars, substring constrains etc. 9 | JavaScript code is typically only used lightly, for gluing the primitives together, 10 | and thus is not performance critical. 11 | 12 | There are [some sample scripts](./samples/) available. 13 | The scripts use the [aici module](./samples/aici-types.d.ts) to communicate with the AICI runtime 14 | and use the native constraints. 15 | 16 | This is quite similar to [PyCtrl](../pyctrl/README.md) but with JavaScript instead of Python. 17 | It is also smaller, at 1.3MiB without regex and CFG, 1.8MiB with regex, and 3.3MiB with regex and CFG. 18 | For comparison, pyctrl is 14MiB. 19 | Also, the [PyCtrl samples](../pyctrl/samples/) translate 1:1 to JsCtrl. 20 | 21 | ## Usage 22 | 23 | To run a JsCtrl sample use: 24 | 25 | ```bash 26 | ../../aici.sh run samples/hello.js 27 | ``` 28 | 29 | If you write your sample in TypeScript, compile it first with `tsc -p samples`. 30 | 31 | If you want to build the interpreter yourself, use: 32 | 33 | ```bash 34 | ../../aici.sh run --build . samples/hello.js 35 | ``` 36 | 37 | You will see the console output of the program. 38 | 39 | -------------------------------------------------------------------------------- /controllers/jsctrl/build.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | 3 | fn run(cmd: &mut Command, msg: &str) { 4 | let status = cmd.status().expect(&format!("failed to execute: {msg}")); 5 | if !status.success() { 6 | panic!("process exited with status: {}; {}", status, msg); 7 | } 8 | } 9 | 10 | fn rerun_if_glob(pattern: &str) { 11 | for entry in glob::glob(pattern).expect(pattern).flatten() { 12 | let display = entry.display(); 13 | println!("cargo:rerun-if-changed={display}"); 14 | } 15 | } 16 | 17 | fn main() { 18 | rerun_if_glob("ts/*.ts"); 19 | rerun_if_glob("ts/*.json"); 20 | rerun_if_glob("gen-dts.mjs"); 21 | 22 | if Command::new("tsc").arg("--version").status().is_err() { 23 | // println!("cargo:warning=typescript not found, installing..."); 24 | run( 25 | Command::new("npm") 26 | .arg("install") 27 | .arg("-g") 28 | .arg("typescript"), 29 | "npm install failed", 30 | ); 31 | } 32 | 33 | run(Command::new("tsc").arg("-p").arg("ts"), "build failed"); 34 | run(Command::new("node").arg("gen-dts.mjs"), "gen-dts failed"); 35 | } 36 | -------------------------------------------------------------------------------- /controllers/jsctrl/gen-dts.mjs: -------------------------------------------------------------------------------- 1 | import * as fs from 'fs'; 2 | 3 | function gen() { 4 | const ts = "./ts/" 5 | const native = fs.readFileSync(ts + '/native.d.ts', 'utf8') 6 | let aici = fs.readFileSync(ts + '/dist/aici.d.ts', 'utf8') 7 | aici = aici.replace(/ pre + aici + '"""') 16 | 17 | const tsconfig = fs.readFileSync("./samples/tsconfig.json", "utf-8") 18 | jssrc = jssrc.replace(/(tsconfig_json = r""")[^]*?"""/g, (_, pre) => pre + tsconfig + '"""') 19 | 20 | const hello = fs.readFileSync("./samples/hello.js", "utf-8") 21 | jssrc = jssrc.replace(/(hello_js = r""")[^]*?"""/g, (_, pre) => pre + hello + '"""') 22 | 23 | fs.writeFileSync("../../py/pyaici/jssrc.py", jssrc) 24 | } 25 | 26 | gen() -------------------------------------------------------------------------------- /controllers/jsctrl/samples/count-yellow.js: -------------------------------------------------------------------------------- 1 | import { Label, setLogLevel } from "aici" 2 | 3 | const colors = [ 4 | // "red", 5 | // "green", 6 | "blue", 7 | // "violet", 8 | // "white", 9 | // "black", 10 | ] 11 | 12 | function gencolors(k) { 13 | let s = [] 14 | 15 | for (let i = 0; i < k; ++i) { 16 | s.push(colors[Math.floor(Math.random() * colors.length)]) 17 | } 18 | 19 | return s.join(", ") 20 | } 21 | 22 | async function countYellow() { 23 | setLogLevel(10) 24 | const q = `Does the color yellow appear in the following list of colors?` 25 | await $`<|user|>\n${q}\n` 26 | const l = new Label() 27 | for (let i = 0; i < 100; ++i) { 28 | const hasYellow = Math.random() < 0.5 29 | const A = (10 + Math.random() * 100) | 0 30 | const B = (10 + Math.random() * 130) | 0 31 | const text = gencolors(A) + (hasYellow ? ", yellow, " : ", blue, ") + gencolors(B) 32 | await l.fixedAfter(`${text}<|end|>\n<|assistant|>`) 33 | const r = await gen({ maxTokens: 10, regex: /(Yes|No)/ }) 34 | console.log(q) 35 | console.log(text) 36 | console.log(hasYellow ? "Yes" : "No", r) 37 | assert(r === "Yes" || r === "No") 38 | assert(r === (hasYellow ? "Yes" : "No")) 39 | await $`\n` 40 | } 41 | } 42 | 43 | start(countYellow) 44 | -------------------------------------------------------------------------------- /controllers/jsctrl/samples/hello.js: -------------------------------------------------------------------------------- 1 | async function main() { 2 | await $`Ultimate answer is to the life, universe and everything is ` 3 | await gen({ regex: /\d\d/ }) 4 | } 5 | 6 | start(main) 7 | -------------------------------------------------------------------------------- /controllers/jsctrl/samples/hellots.ts: -------------------------------------------------------------------------------- 1 | async function main() { 2 | await $`Hello` 3 | await gen({ regex: / [A-Z]+/ }) 4 | } 5 | 6 | start(main) 7 | -------------------------------------------------------------------------------- /controllers/jsctrl/samples/mapping.js: -------------------------------------------------------------------------------- 1 | import { Label, getTokens, setLogLevel } from "aici" 2 | 3 | function randomInt(min, max) { 4 | return Math.floor(Math.random() * (max - min + 1) + min) 5 | } 6 | 7 | async function countYellow() { 8 | setLogLevel(0) 9 | const q = `Tell me the value of x42?` 10 | await $`<|user|>\n${q}\n` 11 | const l = new Label() 12 | let numok = 0 13 | for (let i = 0; i < 20; ++i) { 14 | let text = "" 15 | let x42 = randomInt(10, 99) 16 | for (let i = 10; i < 300; ++i) { 17 | text += `The value of x${i} is ${i == 42 ? x42 : randomInt(10, 99)}.\n` 18 | } 19 | await l.fixedAfter(`${text}\nTell me x42.<|end|>\n<|assistant|>The value of x42 is `) 20 | const r = await gen({ maxTokens: 10, regex: /\d\d/ }) 21 | // console.log(q) 22 | // console.log(text) 23 | console.log(getTokens().length, x42, r, r === x42.toString()) 24 | if (r === x42.toString()) { 25 | numok++ 26 | } 27 | // assert(r === x42.toString()) 28 | await $`\n` 29 | } 30 | console.log("numok", numok) 31 | } 32 | 33 | start(countYellow) 34 | -------------------------------------------------------------------------------- /controllers/jsctrl/samples/schema.js: -------------------------------------------------------------------------------- 1 | // Doesn't seem to work too well... 2 | 3 | import { Label } from "aici" 4 | 5 | async function jsonString() { 6 | await gen({ 7 | maxTokens: 50, 8 | regex: /"(\\(["\\\/bfnrt]|u[a-fA-F0-9]{4})|[^"\\\x00-\x1F\x7F]+)+"/ 9 | }) 10 | } 11 | 12 | async function jsonInt() { 13 | await gen({ regex: /\d+/ }) 14 | } 15 | 16 | /** 17 | * @param {string} name 18 | */ 19 | async function jsonField(name) { 20 | await $` "${name}": ` 21 | } 22 | 23 | async function cityList() { 24 | const start = new Label() 25 | await $`[` 26 | const maxNodes = 3; 27 | for (let i = 0; i < maxNodes; i++) { 28 | await $`{\n` 29 | await jsonField("name") 30 | await jsonString() 31 | await $`,\n` 32 | await jsonField("population") 33 | await jsonInt() 34 | await $`,\n` 35 | await jsonField("url") 36 | await jsonString() 37 | await $`\n` 38 | const nextChar = await gen({ options: ['},\n', ']'] }) 39 | if (nextChar === ']') { 40 | break 41 | } 42 | } 43 | console.log(start.textSince()) 44 | } 45 | 46 | async function main() { 47 | await $`Here is JSON objects for five European cities:\n` 48 | await cityList() 49 | } 50 | 51 | start(main) 52 | -------------------------------------------------------------------------------- /controllers/jsctrl/samples/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | /* Visit https://aka.ms/tsconfig to read more about this file */ 4 | "target": "ES2020", 5 | "lib": [ 6 | "ES2020" 7 | ], 8 | "moduleDetection": "force", 9 | "module": "ES2020", 10 | "allowJs": true, 11 | "checkJs": true, 12 | "strict": true, 13 | "noImplicitThis": true, 14 | "noImplicitReturns": true, 15 | "outDir": "./dist", 16 | "skipDefaultLibCheck": true, 17 | } 18 | } -------------------------------------------------------------------------------- /controllers/llguidance_ctrl/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | target = "wasm32-wasi" 3 | 4 | [profile.dev] 5 | strip = "debuginfo" 6 | 7 | [profile.release] 8 | strip = "debuginfo" 9 | -------------------------------------------------------------------------------- /controllers/llguidance_ctrl/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aici_llguidance_ctrl" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | aici_abi = { path = "../aici_abi" } 8 | llguidance_parser = { path = "../llguidance/parser" } 9 | serde = { version = "1.0.192", features = ["derive"] } 10 | serde_json = "1.0.108" 11 | anyhow = "1.0.75" 12 | 13 | [[bin]] 14 | name = "aici_llguidance_ctrl" 15 | path = "src/gctrl.rs" 16 | 17 | [features] 18 | default = [] 19 | logging = ["llguidance_parser/logging"] -------------------------------------------------------------------------------- /controllers/llguidance_ctrl/go.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | F="$1" 4 | if [ -z "$F" ] ; then 5 | F=run_g.py 6 | fi 7 | 8 | set -x 9 | cd `dirname $0` 10 | HERE=`pwd` 11 | PYTHONPATH=$HERE/../../py:$HERE/../../py/guidance \ 12 | python3 $HERE/$F 13 | 14 | -------------------------------------------------------------------------------- /controllers/llguidance_ctrl/src/gctrl.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use aici_abi::{ 4 | arg_bytes, get_config, 5 | toktrie::{InferenceCapabilities, StepArg}, 6 | AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, 7 | }; 8 | use serde::{Deserialize, Serialize}; 9 | 10 | use llguidance_parser::{api::TopLevelGrammar, output::Reporter, Logger, TokenParser}; 11 | 12 | const INFO: bool = true; 13 | 14 | macro_rules! infoln { 15 | ($($arg:tt)*) => { 16 | if INFO { 17 | println!($($arg)*); 18 | } 19 | }; 20 | } 21 | 22 | pub struct Runner { 23 | tok_parser: TokenParser, 24 | reporter: Reporter, 25 | } 26 | 27 | #[derive(Serialize, Deserialize)] 28 | struct RunnerArg { 29 | grammar: TopLevelGrammar, 30 | } 31 | 32 | impl Runner { 33 | pub fn new() -> Self { 34 | infoln!("building runner..."); 35 | let arg: RunnerArg = serde_json::from_slice(&arg_bytes()).expect("invalid JSON arg"); 36 | let log_level = 2; 37 | let inf = InferenceCapabilities { 38 | backtrack: get_config("backtrack") != 0, 39 | ff_tokens: get_config("ff_tokens") != 0, 40 | conditional_ff_tokens: get_config("ff_tokens") != 0, 41 | fork: false, 42 | }; 43 | let tok_parser = TokenParser::from_llguidance_json( 44 | Arc::new(aici_abi::WasmTokenizerEnv::default()), 45 | arg.grammar, 46 | Logger::new(0, log_level), 47 | inf, 48 | ) 49 | .expect("invalid guidance protobuf"); 50 | 51 | let reporter = Reporter::new(&tok_parser); 52 | Runner { 53 | tok_parser, 54 | reporter, 55 | } 56 | } 57 | } 58 | 59 | fn json_out(obj: &T) { 60 | println!("JSON-OUT: {}", serde_json::to_string(obj).unwrap()); 61 | } 62 | 63 | impl AiciCtrl for Runner { 64 | fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { 65 | InitPromptResult { 66 | prompt: self.tok_parser.process_prompt(arg.prompt), 67 | } 68 | } 69 | fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { 70 | let r = self.tok_parser.mid_process(StepArg { 71 | backtrack: arg.backtrack, 72 | tokens: arg.tokens, 73 | sampled: arg.sampled, 74 | }); 75 | for v in self.reporter.get_progress(&mut self.tok_parser, &r) { 76 | json_out(&v); 77 | } 78 | MidProcessResult::from_branch(r) 79 | } 80 | } 81 | 82 | fn main() {} 83 | 84 | aici_abi::aici_expose_all!(Runner, Runner::new()); 85 | -------------------------------------------------------------------------------- /controllers/llguidance_ctrl/text_req.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | 4 | base = os.getenv("AICI_API_BASE", "http://localhost:4242/v1") 5 | url = base + '/completions' 6 | 7 | headers = { 8 | 'Content-Type': 'application/json', 9 | } 10 | 11 | data = { 12 | 'model': 'model', 13 | 'prompt': 'Once upon a time,', 14 | 'max_tokens': 5, 15 | 'temperature': 0, 16 | 'stream': True 17 | } 18 | 19 | # read tmp/prompt.txt 20 | with open('tmp/prompt.txt', 'r') as file: 21 | data['prompt'] = file.read() 22 | 23 | response = requests.post(url, headers=headers, json=data, stream=True) 24 | 25 | if response.status_code == 200: 26 | for line in response.iter_lines(): 27 | if line: 28 | decoded_line = line.decode('utf-8') 29 | print(decoded_line) 30 | else: 31 | print(f"Request failed with status code {response.status_code}") 32 | print(response.text) 33 | -------------------------------------------------------------------------------- /controllers/pyctrl/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | target = "wasm32-wasi" 3 | 4 | [profile.dev] 5 | strip = "debuginfo" 6 | 7 | [profile.release] 8 | strip = "debuginfo" 9 | -------------------------------------------------------------------------------- /controllers/pyctrl/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aici_pyctrl" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | aici_abi = { path = "../aici_abi" } 8 | serde = { version = "1.0.192", features = ["derive"] } 9 | serde_json = "1.0.108" 10 | anyhow = "1.0.75" 11 | rustpython-vm = { git = "https://github.com/RustPython/RustPython", rev = "317f44945420e", default-features = false, features = ["compiler"] } 12 | rustpython-derive = { git = "https://github.com/RustPython/RustPython", rev = "317f44945420e" } 13 | lazy_static = "1.4.0" 14 | num-traits = "0.2.17" 15 | crossbeam-utils = "0.8.16" 16 | once_cell = "1.18.0" 17 | 18 | [[bin]] 19 | name = "aici_pyctrl" 20 | path = "src/pyctrl.rs" 21 | 22 | [build-dependencies] 23 | glob = "0.3.1" 24 | -------------------------------------------------------------------------------- /controllers/pyctrl/Lib/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 RustPython Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /controllers/pyctrl/Lib/collections/_defaultdict.py: -------------------------------------------------------------------------------- 1 | from reprlib import recursive_repr as _recursive_repr 2 | 3 | class defaultdict(dict): 4 | def __init__(self, *args, **kwargs): 5 | if len(args) >= 1: 6 | default_factory = args[0] 7 | if default_factory is not None and not callable(default_factory): 8 | raise TypeError("first argument must be callable or None") 9 | args = args[1:] 10 | else: 11 | default_factory = None 12 | super().__init__(*args, **kwargs) 13 | self.default_factory = default_factory 14 | 15 | def __missing__(self, key): 16 | if self.default_factory is not None: 17 | val = self.default_factory() 18 | else: 19 | raise KeyError(key) 20 | self[key] = val 21 | return val 22 | 23 | @_recursive_repr() 24 | def __repr_factory(factory): 25 | return repr(factory) 26 | 27 | def __repr__(self): 28 | return f"{type(self).__name__}({defaultdict.__repr_factory(self.default_factory)}, {dict.__repr__(self)})" 29 | 30 | def copy(self): 31 | return type(self)(self.default_factory, self) 32 | 33 | __copy__ = copy 34 | 35 | def __reduce__(self): 36 | if self.default_factory is not None: 37 | args = self.default_factory, 38 | else: 39 | args = () 40 | return type(self), args, None, None, iter(self.items()) 41 | 42 | def __or__(self, other): 43 | if not isinstance(other, dict): 44 | return NotImplemented 45 | 46 | new = defaultdict(self.default_factory, self) 47 | new.update(other) 48 | return new 49 | 50 | def __ror__(self, other): 51 | if not isinstance(other, dict): 52 | return NotImplemented 53 | 54 | new = defaultdict(self.default_factory, other) 55 | new.update(self) 56 | return new 57 | 58 | defaultdict.__module__ = 'collections' 59 | -------------------------------------------------------------------------------- /controllers/pyctrl/Lib/collections/abc.py: -------------------------------------------------------------------------------- 1 | from _collections_abc import * 2 | from _collections_abc import __all__ 3 | from _collections_abc import _CallableGenericAlias 4 | -------------------------------------------------------------------------------- /controllers/pyctrl/Lib/copy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for f in `find -name \*.py | xargs cat | grep -E '^(import |from \S+ import)' | awk '{print $2}' | sort | uniq` ; do 4 | test -f $f.py && continue 5 | test -f $f/__init__.py && continue 6 | if test -f ../../RustPython/pylib/Lib/$f.py ; then 7 | echo cp ../../RustPython/pylib/Lib/$f.py $f.py 8 | continue 9 | fi 10 | echo "? $f" 11 | done 12 | -------------------------------------------------------------------------------- /controllers/pyctrl/Lib/keyword.py: -------------------------------------------------------------------------------- 1 | """Keywords (from "Grammar/python.gram") 2 | 3 | This file is automatically generated; please don't muck it up! 4 | 5 | To update the symbols in this file, 'cd' to the top directory of 6 | the python source tree and run: 7 | 8 | PYTHONPATH=Tools/peg_generator python3 -m pegen.keywordgen \ 9 | Grammar/python.gram \ 10 | Grammar/Tokens \ 11 | Lib/keyword.py 12 | 13 | Alternatively, you can run 'make regen-keyword'. 14 | """ 15 | 16 | __all__ = ["iskeyword", "issoftkeyword", "kwlist", "softkwlist"] 17 | 18 | kwlist = [ 19 | 'False', 20 | 'None', 21 | 'True', 22 | 'and', 23 | 'as', 24 | 'assert', 25 | 'async', 26 | 'await', 27 | 'break', 28 | 'class', 29 | 'continue', 30 | 'def', 31 | 'del', 32 | 'elif', 33 | 'else', 34 | 'except', 35 | 'finally', 36 | 'for', 37 | 'from', 38 | 'global', 39 | 'if', 40 | 'import', 41 | 'in', 42 | 'is', 43 | 'lambda', 44 | 'nonlocal', 45 | 'not', 46 | 'or', 47 | 'pass', 48 | 'raise', 49 | 'return', 50 | 'try', 51 | 'while', 52 | 'with', 53 | 'yield' 54 | ] 55 | 56 | softkwlist = [ 57 | '_', 58 | 'case', 59 | 'match' 60 | ] 61 | 62 | iskeyword = frozenset(kwlist).__contains__ 63 | issoftkeyword = frozenset(softkwlist).__contains__ 64 | -------------------------------------------------------------------------------- /controllers/pyctrl/build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | for entry in glob::glob("Lib/**/*.py").expect("Lib/ exists?").flatten() { 3 | let display = entry.display(); 4 | println!("cargo:rerun-if-changed={display}"); 5 | } 6 | for entry in glob::glob("../../py/pyaici/server*.py") 7 | .expect("exists?") 8 | .flatten() 9 | { 10 | let display = entry.display(); 11 | println!("cargo:rerun-if-changed={display}"); 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /controllers/pyctrl/samples/forkbomb.py: -------------------------------------------------------------------------------- 1 | import pyaici.server as aici 2 | 3 | 4 | async def fork_bomb(): 5 | await aici.FixedTokens("The value of") 6 | id = await aici.fork(20) 7 | await aici.FixedTokens(f" {id} is") 8 | await aici.gen_text(max_tokens=5, store_var=f"x{id}") 9 | 10 | 11 | async def deadlock(): 12 | await aici.wait_vars("foo") 13 | 14 | 15 | async def burn_tokens(): 16 | id = await aici.fork(10) 17 | await aici.FixedTokens(f"The value of {id} in the universe is and everything is ") 18 | await aici.gen_text(max_tokens=200, store_var=f"foo{id}") 19 | 20 | aici.test(burn_tokens()) 21 | -------------------------------------------------------------------------------- /controllers/pyctrl/samples/idents.py: -------------------------------------------------------------------------------- 1 | import pyaici.server as aici 2 | import re 3 | 4 | # asserts for microsoft/Orca-2-13b 5 | 6 | aici.log_level = 1 7 | 8 | async def test_id(): 9 | await aici.FixedTokens("Here's a fib function\n```python\n") 10 | 11 | max_tokens = 60 12 | dyn_lex = aici.DynamicLexer("") 13 | for id in ["def", "fibo", "n", "return", "if"]: 14 | dyn_lex.add(id) 15 | next_token = aici.ConstrainedToken(lambda: dyn_lex.constraint()) 16 | res = [] 17 | text = "" 18 | for _ in range(max_tokens): 19 | tokens = await next_token 20 | if tokens: 21 | res += tokens 22 | print("GEN-STEP:", aici.tokens_repr(tokens)) 23 | text = aici.detokenize(res).decode(errors="replace") 24 | if next_token.finished: 25 | break 26 | print("RESULT:", text) 27 | 28 | 29 | aici.test(test_id()) 30 | -------------------------------------------------------------------------------- /controllers/pyctrl/samples/phi.py: -------------------------------------------------------------------------------- 1 | import pyaici.server as aici 2 | 3 | 4 | async def test_42(): 5 | await aici.FixedTokens( 6 | "The ultimate answer to life, the universe and everything is" 7 | ) 8 | s = await aici.gen_text(regex=r" \d+\.", max_tokens=5, store_var="x") 9 | aici.check_vars({"x": " 42."}) 10 | 11 | 12 | aici.test(test_42()) 13 | -------------------------------------------------------------------------------- /controllers/pyctrl/samples/substr.py: -------------------------------------------------------------------------------- 1 | import pyaici.server as aici 2 | 3 | earth = """ 4 | Earth is rounded into an ellipsoid with a circumference of about 40,000 km. It is the densest planet in the Solar System. Of the four rocky planets, it is the largest and most massive. Earth is about eight light-minutes away from the Sun and orbits it, taking a year (about 365.25 days) to complete one revolution. Earth rotates around its own axis in slightly less than a day (in about 23 hours and 56 minutes). Earth's axis of rotation is tilted with respect to the perPendicular to its orbital plane around the Sun, producing seasons. Earth is orbited by one permanent natural satellite, the Moon, which orbits Earth at 384,400 km (1.28 light seconds) and is roughly a quarter as wide as Earth. Through tidal locking, the Moon always faces Earth with the same side, which causes tides, stabilizes Earth's axis, and gradually slows its rotation. 5 | """ 6 | 7 | prompt = f"""[INST] Here's some text: 8 | {earth} 9 | 10 | Based on the text answer the question: Is Earth axis aligned with its orbit? 11 | Provide a quote from the text, prefixed by 'Source: "', to support your answer. 12 | [/INST] 13 | """ 14 | 15 | async def test_substr(): 16 | await aici.FixedTokens(prompt) 17 | await aici.gen_tokens(max_tokens=60, stop_at="Source: \"", store_var="answer") 18 | await aici.gen_tokens(substring=earth, substring_end="\"", max_tokens=60, store_var="source") 19 | # make sure we can continue generating afterwards 20 | await aici.FixedTokens("\nThe tilt is") 21 | await aici.gen_tokens(max_tokens=6, store_var="tilt") 22 | 23 | aici.test(test_substr()) -------------------------------------------------------------------------------- /controllers/pyctrl/samples/tla.py: -------------------------------------------------------------------------------- 1 | import pyaici.server as aici 2 | 3 | apalache = r""" 4 | 5 | %start tnl 6 | %% 7 | 8 | List 9 | : T 10 | | List ", " T 11 | ; 12 | 13 | tnl: T "/\n/"; 14 | 15 | T 16 | // integers 17 | : "Int" 18 | // immutable constant strings 19 | | "Str" 20 | // boolean 21 | | "Bool" 22 | // functions 23 | | T " -> " T 24 | // sets 25 | | "Set" "(" T ")" 26 | // sequences 27 | | "Seq" "(" T ")" 28 | // tuples 29 | | "<<" List ">>" 30 | // parentheses, e.g., to change associativity of functions 31 | | "(" T ")" 32 | // operators 33 | | "(" List ") => " T 34 | ; 35 | 36 | %% 37 | """ 38 | 39 | 40 | async def gen_and_test_grammar(): 41 | aici.log_level = 3 42 | await aici.FixedTokens( 43 | """ 44 | Here's a TLA+ spec: 45 | 46 | ---- MODULE Counter ---- 47 | VARIABLE 48 | b 49 | q 50 | 51 | Init == b = TRUE 52 | Next == q = q + 1 53 | ==== 54 | 55 | Now with added types: 56 | 57 | ---- MODULE Counter ---- 58 | VARIABLE 59 | b: """ 60 | ) 61 | await aici.gen_tokens(yacc=apalache, max_tokens=10, store_var="b") 62 | await aici.FixedTokens(" q: ") 63 | await aici.gen_tokens(yacc=apalache, max_tokens=10, store_var="q") 64 | 65 | 66 | aici.test(gen_and_test_grammar()) -------------------------------------------------------------------------------- /controllers/pyctrl/samples/warsaw.py: -------------------------------------------------------------------------------- 1 | import pyaici.server as aici 2 | 3 | system_message = "You are a helpful assistant." 4 | 5 | async def fixed(user_message: str): 6 | prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n" 7 | prompt += f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n" 8 | await aici.FixedTokens(prompt) 9 | 10 | async def main(): 11 | await fixed("What is the capital of Poland?") 12 | await aici.gen_tokens(max_tokens=10, store_var="capital") 13 | 14 | aici.start(main()) 15 | -------------------------------------------------------------------------------- /controllers/pyctrl/samples/yesno.py: -------------------------------------------------------------------------------- 1 | import pyaici.server as aici 2 | 3 | async def main(): 4 | tokens = await aici.GetPrompt() 5 | assert len(tokens) > 2, "prompt too short" 6 | await aici.FixedTokens("\n") 7 | await aici.gen_tokens(options=["Yes", "No"]) 8 | 9 | aici.start(main()) 10 | -------------------------------------------------------------------------------- /controllers/pyctrl/wasm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | cargo build --release 6 | BIN=$(cd ../target; pwd) 7 | cp $BIN/wasm32-wasi/release/aici_pyctrl.wasm $BIN/opt.wasm 8 | ls -l $BIN/opt.wasm 9 | if [ "X$1" = "Xbuild" ] ; then 10 | exit 11 | fi 12 | if [ "X$1" = "Xsize" ] ; then 13 | node size.js 14 | fx $BIN/dominators.json 15 | exit 16 | fi 17 | 18 | (cd ../aicirt; cargo build --release) 19 | 20 | mkdir -p tmp 21 | if [ "X$1" = "Xcache" ] ; then 22 | $BIN/release/aicirt --module $BIN/opt.wasm | tee tmp/runlog.txt 23 | exit 24 | fi 25 | 26 | PERF= 27 | if [ `uname` = Linux ] ; then 28 | _PERF="perf stat" 29 | fi 30 | RUST_LOG=info $PERF $BIN/release/aicirt --tokenizer gpt4 --module $BIN/opt.wasm 31 | RUST_LOG=info $PERF $BIN/release/aicirt \ 32 | --tokenizer gpt4 --module $BIN/opt.wasm --run --run-arg ./samples/test.py | tee tmp/runlog.txt 33 | ls -l $BIN/opt.wasm 34 | -------------------------------------------------------------------------------- /controllers/uppercase/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | target = "wasm32-wasi" 3 | 4 | [profile.dev] 5 | strip = "debuginfo" 6 | 7 | [profile.release] 8 | strip = "debuginfo" 9 | -------------------------------------------------------------------------------- /controllers/uppercase/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aici_uppercase" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | aici_abi = { path = "../aici_abi" } 8 | anyhow = "1.0.75" 9 | -------------------------------------------------------------------------------- /controllers/uppercase/README.md: -------------------------------------------------------------------------------- 1 | # AICI Uppercase 2 | 3 | This folder provides a simple, hello-world-like AICI Controller 4 | that you can clone to build your own controller. 5 | 6 | When you clone it, make sure to keep [.cargo/config.toml](.cargo/config.toml) 7 | as it sets up the linker flags for the Wasm target. 8 | 9 | The [main.rs](src/main.rs) shows usage of the `FunctionalRecognizer` interface. 10 | It forces every 4th letter of the model output to be uppercase. 11 | 12 | ``` 13 | $ ../../aici.sh run --build . 14 | will build aici_uppercase from /workspaces/aici/uppercase/Cargo.toml 15 | Compiling aici_abi v0.1.0 (/workspaces/aici/aici_abi) 16 | Compiling aici_uppercase v0.1.0 (/workspaces/aici/uppercase) 17 | Finished release [optimized + debuginfo] target(s) in 1.81s 18 | built: /workspaces/aici/target/wasm32-wasi/release/aici_uppercase.wasm, 0.189 MiB 19 | upload module... 193kB -> 675kB id:a4000d9b 20 | [DONE] 21 | [Response] Here's a tweet: 22 | I'm SO EXCITED! I'm GoinG toBe aMom!I'm GoinG toHaVeA BaBy! 23 | ``` 24 | 25 | This is only meant as a sample - it could be done with [PyCtrl](../pyctrl) or 26 | [JsCtrl](../jsctrl) and a simple regex. 27 | 28 | ## Yes/no controller 29 | 30 | You can also take a look at the [yes/no controller](../aici_abi/src/yesno.rs), which 31 | only allows the model to say "Yes" or "No" in answer to the question in the prompt. 32 | 33 | ``` 34 | $ ./scripts/sample-yesno.sh "Can orcas sing?" 35 | + echo 'Can orcas sing?' 36 | + ./aici.sh run --build aici_abi::yesno - 37 | will build yesno from /workspaces/aici/aici_abi/Cargo.toml 38 | Compiling aici_abi v0.1.0 (/workspaces/aici/aici_abi) 39 | Finished release [optimized + debuginfo] target(s) in 0.58s 40 | upload module... 192kB -> 671kB id:c65e78e9 41 | [DONE] 42 | [Response] Can orcas sing? 43 | 44 | Yes 45 | ``` 46 | 47 | A similar effect can be achieved with PyCtrl and [10x less lines of code](../pyctrl/samples/yesno.py), 48 | but it illustrates the raw token APIs. 49 | 50 | 51 | ``` 52 | $ ./aici.sh run pyctrl/samples/yesno.py 53 | Running with tagged AICI Controller: pyctrl-latest 54 | [0]: FIXED 'Are dolphins fish?\n' 55 | [0]: GEN 'No' 56 | [DONE] 57 | [Response] Are dolphins fish? 58 | No 59 | ``` 60 | 61 | ## DeclCtrl 62 | 63 | For a more full-fledged example, take a look at the [DeclCtrl](../declctrl/src/declctrl.rs). 64 | -------------------------------------------------------------------------------- /controllers/uppercase/src/main.rs: -------------------------------------------------------------------------------- 1 | use aici_abi::{ 2 | host_trie, 3 | recognizer::{FunctionalRecognizer, StackRecognizer}, 4 | tokenize, 5 | toktrie::{SpecialToken, TokTrie}, 6 | AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, 7 | }; 8 | 9 | // This constraints enforces an upper case letter every 4th byte 10 | // The state is the position in the output stream 11 | struct QuadUpper {} 12 | impl FunctionalRecognizer for QuadUpper { 13 | fn initial(&self) -> usize { 14 | 0 15 | } 16 | 17 | fn try_append(&self, state: usize, byte: u8) -> Option { 18 | if state % 4 == 0 && !byte.is_ascii_uppercase() { 19 | None 20 | } else { 21 | Some(state + 1) 22 | } 23 | } 24 | 25 | fn special_allowed(&self, _state: usize, tok: SpecialToken) -> bool { 26 | match tok { 27 | SpecialToken::EndOfSentence => false, 28 | _ => false, 29 | } 30 | } 31 | } 32 | 33 | pub struct Runner { 34 | toktrie: TokTrie, 35 | tokens: Vec, 36 | recognizer: StackRecognizer, 37 | } 38 | 39 | impl Runner { 40 | pub fn new() -> Self { 41 | Runner { 42 | toktrie: host_trie(), 43 | tokens: Vec::new(), 44 | recognizer: StackRecognizer::from(QuadUpper {}), 45 | } 46 | } 47 | } 48 | 49 | impl AiciCtrl for Runner { 50 | fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult { 51 | if arg.prompt.len() <= 1 { 52 | // in case no prompt was provided, invent some 53 | InitPromptResult { 54 | prompt: tokenize("Here's a tweet:\n"), 55 | } 56 | } else { 57 | InitPromptResult::from_arg(arg) 58 | } 59 | } 60 | 61 | fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { 62 | // store our tokens 63 | arg.save_tokens(&mut self.tokens); 64 | // and update the state of our recognizer 65 | self.toktrie 66 | .append_tokens(&mut self.recognizer, &arg.tokens) 67 | .unwrap(); 68 | 69 | // stop after 50 tokens 70 | if self.tokens.len() > 50 || arg.has_eos() { 71 | return MidProcessResult::stop(); 72 | } 73 | 74 | // otherwise, compute bias according to our recognizer 75 | let mut set = self.toktrie.alloc_token_set(); 76 | self.toktrie.compute_bias(&mut self.recognizer, &mut set); 77 | MidProcessResult::sample(set) 78 | } 79 | } 80 | 81 | fn main() { 82 | // test code here? 83 | } 84 | 85 | aici_abi::aici_expose_all!(Runner, Runner::new()); 86 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | 3 | ## How does system prompt or chat mode work with AICI? 4 | 5 | AICI interacts with models at the level of sequences of tokens. 6 | Models themselves do not have a distinct input for "system prompt" or "chat message", 7 | instead they are wrapped in model-specific tokens. 8 | You need to find the model's "Instruction format", typically on model's page on HuggingFace. 9 | 10 | For example, the [Orca-2-13b model](https://huggingface.co/microsoft/Orca-2-13b) has the following instruction format: 11 | ``` 12 | <|im_start|>system 13 | {system_message}<|im_end|> 14 | <|im_start|>user 15 | {user_message}<|im_end|> 16 | <|im_start|>assistant 17 | ``` 18 | 19 | The [Mistral-Instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [Mixtral-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), as well as [CodeLlama-Instruct](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) models use: 20 | ``` 21 | [INST]{instruction}[/INST] 22 | ``` 23 | 24 | Intrestingly, `<|im_start|>` and `<|im_end|>` are special tokens, while `[INST]` and `[/INST]` are regular strings. 25 | 26 | The start token (typically denoted ``) is always implicit in AICI! 27 | 28 | For example, for Orca model you can use the following: 29 | 30 | ```python 31 | import pyaici.server as aici 32 | 33 | system_message = "You are a helpful assistant." 34 | 35 | async def ask(user_message: str): 36 | prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n" 37 | prompt += f"<|im_start|>user\n{user_message}<|im_end|>\n" 38 | prompt += "<|im_start|>assistant\n" 39 | await aici.FixedTokens(prompt) 40 | 41 | async def main(): 42 | await ask("What is the capital of Poland?") 43 | await aici.gen_tokens(max_tokens=10, store_var="capital") 44 | 45 | aici.start(main()) 46 | ``` -------------------------------------------------------------------------------- /docs/proxy.md: -------------------------------------------------------------------------------- 1 | # Client-side access to AICI 2 | 3 | The [Artificial Intelligence Controller Interface (AICI)](https://github.com/microsoft/aici) 4 | can be used to constrain output of an LLM in real time. 5 | While the GPU is working on the next token of the output, the AICI runtime can use the CPU to 6 | compute a user-provided constraint on the next token. 7 | This adds minimal latency to the LLM generation. 8 | 9 | ## Setup 10 | 11 | Install the `pyaici` package, export credentials, and see if the connection is working: 12 | 13 | ```bash 14 | pip uninstall pyaici 15 | pip install -e "git+https://github.com/microsoft/aici#egg=pyaici&subdirectory=py" 16 | export AICI_API_BASE="https://inference.example.com/v1/#key=wht_..." 17 | aici infer --max-tokens=10 "Answer to the Ultimate Question of Life, the Universe, and Everything is" 18 | ``` 19 | 20 | To test out the `pyctrl`, create `answer.py` file with: 21 | 22 | ```python 23 | import pyaici.server as aici 24 | 25 | async def main(): 26 | await aici.FixedTokens("The ultimate answer to the universe is ") 27 | await aici.gen_text(regex=r'\d\d', max_tokens=2) 28 | 29 | aici.start(main()) 30 | ``` 31 | 32 | You can run it with `aici run answer.py`. Try `aici run --help` for available options. 33 | 34 | You can use `aici --log-level=5 run answer.py` to see arguments to the REST requests, 35 | if you want to do them yourself. 36 | -------------------------------------------------------------------------------- /py/promptlib/README.md: -------------------------------------------------------------------------------- 1 | Promptlib is a sample client-side Python library that interacts with the DeclCtrl AIController, using the pyaici package for communication. 2 | See the notebooks folder for examples of using promptlib. 3 | -------------------------------------------------------------------------------- /py/promptlib/__init__.py: -------------------------------------------------------------------------------- 1 | from .promptlib import PromptNode, append, begin, end 2 | from .promptlib import PromptProgram 3 | from .promptlib import gen, choose, wait 4 | from .promptlib import AICI 5 | -------------------------------------------------------------------------------- /py/promptlib/promptlib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .prompt import PromptNode, append, begin, end, PromptProgram 3 | from .gen import gen, choose, wait 4 | 5 | from .aici import AICI 6 | 7 | setattr(PromptNode, "append", append) 8 | setattr(PromptNode, "begin", begin) 9 | setattr(PromptNode, "end", end) 10 | setattr(PromptNode, "gen", gen) 11 | setattr(PromptNode, "choose", choose) 12 | setattr(PromptNode, "wait", wait) 13 | 14 | -------------------------------------------------------------------------------- /py/promptlib/promptlib/aici.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import requests 3 | import json 4 | import sys 5 | import os 6 | import re 7 | from typing import Optional 8 | 9 | import pyaici.rest as aici_rest 10 | 11 | class AICI: 12 | def __init__(self, base_url=None, wasm_runner_id=None, wasm_runner_path=None, wasm_runner_buildsh=None ): 13 | self.base_url = base_url 14 | 15 | if wasm_runner_id is None: 16 | if wasm_runner_path is None: 17 | if wasm_runner_buildsh is None: 18 | raise RuntimeError("Must specify wasm_runner_id or wasm_runner_path or wasm_runner_buildsh") 19 | wasm_runner_path = _compile_wasm(wasm_runner_buildsh) 20 | wasm_runner_id = _upload_wasm(self.base_url, wasm_runner_path) 21 | self.wasm_runner_id = wasm_runner_id 22 | 23 | def run(self, prompt_plan): 24 | return _submit_program(self.base_url, self.wasm_runner_id, prompt_plan, log=True) 25 | 26 | 27 | def _compile_wasm(wasm_runner_buildsh, scriptargs=["build"]): 28 | # separate wasm_runner_buildsh into the script filename and the directory 29 | # containing the script 30 | script_dir = os.path.dirname(wasm_runner_buildsh) 31 | script_name = os.path.basename(wasm_runner_buildsh) 32 | 33 | r = subprocess.run(["sh", script_name].extend(scriptargs), cwd=script_dir) 34 | if r.returncode != 0: 35 | raise RuntimeError(f"error compiling aici promptlib module") 36 | 37 | file_path = script_dir + "/target/strip.wasm" 38 | return file_path 39 | 40 | 41 | def _upload_wasm(base_url, wasm_runner_path): 42 | print("upload module... ", end="") 43 | with open(wasm_runner_path, "rb") as f: 44 | resp = requests.post(base_url + "aici_modules", data=f) 45 | if resp.status_code == 200: 46 | d = resp.json() 47 | dd = d["data"] 48 | mod_id = dd["module_id"] 49 | print( 50 | f"{dd['wasm_size']//1024}kB -> {dd['compiled_size']//1024}kB id:{mod_id[0:8]}" 51 | ) 52 | return mod_id 53 | else: 54 | raise RuntimeError( 55 | f"bad response to model upload: {resp.status_code} {resp.reason}: {resp.text}" 56 | ) 57 | 58 | 59 | def _submit_program(base_url, aici_module, aici_arg, temperature=0, max_tokens=None, log=False): 60 | return aici_rest.run_controller(controller=aici_module, controller_arg=aici_arg, temperature=temperature, max_tokens=max_tokens, base_url=base_url) 61 | -------------------------------------------------------------------------------- /py/promptlib/promptlib/vars.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/py/promptlib/promptlib/vars.py -------------------------------------------------------------------------------- /py/pyaici/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def runner_from_cli(args, dtype: str = 'f32'): 5 | from pyaici.comms import AiciRunner 6 | 7 | tokenizer = args.aici_tokenizer 8 | 9 | # when no explicit --aici-tokenizer, we look for: 10 | # --tokenizer + --tokenizer-revision 11 | # --model + --revision 12 | if not tokenizer: 13 | model_tokenizer = getattr(args, "tokenizer", None) 14 | if model_tokenizer: 15 | rev = getattr(args, "tokenizer_revision", None) 16 | if rev: 17 | model_tokenizer += f"@{rev}" 18 | tokenizer = model_tokenizer 19 | else: 20 | model = getattr(args, "model", None) 21 | if model: 22 | rev = getattr(args, "revision", None) 23 | if rev: 24 | model += f"@{rev}" 25 | tokenizer = model 26 | 27 | if not tokenizer: 28 | raise ValueError("No AICIrt tokenizer specified") 29 | if not args.aici_rt: 30 | raise ValueError("No AICIrt path specified") 31 | 32 | aici = AiciRunner( 33 | rtpath=args.aici_rt, 34 | tokenizer=tokenizer, 35 | trace_file=args.aici_trace, 36 | rtargs=args.aici_rtarg, 37 | pref=args.aici_shm_prefix, 38 | dtype=dtype, 39 | ) 40 | return aici 41 | 42 | 43 | def add_cli_args(parser: argparse.ArgumentParser, single=False): 44 | parser.add_argument( 45 | "--aici-rt", 46 | type=str, 47 | required=True, 48 | help="path to aicirt", 49 | ) 50 | parser.add_argument( 51 | "--aici-shm-prefix", 52 | type=str, 53 | default="/aici0-", 54 | help="prefix for shared memory communication channels", 55 | ) 56 | parser.add_argument( 57 | "--aici-tokenizer", 58 | type=str, 59 | default="", 60 | help= 61 | "tokenizer to use; llama, phi, ...; can also use HF tokenizer name", 62 | ) 63 | parser.add_argument( 64 | "--aici-trace", 65 | type=str, 66 | help="save trace of aicirt interaction to a JSONL file", 67 | ) 68 | parser.add_argument( 69 | "--aici-rtarg", 70 | "-A", 71 | type=str, 72 | default=[], 73 | action="append", 74 | help="pass argument to aicirt process", 75 | ) 76 | 77 | if single: 78 | parser.add_argument( 79 | "--controller", 80 | type=str, 81 | required=True, 82 | help="id of the module to run", 83 | ) 84 | parser.add_argument( 85 | "--controller-arg", 86 | type=str, 87 | default="", 88 | help="arg passed to module (filename)", 89 | ) 90 | 91 | return parser 92 | -------------------------------------------------------------------------------- /py/pyaici/_vllm_protocol.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import List, Optional, Union 3 | from vllm.sampling_params import SamplingParams 4 | from vllm.entrypoints.chat_utils import ChatCompletionMessageParam 5 | 6 | EPSILON_TEMP = 1.5e-5 7 | 8 | class RunRequest(BaseModel): 9 | model: Optional[str] = None 10 | prompt: Optional[str] = None 11 | messages: Optional[List[ChatCompletionMessageParam]] = None 12 | controller: str 13 | controller_arg: Union[str, dict] 14 | temperature: float = 0.0 15 | top_p: float = 1.0 16 | top_k: int = -1 17 | max_tokens: Optional[int] = None 18 | 19 | def to_sampling_params(self): 20 | r = SamplingParams( 21 | temperature=max(self.temperature, EPSILON_TEMP), 22 | top_p=self.top_p, 23 | top_k=self.top_k, 24 | max_tokens=self.max_tokens, 25 | ignore_eos=False, 26 | logprobs=10, 27 | ) 28 | r.has_aici = True # type: ignore 29 | return r 30 | 31 | 32 | class SetTagsRequest(BaseModel): 33 | module_id: str 34 | tags: List[str] -------------------------------------------------------------------------------- /py/pyaici/util.py: -------------------------------------------------------------------------------- 1 | llama_sys_prompt = """ 2 | You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. You are concise. 3 | """ 4 | 5 | 6 | def llama_prompt(prompt: str) -> str: 7 | return f"[INST] <>\n{llama_sys_prompt}\n<>\n\n [/INST]\n[INST] {prompt} [/INST]\n" 8 | 9 | 10 | def codellama_prompt(prompt: str) -> str: 11 | return f"[INST] {prompt} [/INST]\n" 12 | 13 | 14 | system_message = "You are a helpful assistant." 15 | orca_prefix = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n" 16 | orca_suffix = "<|im_end|>\n<|im_start|>assistant\n" 17 | 18 | 19 | def orca_prompt(prompt: str) -> str: 20 | return orca_prefix + prompt + orca_suffix 21 | -------------------------------------------------------------------------------- /py/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='pyaici', 5 | version='0.1', 6 | packages=['pyaici'], 7 | entry_points={ 8 | 'console_scripts': [ 9 | 'aici = pyaici.cli:main' 10 | ] 11 | }, 12 | install_requires=[ 13 | 'requests' 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /py/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | 5 | prj_dir = os.path.dirname(os.path.abspath(__file__)) + "/../.." 6 | sys.path.append(prj_dir + "/py") 7 | 8 | ast_module_path = prj_dir + "/tmp/ast_module.txt" 9 | 10 | 11 | def upload_wasm(): 12 | import pyaici.rest 13 | prog = prj_dir + "/controllers/declctrl" 14 | r = subprocess.run( 15 | ["cargo", "build", "--release"], 16 | cwd=prog, 17 | stdout=subprocess.DEVNULL, 18 | stderr=subprocess.DEVNULL, 19 | ) 20 | if r.returncode != 0: 21 | sys.exit(1) 22 | file_path = prj_dir + "/target/wasm32-wasi/release/aici_declctrl.wasm" 23 | pyaici.rest.log_level = 0 24 | ast_module = pyaici.rest.upload_module(file_path) 25 | 26 | os.makedirs(prj_dir + "/tmp", exist_ok=True) 27 | with open(ast_module_path, "w") as f: 28 | f.write(ast_module) 29 | 30 | 31 | def pytest_configure(config): 32 | import pyaici.rest 33 | 34 | if not hasattr(config, "workerinput"): 35 | upload_wasm() 36 | with open(ast_module_path, "r") as f: 37 | pyaici.rest.ast_module = f.read() 38 | -------------------------------------------------------------------------------- /py/tests/test-prompt.txt: -------------------------------------------------------------------------------- 1 | ```python 2 | def print_prime(n): 3 | """ 4 | Print all primes between 1 and n 5 | """ -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | ;[pytest] 2 | ;testpaths = py/tests 3 | ;addopts = -n 1 4 | -------------------------------------------------------------------------------- /rllm/llama-cpp-low/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llama_cpp_low" 3 | version = "0.0.1" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | anyhow = "1.0.79" 8 | link-cplusplus = "1.0.9" 9 | log = "0.4.20" 10 | num_cpus = "1.16.0" 11 | 12 | [build-dependencies] 13 | bindgen = "0.69.2" 14 | cmake = "0.1.50" 15 | 16 | [features] 17 | default = [] 18 | cuda = [] 19 | -------------------------------------------------------------------------------- /rllm/llama-cpp-low/src/main.rs: -------------------------------------------------------------------------------- 1 | fn main() {} 2 | 3 | #[cfg(feature = "disabled")] 4 | fn main() { 5 | use llama_cpp_low::*; 6 | 7 | let mparams = ModelParams::default(); 8 | let mut cparams = ContextParams::default(); 9 | cparams.n_ctx = 2048; 10 | let model = Model::from_file("tmp/llama-2-7b-chat.Q5_K_M.gguf", mparams, cparams); 11 | let mut batch = Batch::new(512); 12 | let s = model.new_sequence(); 13 | for (idx, tok) in model 14 | .tokenize("Hello, my name is".as_bytes(), true, true) 15 | .iter() 16 | .enumerate() 17 | { 18 | batch.add_token(*tok, idx, &s, false) 19 | } 20 | 21 | let mut logit_idx = batch.len() - 1; 22 | let mut pos = batch.len(); 23 | 24 | batch.enable_logits(logit_idx); 25 | 26 | for _ in 0..10 { 27 | model.decode(&mut batch).unwrap(); 28 | let logits = model.get_logits(logit_idx); 29 | let top_idx = logits 30 | .iter() 31 | .enumerate() 32 | .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) 33 | .map(|(index, _)| index) 34 | .unwrap(); 35 | println!( 36 | "top_idx: {:?} {:?}", 37 | top_idx, 38 | String::from_utf8_lossy(&model.token_to_bytes(top_idx as u32)) 39 | ); 40 | 41 | logit_idx = 0; 42 | batch.clear(); 43 | batch.add_token(top_idx as u32, pos, &s, true); 44 | pos += 1; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /rllm/rllm-base/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rllm" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | anyhow = "1.0.75" 8 | clap = "4.4.8" 9 | hf-hub = "0.3.2" 10 | tokenizers = { version = "0.15.0", features = ["hf-hub"] } 11 | serde_json = "1.0.108" 12 | serde = { version = "1.0.193", features = ["derive"] } 13 | rand = "0.8.5" 14 | half = "2.3.1" 15 | log = "0.4.20" 16 | actix-web = "4.4.0" 17 | tokio = { version = "1.34.0", features = ["sync"] } 18 | futures = "0.3.29" 19 | uuid = { version = "1.6.1", features = ["v4"] } 20 | 21 | aicirt = { path = "../../aicirt" } 22 | aici_abi = { path = "../../controllers/aici_abi" } 23 | libc = "0.2.150" 24 | base64 = "0.21.5" 25 | memmap2 = "0.9.0" 26 | safetensors = "0.4.1" 27 | lazy_static = "1.4.0" 28 | percent-encoding = "2.3.1" 29 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/exec.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt::Display, sync::Arc}; 2 | 3 | use aicirt::TimerRef; 4 | use anyhow::Result; 5 | 6 | use crate::{ 7 | config::{ModelMeta, RllmConfig}, 8 | scheduler::SchedulerOutputs, 9 | seq::{Sequence, SequenceGroup}, 10 | HashMap, LoaderArgs, LogitsProcessor, RllmEngine, 11 | }; 12 | 13 | #[derive(Debug, Clone, Copy)] 14 | pub enum BlockLocation { 15 | GPU, 16 | CPU, 17 | } 18 | 19 | pub trait AiciBias { 20 | fn apply(&self, logits: &mut T, seq_id: usize); 21 | } 22 | 23 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 24 | pub struct SeqId(pub usize); 25 | 26 | impl SeqId { 27 | pub fn to_num(&self) -> usize { 28 | self.0 29 | } 30 | } 31 | 32 | impl Display for SeqId { 33 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 34 | write!(f, "{}", self.0) 35 | } 36 | } 37 | 38 | pub trait SequenceManager { 39 | fn new_sequence(&self) -> SeqId; 40 | fn copy(&self, src: SeqId, dst: SeqId, length: usize); 41 | fn trim(&self, seq: SeqId, length: usize); 42 | fn delete(&self, seq: SeqId); 43 | } 44 | 45 | pub trait ModelExec: Sized { 46 | type Tensor; 47 | type BlockSpaceManager: TBlockSpaceManager; 48 | type AiciBias: AiciBias; 49 | type ModelConfig; 50 | type ModelLoaderArgs: Send + 'static; 51 | type SequenceManager: SequenceManager; 52 | 53 | fn tensor_to_vec1(tensor: &Self::Tensor) -> Vec; 54 | 55 | fn load_model_config( 56 | args: &LoaderArgs, 57 | model_args: &mut Self::ModelLoaderArgs, 58 | ) -> Result<(ModelMeta, Self::ModelConfig)>; 59 | fn verify_args(args: &RllmConfig) -> Result<()>; 60 | fn load_rllm_engine( 61 | args: LoaderArgs, 62 | model_args: Self::ModelLoaderArgs, 63 | ) -> Result>; 64 | 65 | fn sequence_manager(&self) -> Arc; 66 | 67 | fn run( 68 | &mut self, 69 | _vocab_size: usize, 70 | tim: &TimerRef, 71 | step_no: usize, 72 | sched_out: &mut SchedulerOutputs, 73 | ) -> Result<()>; 74 | fn get_logits(&self, seq_id: usize) -> Self::Tensor; 75 | fn finalize_run(&mut self) -> Result<()>; 76 | 77 | fn empty_bias(&self, vocab_size: usize) -> Self::AiciBias; 78 | fn new_bias(&self, slice: &'static [f32], num_seqs: usize, vocab_size: usize) 79 | -> Self::AiciBias; 80 | 81 | fn sample(&self, processor: &mut LogitsProcessor, logits: &Self::Tensor) -> Result; 82 | } 83 | 84 | pub trait TBlockSpaceManager { 85 | fn can_allocate(&self, _seq_group: &SequenceGroup) -> bool; 86 | fn allocate(&mut self, seq_group: &mut SequenceGroup); 87 | 88 | fn can_append_slot(&self, _seq_group: &SequenceGroup) -> bool; 89 | fn append_slots(&mut self, _seq: &mut Sequence, _outputs: &mut SchedulerOutputs); 90 | fn get_num_free_gpu_blocks(&self) -> usize; 91 | fn get_num_free_cpu_blocks(&self) -> usize; 92 | 93 | fn can_swap_in(&self, _seq_group: &SequenceGroup) -> bool { 94 | false 95 | } 96 | 97 | fn swap_in(&mut self, _seq_group: &mut SequenceGroup) -> HashMap { 98 | panic!("no swap_in") 99 | } 100 | 101 | fn swap_out(&mut self, _seq_group: &mut SequenceGroup) -> HashMap { 102 | panic!("no swap_out") 103 | } 104 | 105 | fn can_swap_out(&self, _seq_group: &SequenceGroup) -> bool { 106 | false 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod seq; 2 | 3 | // vllm modules 4 | pub mod config; 5 | mod engine; 6 | mod exec; 7 | mod expected; 8 | pub mod iface; 9 | mod logits; 10 | mod scheduler; 11 | pub mod server; 12 | pub mod util; 13 | 14 | use config::AiciConfig; 15 | pub use engine::*; 16 | pub use exec::*; 17 | pub use logits::LogitsProcessor; 18 | pub use scheduler::*; 19 | use std::sync::atomic::AtomicBool; 20 | 21 | pub use aicirt::HashMap; 22 | pub use aicirt::HashSet; 23 | 24 | pub struct LoaderArgs { 25 | pub tokenizer: String, // one of aici_tokenizer; eg "llama" 26 | pub model_id: String, 27 | pub revision: Option, 28 | pub file: Option, 29 | pub local_weights: Option, 30 | pub alt: usize, 31 | pub aici: AiciConfig, 32 | } 33 | 34 | impl Default for LoaderArgs { 35 | fn default() -> Self { 36 | Self { 37 | tokenizer: "llama".to_string(), 38 | model_id: "NousResearch/Llama-2-7b-hf".to_string(), 39 | revision: None, 40 | local_weights: None, 41 | file: None, 42 | aici: AiciConfig::default(), 43 | alt: 0, 44 | } 45 | } 46 | } 47 | 48 | static mut TRACE: AtomicBool = AtomicBool::new(false); 49 | 50 | pub fn set_trace(trace_enabled: bool) { 51 | unsafe { 52 | TRACE.store(trace_enabled, std::sync::atomic::Ordering::Relaxed); 53 | } 54 | } 55 | 56 | pub fn get_trace() -> bool { 57 | unsafe { TRACE.load(std::sync::atomic::Ordering::Relaxed) } 58 | } 59 | 60 | #[macro_export] 61 | macro_rules! rtrace { 62 | ($($arg:tt)*) => {{ 63 | if $crate::get_trace() { 64 | println!($($arg)*); 65 | } 66 | }}; 67 | } 68 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/logits.rs: -------------------------------------------------------------------------------- 1 | // based on https://github.com/huggingface/candle/blob/main/candle-transformers/src/generation/mod.rs 2 | 3 | use crate::config::{SamplingParams, SAMPLING_EPS}; 4 | use rand::SeedableRng; 5 | 6 | pub struct LogitsProcessor { 7 | pub rng: rand::rngs::StdRng, 8 | pub temperature: Option, 9 | pub top_p: f32, 10 | } 11 | 12 | impl LogitsProcessor { 13 | pub fn new(sampling_params: &SamplingParams) -> Self { 14 | let temperature = if sampling_params.temperature < SAMPLING_EPS { 15 | None 16 | } else { 17 | Some(sampling_params.temperature) 18 | }; 19 | 20 | Self { 21 | rng: rand::rngs::StdRng::from_entropy(), 22 | // seed_from_u64(42), 23 | temperature, 24 | top_p: sampling_params.top_p, 25 | } 26 | } 27 | 28 | pub fn set_temperature(&mut self, temperature: f32) { 29 | if temperature < SAMPLING_EPS { 30 | self.temperature = None; 31 | } else { 32 | self.temperature = Some(temperature); 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/server/api.rs: -------------------------------------------------------------------------------- 1 | use aici_abi::StorageCmd; 2 | use serde::{Deserialize, Serialize}; 3 | 4 | #[derive(Debug, Clone, Serialize, Deserialize)] 5 | pub struct RunRequest { 6 | pub controller: String, 7 | pub controller_arg: serde_json::Value, 8 | #[serde(default)] 9 | pub prompt: String, 10 | pub temperature: Option, // defl 0.0 11 | pub top_p: Option, // defl 1.0 12 | pub top_k: Option, // defl -1 13 | pub max_tokens: Option, // defl context size 14 | } 15 | 16 | #[derive(Debug, Clone, Serialize, Deserialize)] 17 | pub struct RunUsageResponse { 18 | pub sampled_tokens: usize, 19 | pub ff_tokens: usize, 20 | pub cost: usize, 21 | } 22 | 23 | #[derive(Debug, Clone, Serialize, Deserialize)] 24 | pub struct InitialRunResponse { 25 | pub id: String, 26 | pub object: &'static str, // "initial-run" 27 | pub created: u64, 28 | pub model: String, 29 | } 30 | 31 | #[derive(Debug, Clone, Serialize, Deserialize)] 32 | pub struct RunResponse { 33 | pub object: &'static str, // "run" 34 | pub forks: Vec, 35 | pub usage: RunUsageResponse, 36 | } 37 | 38 | #[derive(Debug, Clone, Serialize, Deserialize)] 39 | pub struct RunForkResponse { 40 | pub index: usize, 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | pub finish_reason: Option, 43 | pub text: String, 44 | pub error: String, 45 | pub logs: String, 46 | pub storage: Vec, 47 | pub micros: u64, 48 | } 49 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/server/openai/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eric Buehler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/server/openai/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod requests; 2 | pub mod responses; 3 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/server/openai/requests.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use crate::HashMap; 3 | 4 | #[derive(Debug, Clone, Serialize, Deserialize)] 5 | #[serde(untagged)] 6 | pub enum Messages { 7 | Map(Vec>), 8 | Literal(String), 9 | } 10 | 11 | #[derive(Debug, Clone, Serialize, Deserialize)] 12 | pub enum StopTokens { 13 | Multi(Vec), 14 | Single(String), 15 | } 16 | 17 | #[derive(Debug, Clone, Serialize, Deserialize)] 18 | pub struct ChatCompletionRequest { 19 | pub model: String, 20 | pub messages: Messages, 21 | #[serde(default)] 22 | pub temperature: Option, //0.7 23 | #[serde(default)] 24 | pub top_p: Option, //1.0 25 | #[serde(default)] 26 | pub n: Option, //1 27 | #[serde(default)] 28 | pub max_tokens: Option, //None 29 | #[serde(default)] 30 | pub stop: Option, 31 | #[serde(default)] 32 | pub stream: Option, //false 33 | #[serde(default)] 34 | pub presence_penalty: Option, //0.0 35 | #[serde(default)] 36 | pub frequency_penalty: Option, //0.0 37 | #[serde(default)] 38 | pub logit_bias: Option>, //None 39 | #[serde(default)] 40 | pub user: Option, //None 41 | #[serde(default)] 42 | //Additional candle-vllm params 43 | pub top_k: Option, //-1 44 | #[serde(default)] 45 | pub best_of: Option, //None 46 | #[serde(default)] 47 | pub use_beam_search: Option, //false 48 | #[serde(default)] 49 | pub ignore_eos: Option, //false 50 | #[serde(default)] 51 | pub skip_special_tokens: Option, //false 52 | #[serde(default)] 53 | pub stop_token_ids: Option>, //[] 54 | } 55 | 56 | #[derive(Debug, Clone, Serialize, Deserialize)] 57 | pub struct CompletionRequest { 58 | pub model: String, 59 | pub prompt: String, 60 | 61 | #[serde(default)] 62 | pub temperature: Option, //0.7 63 | #[serde(default)] 64 | pub top_p: Option, //1.0 65 | #[serde(default)] 66 | pub n: Option, //1 67 | #[serde(default)] 68 | pub max_tokens: Option, //None 69 | #[serde(default)] 70 | pub stop: Option>, 71 | #[serde(default)] 72 | pub stream: Option, //false 73 | #[serde(default)] 74 | pub presence_penalty: Option, //0.0 75 | #[serde(default)] 76 | pub frequency_penalty: Option, //0.0 77 | #[serde(default)] 78 | pub logit_bias: Option>, //None 79 | #[serde(default)] 80 | pub user: Option, //None 81 | #[serde(default)] 82 | pub top_k: Option, //-1 83 | #[serde(default)] 84 | pub best_of: Option, //None 85 | #[serde(default)] 86 | pub use_beam_search: Option, //false 87 | #[serde(default)] 88 | pub ignore_eos: Option, //false 89 | #[serde(default)] 90 | pub skip_special_tokens: Option, //false 91 | #[serde(default)] 92 | pub stop_token_ids: Option>, //[] 93 | } 94 | -------------------------------------------------------------------------------- /rllm/rllm-base/src/server/openai/responses.rs: -------------------------------------------------------------------------------- 1 | use aici_abi::StorageCmd; 2 | use serde::{Deserialize, Serialize}; 3 | use std::fmt::Debug; 4 | 5 | #[derive(Debug, Clone, Serialize, Deserialize)] 6 | pub struct ChatCompletionUsageResponse { 7 | pub completion_tokens: usize, 8 | pub prompt_tokens: usize, 9 | pub total_tokens: usize, 10 | pub fuel_tokens: usize, 11 | } 12 | 13 | // tool_calls, function_call not supported! 14 | #[derive(Debug, Clone, Serialize, Deserialize)] 15 | pub struct ChatChoiceData { 16 | pub content: Option, 17 | pub role: String, 18 | } 19 | 20 | #[derive(Debug, Clone, Serialize, Deserialize)] 21 | pub struct ChatChoice { 22 | pub message: ChatChoiceData, 23 | pub finish_reason: Option, 24 | pub index: usize, 25 | } 26 | 27 | #[derive(Debug, Clone, Serialize, Deserialize)] 28 | pub struct ChatCompletionResponse { 29 | pub id: String, 30 | pub choices: Vec, 31 | pub created: u64, 32 | pub model: String, 33 | pub object: &'static str, 34 | pub usage: ChatCompletionUsageResponse, 35 | } 36 | 37 | #[derive(Debug, Clone, Serialize, Deserialize)] 38 | pub struct CompletionChoice { 39 | pub text: String, 40 | pub finish_reason: Option, 41 | pub index: usize, 42 | } 43 | 44 | #[derive(Debug, Clone, Serialize, Deserialize)] 45 | pub struct CompletionResponse { 46 | pub id: String, 47 | pub choices: Vec, 48 | pub created: u64, 49 | pub model: String, 50 | pub object: &'static str, // "text_completion" 51 | pub usage: ChatCompletionUsageResponse, 52 | } 53 | 54 | // tool_calls, function_call not supported! 55 | #[derive(Debug, Clone, Serialize, Deserialize)] 56 | pub struct StreamingChoiceData { 57 | pub content: Option, 58 | pub role: String, 59 | } 60 | 61 | #[derive(Debug, Clone, Serialize, Deserialize)] 62 | pub struct StreamingChatChoice { 63 | pub delta: StreamingChoiceData, 64 | pub finish_reason: Option, 65 | pub index: usize, 66 | } 67 | 68 | #[derive(Debug, Clone, Serialize, Deserialize)] 69 | pub struct StreamingChatCompletionResponse { 70 | pub id: String, 71 | pub choices: Vec, 72 | pub created: u64, 73 | pub model: String, 74 | pub object: &'static str, 75 | } 76 | 77 | #[derive(Debug, Clone, Serialize, Deserialize)] 78 | pub struct List { 79 | pub object: &'static str, // "list" 80 | pub data: Vec, 81 | } 82 | 83 | impl List { 84 | pub fn new(data: Vec) -> Self { 85 | Self { 86 | object: "list", 87 | data, 88 | } 89 | } 90 | } 91 | 92 | #[derive(Debug, Clone, Serialize, Deserialize)] 93 | pub struct Model { 94 | pub object: &'static str, // "model" 95 | pub id: String, 96 | pub created: u64, 97 | pub owned_by: String, 98 | } 99 | 100 | #[derive(Debug, Clone, Serialize, Deserialize)] 101 | pub struct StreamingCompletionChoice { 102 | pub index: usize, 103 | pub finish_reason: Option, 104 | pub text: String, 105 | 106 | pub error: String, 107 | pub logs: String, 108 | pub storage: Vec, 109 | // pub logprobs: Option, 110 | } 111 | 112 | #[derive(Debug, Clone, Serialize, Deserialize)] 113 | pub struct StreamingCompletionResponse { 114 | pub object: &'static str, // "text_completion" 115 | pub id: String, 116 | pub model: String, 117 | pub created: u64, 118 | pub choices: Vec, 119 | pub usage: ChatCompletionUsageResponse, 120 | } 121 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rllm-cuda" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | anyhow = "1.0.75" 8 | clap = "4.4.8" 9 | serde_json = "1.0.108" 10 | serde = { version = "1.0.193", features = ["derive"] } 11 | rand = "0.8.5" 12 | log = "0.4.20" 13 | actix-web = "4.4.0" 14 | tch = { version = "0.14.0" } 15 | 16 | cudarc = { version = "0.10.0", features = ["f16"], optional = true } 17 | tch-cuda = { path = "../tch-cuda", optional = true } 18 | 19 | rllm = { path = "../rllm-base" } 20 | aicirt = { path = "../../aicirt" } 21 | indicatif = "0.17.7" 22 | memmap2 = "0.9.0" 23 | safetensors = "0.4.1" 24 | 25 | [[bin]] 26 | name = "rllm-cuda" 27 | path = "src/rllm-cuda.rs" 28 | 29 | [features] 30 | default = ["cuda"] 31 | cuda = ["dep:tch-cuda", "dep:cudarc"] 32 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/README.md: -------------------------------------------------------------------------------- 1 | # rLLM 2 | 3 | This is a partial port of [vLLM](https://github.com/vllm-project/vllm) 4 | to Rust and [tch-rs](https://github.com/LaurentMazare/tch-rs) 5 | (bindings for [libtorch](https://github.com/pytorch/pytorch/blob/main/docs/libtorch.rst) 6 | which is basis of [PyTorch](https://github.com/pytorch/pytorch)). 7 | It is mostly meant as a proving ground for AICI (AI Controller Interface) integration. 8 | 9 | 10 | ## Building 11 | 12 | If you're not using the supplied docker container make sure to check 13 | that the following environment variables are set: 14 | 15 | ```bash 16 | export CUDA_COMPUTE_CAP="80" 17 | export LIBTORCH_USE_PYTORCH="1" 18 | ``` 19 | 20 | You can run the server with `./server.sh` script; have a look inside to figure out 21 | how to run with different options. 22 | 23 | ## Tests 24 | 25 | The `expected/` directory contains sample prompts along with expected model output - 26 | top 128 logits for the first few tokens of output. 27 | Running `./expected/tests.sh` will run rLLM on these testcases and make sure it gets the 28 | same logits with some tolerance. 29 | 30 | You can inspect test cases like so: 31 | 32 | ``` 33 | $ python scripts/testgen.py show expected/phi-1_5/lighthouse.safetensors 34 | Prompt: 'Write a detailed analogy between mathematics and a lighthouse.\n\nAnswer:' 35 | Output: ' In mathematics, logic is like a beacon of the lighthouse. It guides us' 36 | logits: torch.Size([15, 128]) min: tensor(12.7188) avg: tensor(17.6671) max: tensor(36.0938) 37 | prob_mass: torch.Size([15]) min: tensor(0.9795) avg: tensor(0.9944) max: tensor(0.9999) 38 | $ 39 | ``` 40 | 41 | `prob_mass` refers to the sum of probiblites of the top 128 logits after softmax 42 | (for every token of output). It should be very close to 1. 43 | 44 | ## Models 45 | 46 | The following models have been tested: 47 | 48 | * [CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) 49 | * [CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) 50 | - this barely fits in the 80GB A100, not much space for KV-cache 51 | * [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) 52 | * [Orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b) 53 | * [phi-1_5](https://huggingface.co/microsoft/phi-1_5) 54 | * [phi-2](https://huggingface.co/microsoft/phi-2) 55 | 56 | In general all Llama models should work. 57 | 58 | ## Acknowledgements 59 | 60 | See [top-level README.md](../../README.md#acknowledgements). 61 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama/args.txt: -------------------------------------------------------------------------------- 1 | --model codellama/CodeLlama-13b-Instruct-hf --tokenizer llama16 2 | -s test_maxtol=0.20 -s test_avgtol=0.10 3 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama/cats.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama/cats.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama/lighthouse.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama/lighthouse.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama/primes.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama/primes.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama34/args.txt: -------------------------------------------------------------------------------- 1 | --model codellama/CodeLlama-34b-Instruct-hf --tokenizer llama 2 | -s test_maxtol=8 -s test_avgtol=0.5 3 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama34/cats.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama34/cats.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama34/lighthouse.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama34/lighthouse.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/codellama34/primes.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama34/primes.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/go.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | BIN=$(cd ../target; pwd) 6 | 7 | # the PORT is in fact unused 8 | COMMON_ARGS="--verbose --aicirt $BIN/release/aicirt" 9 | 10 | (cd ../aicirt; cargo build --release) 11 | 12 | RLLM_LOG=debug 13 | 14 | FILES= 15 | for f in "$@" ; do 16 | if [ -f "$f" ] ; then 17 | FILES="$FILES $f" 18 | elif [ -f "$f/args.txt" ] ; then 19 | FILES="$FILES $f/args.txt" 20 | else 21 | echo "File $f not found" 22 | exit 1 23 | fi 24 | done 25 | 26 | for A in $FILES ; do 27 | echo 28 | echo 29 | echo 30 | echo "*** $A ***" 31 | echo 32 | ARGS="$COMMON_ARGS `cat $A`" 33 | for S in $(dirname $A)/*.safetensors ; do 34 | ARGS="$ARGS --test $S" 35 | done 36 | RUST_BACKTRACE=1 \ 37 | RUST_LOG=info,rllm=$RLLM_LOG,aicirt=info \ 38 | cargo run $REL -- $ARGS 39 | done 40 | 41 | echo "All OK!" 42 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/llama/args.txt: -------------------------------------------------------------------------------- 1 | --model NousResearch/Llama-2-7b-hf --tokenizer llama 2 | -s test_maxtol=0.30 -s test_avgtol=0.10 3 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/llama/cats.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/llama/cats.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/llama/lighthouse.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/llama/lighthouse.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/llama/primes.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/llama/primes.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/orca/args.txt: -------------------------------------------------------------------------------- 1 | --model microsoft/Orca-2-13b@refs/pr/22 --tokenizer orca 2 | -s test_maxtol=0.50 -s test_avgtol=0.10 3 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/orca/cats.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/orca/cats.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/orca/lighthouse.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/orca/lighthouse.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/orca/primes.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/orca/primes.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-1_5/args.txt: -------------------------------------------------------------------------------- 1 | --model microsoft/phi-1_5@refs/pr/66 --tokenizer phi 2 | -s test_maxtol=0.05 -s test_avgtol=0.01 -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-1_5/cats.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-1_5/cats.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-1_5/lighthouse.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-1_5/lighthouse.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-1_5/primes.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-1_5/primes.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-2/args.txt: -------------------------------------------------------------------------------- 1 | --model microsoft/phi-2 --tokenizer phi -s test_maxtol=2.0 -s test_avgtol=0.15 -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-2/cats.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-2/cats.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-2/lighthouse.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-2/lighthouse.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/expected/phi-2/primes.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-2/primes.safetensors -------------------------------------------------------------------------------- /rllm/rllm-cuda/scripts/cmp2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | cargo run -- --sample-len 1 6 | mv step-1.safetensor single.safetensor 7 | cargo run -- --sample-len 1 --alt=7 8 | mv step-1.safetensor multi.safetensor 9 | python3 tensorcmp.py single.safetensor multi.safetensor 10 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/scripts/tensorcmp.py: -------------------------------------------------------------------------------- 1 | import safetensors 2 | import sys 3 | import torch 4 | 5 | args = sys.argv[1:] 6 | 7 | 8 | def check_all_close(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-05, atol=1e-10): 9 | # assert torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol) 10 | if torch.equal(tensor1, tensor2): 11 | return 12 | d = tensor1.sub(tensor2).abs() 13 | print(d) 14 | print(d.max()) 15 | assert False 16 | 17 | 18 | def cmp(a: torch.Tensor, b: torch.Tensor): 19 | # if b.shape[0] == 27: 20 | # b = b[4:16, :, :] 21 | # elif b.shape[1] == 27: 22 | # b = b[:, 4:16, :] 23 | 24 | if a.shape == b.shape: 25 | print(a.shape) 26 | check_all_close(a, b) 27 | else: 28 | print("Size wrong", a.shape, b.shape) 29 | assert False 30 | 31 | def load_all(fn: str) -> dict[str, torch.Tensor]: 32 | r = {} 33 | with safetensors.safe_open(fn, framework="pt", device="cuda") as a: 34 | keys = a.keys() 35 | keys.sort() 36 | for key in keys: 37 | kk = key.split("_")[1] 38 | if kk not in r: 39 | r[kk] = a.get_tensor(key) 40 | return r 41 | 42 | def main(): 43 | a = load_all(args[0]) 44 | b = load_all(args[1]) 45 | 46 | x1 = a["x1"] 47 | w1 = a["w1"].t().unsqueeze(0) 48 | 49 | m1 = a["m1"] 50 | print("X", x1.shape, w1.shape) 51 | m1p = x1.matmul(w1) 52 | #cmp(m1, m1p) 53 | 54 | bx1 = b["x1"] 55 | bw1 = b["w1"].t().unsqueeze(0) 56 | cmp(w1, bw1) 57 | cmp(bx1[:,4:16,:], x1) 58 | bm1 = bx1.matmul(bw1) 59 | cmp(bm1[:,4:16,:], m1p) 60 | 61 | 62 | 63 | 64 | # for key in a.keys(): 65 | # print(key) 66 | # cmp(a[key], b[key]) 67 | 68 | main() -------------------------------------------------------------------------------- /rllm/rllm-cuda/src/llm/kernels.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "cuda"))] 2 | pub use super::refkernels::*; 3 | use tch::{Device, Tensor}; 4 | #[cfg(feature = "cuda")] 5 | pub use tch_cuda::flash_attn_varlen as varlen_attn; 6 | #[cfg(feature = "cuda")] 7 | pub use tch_cuda::*; 8 | 9 | /// Convert a vector of lengths into a tensor of offsets, as expected by flash attn. 10 | pub fn to_offsets(seqlens: impl Iterator, device: Device) -> (usize, Tensor) { 11 | let mut offsets = Vec::new(); 12 | let mut offset = 0; 13 | let mut max = 0; 14 | for len in seqlens { 15 | max = std::cmp::max(len, max); 16 | offsets.push(offset as i32); 17 | offset += len; 18 | } 19 | offsets.push(offset as i32); 20 | (max, Tensor::from_slice(offsets.as_slice()).to(device)) 21 | } 22 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/src/llm/paged/cuda_stub.rs: -------------------------------------------------------------------------------- 1 | use tch::Device; 2 | 3 | pub struct CudaEvent {} 4 | 5 | impl CudaEvent { 6 | pub fn new() -> Self { 7 | CudaEvent {} 8 | } 9 | 10 | pub fn wait(&self, _stream: &CudaStream) {} 11 | } 12 | 13 | pub struct CudaStream {} 14 | 15 | impl CudaStream { 16 | pub fn new(_device: Device) -> Self { 17 | CudaStream {} 18 | } 19 | 20 | pub fn current(_device: Device) -> Self { 21 | CudaStream {} 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/src/llm/paged/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "cuda"))] 2 | mod cuda_stub; 3 | 4 | mod batch_info; 5 | mod blocks; 6 | mod cache_engine; 7 | 8 | pub use batch_info::*; 9 | pub use blocks::*; 10 | pub use cache_engine::*; 11 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/src/rllm-cuda.rs: -------------------------------------------------------------------------------- 1 | mod llm; 2 | 3 | use clap::Parser; 4 | use llm::{ 5 | tmodel::{TModel, TchLoaderArgs}, 6 | DType, 7 | }; 8 | use rllm::util::parse_with_settings; 9 | use tch::Device; 10 | 11 | /// Serve LLMs with AICI over HTTP with tch (torch) backend. 12 | #[derive(Parser, Debug)] 13 | #[command(version, about, long_about = None)] 14 | pub struct DriverArgs { 15 | #[clap(flatten)] 16 | pub args: rllm::server::RllmCliArgs, 17 | 18 | /// Specify which type to use in the model (bf16, f16, f32) 19 | #[arg(long, default_value = "", help_heading = "Model")] 20 | pub dtype: String, 21 | 22 | /// Enable nvprof profiling for given engine step (if available) 23 | #[arg(long, default_value_t = 0, help_heading = "Development")] 24 | pub profile_step: usize, 25 | } 26 | 27 | #[actix_web::main] 28 | async fn main() -> () { 29 | let args = parse_with_settings::(); 30 | let _ = args; 31 | 32 | let (device, dtype) = if tch::Cuda::is_available() { 33 | (Device::Cuda(0), None) 34 | } else { 35 | // At least on AMD 5500m MPS is 3x slower than CPU 36 | // #[cfg(target_os = "macos")] 37 | // let r = (Device::Mps, DType::Half); 38 | // #[cfg(not(target_os = "macos"))] 39 | let r = (Device::Cpu, Some(DType::Float)); 40 | r 41 | }; 42 | 43 | let dtype = match args.dtype.as_str() { 44 | "bf16" => Some(DType::BFloat16), 45 | "f16" => Some(DType::Half), 46 | "f32" => Some(DType::Float), 47 | "" => dtype, 48 | _ => panic!("invalid dtype; try one of bf16, f16, f32"), 49 | }; 50 | 51 | let model_args = TchLoaderArgs { 52 | device, 53 | dtype, 54 | profile_step_no: args.profile_step, 55 | }; 56 | rllm::server::server_main::(args.args, model_args).await; 57 | } 58 | -------------------------------------------------------------------------------- /rllm/rllm-cuda/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | ./expected/go.sh \ 6 | expected/phi-1_5 \ 7 | expected/orca 8 | 9 | if [ "$1" = "all" ] ; then 10 | ./expected/go.sh \ 11 | expected/codellama34 \ 12 | expected/codellama \ 13 | expected/phi-2 \ 14 | expected/llama 15 | fi 16 | 17 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rllm-llamacpp" 3 | version = "0.1.0" 4 | edition = "2021" 5 | rust-version = "1.75.0" 6 | 7 | [dependencies] 8 | actix-web = "4.4.0" 9 | anyhow = "1.0.79" 10 | clap = { version = "4.4.18", features = ["derive"] } 11 | llama_cpp_low = { path = "../llama-cpp-low" } 12 | log = "0.4.20" 13 | rllm = { path = "../rllm-base" } 14 | aicirt = { path = "../../aicirt" } 15 | rand = "0.8.5" 16 | 17 | [[bin]] 18 | name = "rllm-llamacpp" 19 | path = "src/rllm-llamacpp.rs" 20 | 21 | [features] 22 | default = [] 23 | cuda = ["llama_cpp_low/cuda"] 24 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/README.md: -------------------------------------------------------------------------------- 1 | # rLLM for llama.cpp 2 | 3 | This is similar to the [CUDA-based rLLM](../rllm-cuda/) 4 | but built on top of [llama.cpp](https://github.com/ggerganov/llama.cpp). 5 | 6 | ## Building 7 | 8 | If you're not using the supplied docker container follow the 9 | [build setup instructions](../../README.md#development-environment-setup). 10 | 11 | To compile and run first aicirt and then the rllm server, run: 12 | 13 | ```bash 14 | ./server.sh phi2 15 | ``` 16 | 17 | Run `./server.sh --help` for more options. 18 | 19 | You can also try passing `--cuda` before `phi2`, which will enable cuBLASS in llama.cpp. 20 | Note that this is different from [rllm-cuda](../rllm-cuda/), 21 | which may give you better performance when doing batched inference. 22 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | CPP=1 exec `dirname $0`/../rllm-cuda/server.sh "$@" 4 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/src/llamacpp/blocks.rs: -------------------------------------------------------------------------------- 1 | use rllm::{ 2 | seq::{Sequence, SequenceGroup}, 3 | SchedulerOutputs, TBlockSpaceManager, 4 | }; 5 | 6 | use super::tmodel::TModel; 7 | 8 | /// Manages the mapping between logical and physical token blocks. 9 | pub struct CppBlockSpaceManager {} 10 | 11 | impl TBlockSpaceManager for CppBlockSpaceManager { 12 | fn can_allocate(&self, _seq_group: &SequenceGroup) -> bool { 13 | true 14 | } 15 | 16 | fn allocate(&mut self, seq_group: &mut SequenceGroup) { 17 | let seq = seq_group.only_seq(); 18 | assert!(seq.num_kv_computed == 0); 19 | } 20 | 21 | fn can_append_slot(&self, _seq_group: &SequenceGroup) -> bool { 22 | true 23 | } 24 | 25 | fn append_slots(&mut self, _seq: &mut Sequence, _outputs: &mut SchedulerOutputs) {} 26 | 27 | fn get_num_free_gpu_blocks(&self) -> usize { 28 | 0 29 | } 30 | 31 | fn get_num_free_cpu_blocks(&self) -> usize { 32 | 0 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/src/llamacpp/loader.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use anyhow::{bail, Result}; 4 | use rllm::{config::ModelMeta, LoaderArgs, Repo, RllmEngine}; 5 | 6 | use llama_cpp_low as cpp; 7 | 8 | use super::{ 9 | blocks::CppBlockSpaceManager, 10 | tmodel::{CppLoaderArgs, TModel}, 11 | }; 12 | 13 | pub(super) fn load_rllm_engine( 14 | args: LoaderArgs, 15 | mut model_args: CppLoaderArgs, 16 | ) -> Result> { 17 | let model = do_load(&args, &mut model_args)?; 18 | let rllm_config = RllmEngine::::build_config(&args, &mut model_args)?; 19 | 20 | let mut cparams = cpp::ContextParams::default(); 21 | cparams.n_batch = rllm_config.scheduler.max_num_batched_tokens as u32; 22 | cparams.n_ctx = 10000; // TODO 23 | model.setup_context(cparams); 24 | 25 | let rllm_config = Arc::new(rllm_config); 26 | let tmodel = TModel::new(rllm_config.clone(), model); 27 | let block_mgr = CppBlockSpaceManager {}; 28 | RllmEngine::build(args, tmodel, block_mgr, rllm_config) 29 | } 30 | 31 | fn do_load(args: &LoaderArgs, model_args: &mut CppLoaderArgs) -> Result { 32 | if model_args.cached_model.is_none() { 33 | let repo = Repo::from(args)?; 34 | log::info!("loading the model from {}", repo); 35 | 36 | let gguf = match args.file.as_ref() { 37 | Some(gguf) => gguf, 38 | None => { 39 | bail!("--gguf file.gguf or --model user/model::file.gguf is required for loading the model") 40 | } 41 | }; 42 | 43 | let file = repo.get(gguf)?; 44 | 45 | let mut mparams = cpp::ModelParams::default(); 46 | // TODO: make this configurable 47 | mparams.set_split_mode(cpp::SplitMode::Layer); 48 | match model_args.n_gpu_layers { 49 | Some(n) => mparams.n_gpu_layers = n as i32, 50 | None => { 51 | mparams.n_gpu_layers = 999; 52 | // by default, don't GPU offload on Intel macs - it's much slower than CPU 53 | #[cfg(all(target_os = "macos", target_arch = "x86_64"))] 54 | { 55 | mparams.n_gpu_layers = 0; 56 | } 57 | } 58 | } 59 | log::info!("{} layer(s) offloaded to GPU", mparams.n_gpu_layers); 60 | 61 | let m = cpp::Model::from_file(file.to_str().unwrap(), mparams)?; 62 | model_args.cached_model = Some(m); 63 | } 64 | 65 | let model = model_args.cached_model.as_ref().unwrap().clone(); 66 | Ok(model) 67 | } 68 | 69 | pub(super) fn load_model_config( 70 | args: &LoaderArgs, 71 | model_args: &mut CppLoaderArgs, 72 | ) -> Result { 73 | let model = do_load(args, model_args)?; 74 | 75 | let info = model.model_info(); 76 | let vocab_size = info.n_vocab.try_into().unwrap(); 77 | let max_sequence_length = info.n_ctx_train.try_into().unwrap(); 78 | 79 | let mut meta = ModelMeta { 80 | id: args.model_id.clone(), 81 | max_sequence_length, 82 | vocab_size, 83 | tok_vocab_size: vocab_size, 84 | }; 85 | 86 | // hidden_size: info.n_embd.try_into().unwrap(), 87 | // rope_theta: info.rope, 88 | // rotary_dim: max_sequence_length, 89 | 90 | let tok = aicirt::bintokens::find_tokenizer(&args.tokenizer)?; 91 | meta.tok_vocab_size = tok.tokrx_info().vocab_size as usize; 92 | 93 | Ok(meta) 94 | } 95 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/src/llamacpp/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod blocks; 2 | pub mod loader; 3 | pub mod tmodel; 4 | pub mod seqid; 5 | 6 | #[derive(Clone)] 7 | pub struct Tensor { 8 | ptr: *mut f32, 9 | size: usize, 10 | } 11 | 12 | impl Tensor { 13 | pub fn from_slice(slice: &'static [f32]) -> Self { 14 | Tensor { 15 | ptr: slice.as_ptr() as *mut f32, 16 | size: slice.len(), 17 | } 18 | } 19 | 20 | pub fn as_slice(&self) -> &[f32] { 21 | unsafe { std::slice::from_raw_parts(self.ptr, self.size) } 22 | } 23 | 24 | pub fn as_mut_slice(&mut self) -> &mut [f32] { 25 | unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) } 26 | } 27 | 28 | pub fn to_vec1(&self) -> Vec { 29 | self.as_slice().to_vec() 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/src/llamacpp/seqid.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Mutex; 2 | 3 | use rllm::{HashMap, SeqId, SequenceManager}; 4 | use llama_cpp_low as cpp; 5 | 6 | pub struct CppSequenceManager { 7 | model: cpp::Model, 8 | seqs: Mutex>, 9 | } 10 | 11 | impl CppSequenceManager { 12 | pub fn new(model: cpp::Model) -> Self { 13 | Self { 14 | model, 15 | seqs: Mutex::new(HashMap::default()), 16 | } 17 | } 18 | 19 | pub fn with_cpp(&self, seq: SeqId, cb: impl FnOnce(&cpp::Sequence)) { 20 | let seqs = self.seqs.lock().unwrap(); 21 | let seq = seqs.get(&seq).unwrap(); 22 | cb(seq); 23 | } 24 | } 25 | 26 | impl SequenceManager for CppSequenceManager { 27 | fn new_sequence(&self) -> SeqId { 28 | let r = self.model.new_sequence(); 29 | let id = SeqId(r.id() as usize); 30 | self.seqs.lock().unwrap().insert(id, r); 31 | id 32 | } 33 | 34 | fn copy(&self, src: SeqId, dst: SeqId, length: usize) { 35 | let seqs = self.seqs.lock().unwrap(); 36 | let src = seqs.get(&src).unwrap(); 37 | let dst = seqs.get(&dst).unwrap(); 38 | dst.cp_from(src, 0, length as i32); 39 | } 40 | 41 | fn trim(&self, seq: SeqId, length: usize) { 42 | let seqs = self.seqs.lock().unwrap(); 43 | let seq = seqs.get(&seq).unwrap(); 44 | seq.rm(length as i32, -1); 45 | } 46 | 47 | fn delete(&self, seq: SeqId) { 48 | self.seqs.lock().unwrap().remove(&seq); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /rllm/rllm-llamacpp/src/rllm-llamacpp.rs: -------------------------------------------------------------------------------- 1 | mod llamacpp; 2 | use clap::Parser; 3 | use llamacpp::tmodel::{CppLoaderArgs, TModel}; 4 | use rllm::util::parse_with_settings; 5 | 6 | /// Serve LLMs with AICI over HTTP with llama.cpp backend. 7 | #[derive(Parser, Debug)] 8 | #[command(version, about, long_about = None)] 9 | pub struct CppArgs { 10 | #[clap(flatten)] 11 | pub args: rllm::server::RllmCliArgs, 12 | 13 | /// Name of .gguf file inside of the model folder/repo. 14 | #[arg(long, help_heading = "Model")] 15 | pub gguf: Option, 16 | 17 | /// How many model layers to offload to GPU (if available) 18 | #[arg(long, short = 'g', help_heading = "Model")] 19 | pub gpu_layers: Option, 20 | } 21 | 22 | #[actix_web::main] 23 | async fn main() -> () { 24 | let mut args = parse_with_settings::(); 25 | args.args.file = args.gguf; 26 | let model_args = CppLoaderArgs::new(args.gpu_layers); 27 | rllm::server::server_main::(args.args, model_args).await; 28 | } 29 | -------------------------------------------------------------------------------- /rllm/tch-cuda/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tch-cuda" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | description = "Flash attention layer for the tch-rs" 7 | license = "MIT" 8 | readme = "README.md" 9 | 10 | [dependencies] 11 | half = { version = "2.3.1", features = ["num-traits"] } 12 | libc = "0.2.151" 13 | tch = "0.14.0" 14 | torch-sys = "0.14.0" 15 | rustc-hash = "2.0.0" 16 | 17 | [build-dependencies] 18 | anyhow = { version = "1", features = ["backtrace"] } 19 | num_cpus = "1.15.0" 20 | rayon = "1.7.0" 21 | glob = "0.3.1" 22 | 23 | [dev-dependencies] 24 | anyhow = { version = "1", features = ["backtrace"] } 25 | -------------------------------------------------------------------------------- /rllm/tch-cuda/README.md: -------------------------------------------------------------------------------- 1 | # tch-cuda 2 | 3 | based on 4 | https://github.com/Dao-AILab/flash-attention/tree/9356a1c0389660d7e231ff3163c1ac17d9e3824a/csrc/flash_attn/src 5 | 6 | 7 | -------------------------------------------------------------------------------- /rllm/tch-cuda/convhd.js: -------------------------------------------------------------------------------- 1 | const fs = require('fs'); 2 | 3 | const args = process.argv.slice(2); 4 | 5 | let all_ext_c = ""; 6 | let all_rust = ""; 7 | 8 | const tpmap = { 9 | "torch::Tensor&": ["tensor", "*mut C_tensor"], 10 | "const torch::Tensor&": ["tensor", "*const C_tensor"], 11 | "const c10::optional&": ["tensor", "*const C_tensor"], 12 | }; 13 | 14 | const r_tpmap = { 15 | "float": "f32", 16 | "int": "i32", 17 | "bool": "bool", 18 | } 19 | 20 | let left = fs.readFileSync(args[0], 'utf-8').replace(/^(\S+) (\w+)\(([^)]*)\);/mg, (_, rettp, fname, args) => { 21 | let skip = false; 22 | 23 | if (rettp != "void") 24 | skip = true; 25 | 26 | if (args.indexOf("std::map<") >= 0) { 27 | return "" 28 | } 29 | 30 | let ext_c = `char* ${fname}_C(\n`; 31 | let rust = `fn ${fname}_C(\n`; 32 | let ext_c_inner = `${fname}(`; 33 | 34 | args.split(/,\s+/).forEach((x, i, arr) => { 35 | const m = /^(.*)\s+(\w+)$/.exec(x.trim()); 36 | let tp = m[1]; 37 | let aname = m[2]; 38 | 39 | if (tp == "torch::Tensor") 40 | skip = true; 41 | 42 | 43 | const tp0 = tpmap[tp]; 44 | 45 | let rtp = tp0?.[1] ?? r_tpmap[tp] ?? `C__${tp}`; 46 | let ctp = tp0?.[0] ?? tp; 47 | 48 | ext_c += ` ${ctp} ${aname}`; 49 | 50 | if (tp0) { 51 | ext_c_inner += `*${aname}` 52 | } else { 53 | ext_c_inner += `${aname}` 54 | } 55 | 56 | rust += ` ${aname}: ${rtp}`; 57 | 58 | if (i < arr.length - 1) { 59 | ext_c += ",\n"; 60 | ext_c_inner += ",\n"; 61 | rust += ",\n"; 62 | } 63 | }) 64 | 65 | ext_c += ") {\nPROTECT(" + ext_c_inner + "));\n}\n\n"; 66 | rust += ") -> *mut libc::c_char;\n\n"; 67 | 68 | if (!skip) { 69 | all_ext_c += ext_c; 70 | all_rust += rust; 71 | } 72 | 73 | return ""; 74 | }); 75 | 76 | left = left.replace(/^\s*#.*/mg, "").trim(); 77 | if (left) { 78 | console.log("left", left); 79 | } else { 80 | console.log(all_ext_c); 81 | console.log(all_rust); 82 | } 83 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, trid@cs.stanford.edu -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/block_info.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | namespace flash { 8 | 9 | //////////////////////////////////////////////////////////////////////////////////////////////////// 10 | 11 | template 12 | struct BlockInfo { 13 | 14 | template 15 | __device__ BlockInfo(const Params ¶ms, const int bidb) 16 | : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) 17 | , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) 18 | , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) 19 | // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. 20 | // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. 21 | , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) 22 | , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) 23 | { 24 | } 25 | 26 | template 27 | inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 28 | return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; 29 | } 30 | 31 | template 32 | inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 33 | return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; 34 | } 35 | 36 | const int sum_s_q; 37 | const int sum_s_k; 38 | const int actual_seqlen_q; 39 | // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. 40 | const int seqlen_k_cache; 41 | const int actual_seqlen_k; 42 | }; 43 | 44 | //////////////////////////////////////////////////////////////////////////////////////////////////// 45 | 46 | } // namespace flash 47 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim128_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim128_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim160_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim160_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim192_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim192_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim224_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim224_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim256_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim256_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim32_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim32_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim64_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim64_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim96_bf16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim96_fp16_sm80.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2023, Tri Dao. 2 | // Splitting the different head dimensions to different files to speed up compilation. 3 | // This file is auto-generated. See "generate_kernels.py" 4 | 5 | #include "flash_fwd_launch_template.h" 6 | 7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); 8 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/flash_attn/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by 2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 4 | 5 | #pragma once 6 | 7 | /// @param COND - a boolean expression to switch by 8 | /// @param CONST_NAME - a name given for the constexpr bool variable. 9 | /// @param ... - code to execute for true and false 10 | /// 11 | /// Usage: 12 | /// ``` 13 | /// BOOL_SWITCH(flag, BoolConst, [&] { 14 | /// some_function(...); 15 | /// }); 16 | /// ``` 17 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 18 | [&] { \ 19 | if (COND) { \ 20 | constexpr static bool CONST_NAME = true; \ 21 | return __VA_ARGS__(); \ 22 | } else { \ 23 | constexpr static bool CONST_NAME = false; \ 24 | return __VA_ARGS__(); \ 25 | } \ 26 | }() 27 | 28 | #define FP16_SWITCH(COND, ...) \ 29 | [&] { \ 30 | if (COND) { \ 31 | using elem_type = cutlass::half_t; \ 32 | return __VA_ARGS__(); \ 33 | } else { \ 34 | using elem_type = cutlass::bfloat16_t; \ 35 | return __VA_ARGS__(); \ 36 | } \ 37 | }() 38 | 39 | #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ 40 | [&] { \ 41 | if (HEADDIM <= 32) { \ 42 | constexpr static int kHeadDim = 32; \ 43 | return __VA_ARGS__(); \ 44 | } else if (HEADDIM <= 64) { \ 45 | constexpr static int kHeadDim = 64; \ 46 | return __VA_ARGS__(); \ 47 | } else if (HEADDIM <= 96) { \ 48 | constexpr static int kHeadDim = 96; \ 49 | return __VA_ARGS__(); \ 50 | } else if (HEADDIM <= 128) { \ 51 | constexpr static int kHeadDim = 128; \ 52 | return __VA_ARGS__(); \ 53 | } else if (HEADDIM <= 160) { \ 54 | constexpr static int kHeadDim = 160; \ 55 | return __VA_ARGS__(); \ 56 | } else if (HEADDIM <= 192) { \ 57 | constexpr static int kHeadDim = 192; \ 58 | return __VA_ARGS__(); \ 59 | } else if (HEADDIM <= 224) { \ 60 | constexpr static int kHeadDim = 224; \ 61 | return __VA_ARGS__(); \ 62 | } else if (HEADDIM <= 256) { \ 63 | constexpr static int kHeadDim = 256; \ 64 | return __VA_ARGS__(); \ 65 | } \ 66 | }() 67 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/attention/attention_dtypes.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "attention_generic.cuh" 4 | #include "dtype_float16.cuh" 5 | #include "dtype_float32.cuh" 6 | #include "dtype_bfloat16.cuh" 7 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/attention/attention_generic.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h 3 | * Copyright (c) 2023, The vLLM team. 4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | namespace vllm { 23 | 24 | // A vector type to store Q, K, V elements. 25 | template 26 | struct Vec {}; 27 | 28 | // A vector type to store FP32 accumulators. 29 | template 30 | struct FloatVec {}; 31 | 32 | // Template vector operations. 33 | template 34 | inline __device__ Acc mul(A a, B b); 35 | 36 | template 37 | inline __device__ float sum(T v); 38 | 39 | template 40 | inline __device__ float dot(T a, T b) { 41 | return sum(mul(a, b)); 42 | } 43 | 44 | template 45 | inline __device__ float dot(T a, T b) { 46 | return sum(mul(a, b)); 47 | } 48 | 49 | template 50 | inline __device__ void zero(T& dst) { 51 | constexpr int WORDS = sizeof(T) / 4; 52 | union { 53 | T raw; 54 | uint32_t words[WORDS]; 55 | } tmp; 56 | 57 | #pragma unroll 58 | for (int ii = 0; ii < WORDS; ++ii) { 59 | tmp.words[ii] = 0u; 60 | } 61 | dst = tmp.raw; 62 | } 63 | 64 | } // namespace vllm 65 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/attention/attention_utils.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp 3 | * Copyright (c) 2023, The vLLM team. 4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include "../cuda_compat.h" 21 | #include "attention_dtypes.h" 22 | 23 | #include 24 | #include 25 | 26 | namespace vllm { 27 | 28 | // Q*K^T operation. 29 | template 30 | inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { 31 | using A_vec = typename FloatVec::Type; 32 | // Compute the parallel products for Q*K^T (treat vector lanes separately). 33 | A_vec qk_vec = mul(q[0], k[0]); 34 | #pragma unroll 35 | for (int ii = 1; ii < N; ++ii) { 36 | qk_vec = fma(q[ii], k[ii], qk_vec); 37 | } 38 | 39 | // Finalize the reduction across lanes. 40 | float qk = sum(qk_vec); 41 | #pragma unroll 42 | for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { 43 | qk += VLLM_SHFL_XOR_SYNC(qk, mask); 44 | } 45 | return qk; 46 | } 47 | 48 | template 49 | struct Qk_dot { 50 | template 51 | static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { 52 | return qk_dot_(q, k); 53 | } 54 | }; 55 | 56 | } // namespace vllm 57 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/cache.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | void swap_blocks( 7 | torch::Tensor& src, 8 | torch::Tensor& dst, 9 | const std::map& block_mapping); 10 | 11 | void copy_blocks( 12 | std::vector& key_caches, 13 | std::vector& value_caches, 14 | const std::map>& block_mapping); 15 | 16 | void reshape_and_cache( 17 | torch::Tensor& key, 18 | torch::Tensor& value, 19 | torch::Tensor& key_cache, 20 | torch::Tensor& value_cache, 21 | torch::Tensor& slot_mapping); 22 | 23 | void gather_cached_kv( 24 | torch::Tensor& key, 25 | torch::Tensor& value, 26 | torch::Tensor& key_cache, 27 | torch::Tensor& value_cache, 28 | torch::Tensor& slot_mapping); 29 | 30 | void copy_blocks_2( 31 | torch::Tensor& key_cache_ptrs_tensor, 32 | torch::Tensor& value_cache_ptrs_tensor, 33 | torch::Tensor& block_mapping_tensor, 34 | torch::Tensor& key0); 35 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/cuda_compat.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef USE_ROCM 4 | #define VLLM_LDG(arg) __ldg(arg) 5 | #else 6 | #define VLLM_LDG(arg) *(arg) 7 | #endif 8 | 9 | #ifndef USE_ROCM 10 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) 11 | #else 12 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) 13 | #endif 14 | 15 | #ifndef USE_ROCM 16 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) 17 | #else 18 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) 19 | #endif 20 | 21 | #ifndef USE_ROCM 22 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 23 | cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) 24 | #else 25 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 26 | hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) 27 | #endif 28 | 29 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int get_device_attribute( 4 | int attribute, 5 | int device_id); 6 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/cuda_utils_kernels.cu: -------------------------------------------------------------------------------- 1 | #ifdef USE_ROCM 2 | #include 3 | #endif 4 | int get_device_attribute( 5 | int attribute, 6 | int device_id) 7 | { 8 | int device, value; 9 | if (device_id < 0) { 10 | cudaGetDevice(&device); 11 | } 12 | else { 13 | device = device_id; 14 | } 15 | cudaDeviceGetAttribute(&value, static_cast(attribute), device); 16 | return value; 17 | } 18 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/dispatch_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from 3 | * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h 4 | */ 5 | #include 6 | 7 | #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ 8 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ 9 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ 10 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) 11 | 12 | #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ 13 | AT_DISPATCH_SWITCH( \ 14 | TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) 15 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/ops.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void paged_attention_v1( 4 | torch::Tensor& out, 5 | torch::Tensor& query, 6 | torch::Tensor& key_cache, 7 | torch::Tensor& value_cache, 8 | int num_kv_heads, 9 | float scale, 10 | torch::Tensor& block_tables, 11 | torch::Tensor& context_lens, 12 | int block_size, 13 | int max_context_len, 14 | const c10::optional& alibi_slopes); 15 | 16 | void paged_attention_v2( 17 | torch::Tensor& out, 18 | torch::Tensor& exp_sums, 19 | torch::Tensor& max_logits, 20 | torch::Tensor& tmp_out, 21 | torch::Tensor& query, 22 | torch::Tensor& key_cache, 23 | torch::Tensor& value_cache, 24 | int num_kv_heads, 25 | float scale, 26 | torch::Tensor& block_tables, 27 | torch::Tensor& context_lens, 28 | int block_size, 29 | int max_context_len, 30 | const c10::optional& alibi_slopes); 31 | 32 | void rms_norm( 33 | torch::Tensor& out, 34 | torch::Tensor& input, 35 | torch::Tensor& weight, 36 | float epsilon); 37 | 38 | void fused_add_rms_norm( 39 | torch::Tensor& input, 40 | torch::Tensor& residual, 41 | torch::Tensor& weight, 42 | float epsilon); 43 | 44 | void rotary_embedding( 45 | torch::Tensor& positions, 46 | torch::Tensor& query, 47 | torch::Tensor& key, 48 | int head_size, 49 | torch::Tensor& cos_sin_cache, 50 | bool is_neox); 51 | 52 | void silu_and_mul( 53 | torch::Tensor& out, 54 | torch::Tensor& input); 55 | 56 | void gelu_new( 57 | torch::Tensor& out, 58 | torch::Tensor& input); 59 | 60 | void gelu_fast( 61 | torch::Tensor& out, 62 | torch::Tensor& input); 63 | 64 | #ifndef USE_ROCM 65 | torch::Tensor awq_gemm( 66 | torch::Tensor _in_feats, 67 | torch::Tensor _kernel, 68 | torch::Tensor _scaling_factors, 69 | torch::Tensor _zeros, 70 | int split_k_iters); 71 | #endif 72 | 73 | void squeezellm_gemm( 74 | torch::Tensor vec, 75 | torch::Tensor mat, 76 | torch::Tensor mul, 77 | torch::Tensor lookup_table); 78 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include "cache.h" 2 | #include "cuda_utils.h" 3 | #include "ops.h" 4 | #include 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | // vLLM custom ops 8 | pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); 9 | 10 | // Attention ops 11 | ops.def( 12 | "paged_attention_v1", 13 | &paged_attention_v1, 14 | "Compute the attention between an input query and the cached keys/values using PagedAttention."); 15 | ops.def( 16 | "paged_attention_v2", 17 | &paged_attention_v2, 18 | "PagedAttention V2."); 19 | 20 | // Activation ops 21 | ops.def( 22 | "silu_and_mul", 23 | &silu_and_mul, 24 | "Activation function used in SwiGLU."); 25 | ops.def( 26 | "gelu_new", 27 | &gelu_new, 28 | "GELU implementation used in GPT-2."); 29 | ops.def( 30 | "gelu_fast", 31 | &gelu_fast, 32 | "Approximate GELU implementation."); 33 | 34 | // Layernorm 35 | ops.def( 36 | "rms_norm", 37 | &rms_norm, 38 | "Apply Root Mean Square (RMS) Normalization to the input tensor."); 39 | 40 | ops.def( 41 | "fused_add_rms_norm", 42 | &fused_add_rms_norm, 43 | "In-place fused Add and RMS Normalization"); 44 | 45 | // Rotary embedding 46 | ops.def( 47 | "rotary_embedding", 48 | &rotary_embedding, 49 | "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); 50 | 51 | #ifndef USE_ROCM 52 | // Quantization ops 53 | ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); 54 | #endif 55 | 56 | 57 | ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); 58 | 59 | // Cache ops 60 | pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); 61 | cache_ops.def( 62 | "swap_blocks", 63 | &swap_blocks, 64 | "Swap in (out) the cache blocks from src to dst"); 65 | cache_ops.def( 66 | "copy_blocks", 67 | ©_blocks, 68 | "Copy the cache blocks from src to dst"); 69 | cache_ops.def( 70 | "reshape_and_cache", 71 | &reshape_and_cache, 72 | "Reshape the key and value tensors and cache them"); 73 | cache_ops.def( 74 | "gather_cached_kv", 75 | &gather_cached_kv, 76 | "Gather key and value from the cache into contiguous QKV tensors"); 77 | 78 | // Cuda utils 79 | pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); 80 | cuda_utils.def( 81 | "get_device_attribute", 82 | &get_device_attribute, 83 | "Gets the specified device attribute."); 84 | } 85 | -------------------------------------------------------------------------------- /rllm/tch-cuda/kernels/vllm/reduction_utils.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh 3 | * Copyright (c) 2023, The vLLM team. 4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include "cuda_compat.h" 21 | 22 | namespace vllm { 23 | 24 | template 25 | __inline__ __device__ T warpReduceSum(T val) { 26 | #pragma unroll 27 | for (int mask = 16; mask > 0; mask >>= 1) 28 | val += VLLM_SHFL_XOR_SYNC(val, mask); 29 | return val; 30 | } 31 | 32 | /* Calculate the sum of all elements in a block */ 33 | template 34 | __inline__ __device__ T blockReduceSum(T val) { 35 | static __shared__ T shared[32]; 36 | int lane = threadIdx.x & 0x1f; 37 | int wid = threadIdx.x >> 5; 38 | 39 | val = warpReduceSum(val); 40 | 41 | if (lane == 0) 42 | shared[wid] = val; 43 | 44 | __syncthreads(); 45 | 46 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 47 | // blockDim.x is not divided by 32 48 | val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); 49 | val = warpReduceSum(val); 50 | return val; 51 | } 52 | 53 | } // namespace vllm 54 | -------------------------------------------------------------------------------- /rllm/tch-cuda/tests/flash_attn_tests.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use tch::{kind::Element, Device, IndexOp, Kind, Tensor}; 3 | 4 | fn to_vec3(t: &Tensor) -> Vec>> { 5 | let (d0, d1, d2) = t.size3().unwrap(); 6 | (0..d0) 7 | .map(|i| { 8 | (0..d1) 9 | .map(|j| { 10 | let mut dst = vec![T::ZERO; d2 as usize]; 11 | t.i((i, j, ..)) 12 | .to_kind(T::KIND) 13 | .copy_data::(&mut dst, d2 as usize); 14 | dst 15 | }) 16 | .collect::>() 17 | }) 18 | .collect::>() 19 | } 20 | 21 | fn to_vec3_round(t: &Tensor, digits: i32) -> Vec>> { 22 | let b = 10f32.powi(digits); 23 | let t = to_vec3::(t); 24 | let t = t 25 | .iter() 26 | .map(|t| { 27 | t.iter() 28 | .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) 29 | .collect() 30 | }) 31 | .collect(); 32 | t 33 | } 34 | 35 | #[test] 36 | fn flash_attn_varlen() -> Result<()> { 37 | let device = Device::Cuda(0); 38 | let q = Tensor::arange(48, (Kind::BFloat16, device)).reshape(&[3, 2, 8]); 39 | let k = &q / 40.; 40 | let v = &q / 50.; 41 | let q = &q / 30.; 42 | 43 | let seqlens_q = Tensor::from_slice(&[0i32, 2i32]).to_device(device); 44 | let seqlens_k = Tensor::from_slice(&[0i32, 2i32]).to_device(device); 45 | 46 | let ys = { 47 | let q = q.transpose(0, 1); 48 | let k = k.transpose(0, 1); 49 | let v = v.transpose(0, 1); 50 | tch_cuda::flash_attn_varlen(&q, &k, &v, &seqlens_q, &seqlens_k, 32, 32, 0.5, false) 51 | .transpose(0, 1) 52 | }; 53 | let ys = ys.to_kind(Kind::Float); 54 | 55 | assert_eq!(ys.size(), &[3, 2, 8]); 56 | assert_eq!( 57 | to_vec3_round(&ys, 4), 58 | &[ 59 | [ 60 | [0.084, 0.1035, 0.124, 0.1436, 0.1641, 0.1836, 0.2031, 0.2236], 61 | [0.0923, 0.1118, 0.1318, 0.1523, 0.1719, 0.1924, 0.2119, 0.2324] 62 | ], 63 | [ 64 | [0.4199, 0.4395, 0.459, 0.4805, 0.5, 0.5195, 0.543, 0.5625], 65 | [0.4277, 0.4473, 0.4668, 0.4883, 0.5078, 0.5273, 0.5508, 0.5703] 66 | ], 67 | [ 68 | [0.7539, 0.7734, 0.793, 0.8125, 0.832, 0.8516, 0.875, 0.8945], 69 | [0.7617, 0.7813, 0.8008, 0.8203, 0.8398, 0.8594, 0.8828, 0.9023] 70 | ] 71 | ] 72 | ); 73 | Ok(()) 74 | } 75 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | imports_granularity = "Crate" 2 | group_imports = "One" 3 | -------------------------------------------------------------------------------- /scripts/aici.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | `dirname $0`/../aici.sh "$@" 4 | -------------------------------------------------------------------------------- /scripts/bench-comms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | ./scripts/kill-rt.sh 5 | 6 | set -e 7 | (cd aicirt && cargo run --release -- --bench --name /aicibench-) 8 | 9 | ./aici.sh benchrt \ 10 | --aici-rt ./aicirt/target/release/aicirt \ 11 | --aici-tokenizer llama 12 | -------------------------------------------------------------------------------- /scripts/bench-earley.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | (cd aicirt && cargo build --release) 6 | perf stat ./target/release/aicirt --tokenizer gpt4 --earley-bench 7 | -------------------------------------------------------------------------------- /scripts/bench-server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | HERE=`dirname $0` 4 | PYTHONPATH=$HERE/../py \ 5 | python3 $HERE/py/bench_server.py "$@" 6 | -------------------------------------------------------------------------------- /scripts/bump.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "X`git status --porcelain --untracked-files=no`" != X ] ; then 4 | git status 5 | echo 6 | echo "*** You have local changes; cannot bump." 7 | exit 1 8 | fi 9 | 10 | git pull || exit 1 11 | 12 | eval `git describe --dirty --tags --match 'v[0-9]*' --always | sed -e 's/-.*//; s/v/v0=/; s/\./ v1=/; s/\./ v2=/'` 13 | defl=0.0.0 14 | if [ "X$v0" != X ] ; then 15 | defl=$v0.$v1.$(($v2 + 1)) 16 | fi 17 | set -e 18 | echo "Enter version [Enter = $defl; Ctrl-C to cancel]:" 19 | read ver 20 | if [ "X$ver" = "X" ] ; then 21 | ver="$defl" 22 | else 23 | ver=$(echo "$ver" | sed -e 's/v//i') 24 | fi 25 | if echo "$ver" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$' ; then 26 | : 27 | else 28 | echo "Invalid version: $ver" 29 | exit 1 30 | fi 31 | 32 | set -x 33 | git tag "v$ver" 34 | git push --tags 35 | git push 36 | -------------------------------------------------------------------------------- /scripts/checkdeps.js: -------------------------------------------------------------------------------- 1 | const fs = require('fs'); 2 | const path = require('path'); 3 | 4 | const cargoToml = fs.readFileSync('Cargo.toml', 'utf8'); 5 | let m = /\[dependencies\]\n([^]*)/m.exec(cargoToml); 6 | const deps = m[1].replace(/\n\[[^]*/, "").split('\n').map(line => line.replace(/=.*/, "").trim()).filter(line => line.length > 0); 7 | const depset = new Set(deps); 8 | 9 | // read all *.rs files under src/ recursively 10 | const srcFiles = []; 11 | function readDir(dir) { 12 | const files = fs.readdirSync(dir); 13 | for (const file of files) { 14 | const filePath = path.join(dir, file); 15 | if (fs.statSync(filePath).isDirectory()) { 16 | readDir(filePath); 17 | } 18 | else if (file.endsWith('.rs')) { 19 | const content = fs.readFileSync(filePath, 'utf8'); 20 | for (const dep of depset) { 21 | const dn = dep.replace(/-/g, "_"); 22 | if (content.includes(`${dn}::`)) { 23 | depset.delete(dep); 24 | } 25 | } 26 | srcFiles.push(filePath); 27 | } 28 | } 29 | } 30 | readDir('src'); 31 | console.log('Unused dependencies:', depset); 32 | -------------------------------------------------------------------------------- /scripts/checklinks.js: -------------------------------------------------------------------------------- 1 | const fs = require('fs'); 2 | const path = require('path'); 3 | const exclude = ["tmp/", "node_modules/", "target/", "controllers/RustPython/"] 4 | fs.readFileSync('.gitmodules', 'utf8').replace(/^\s*path = (.*)$/gm, (_, path) => { 5 | exclude.push(path + "/") 6 | }) 7 | 8 | const links = {} 9 | const files = {} 10 | for (let file of process.argv.slice(2)) { 11 | if (exclude.some(path => file.startsWith(path))) { 12 | continue 13 | } 14 | const content = fs.readFileSync(file, 'utf8') 15 | files[file] = content 16 | file = path.resolve(file) 17 | content.replace(/^#+ (.*)$/gm, (_, title) => { 18 | const anchor = "#" + title.toLowerCase().replace(/[^a-z0-9 \-]+/g, '').replace(/ /g, '-') 19 | links[file + anchor] = true 20 | }) 21 | } 22 | 23 | let numerr = 0 24 | let numlinks = 0 25 | let numanchors = 0 26 | let numhttp = 0 27 | 28 | for (const [filename, content] of Object.entries(files)) { 29 | let lineNo = 0 30 | for (const line of content.split("\n")) { 31 | lineNo++ 32 | line.replace(/\[([^\]]+)\]\(([^\)]+)\)/g, (_, title, link) => { 33 | if (link.startsWith("https://") || link.startsWith("http://") || link.startsWith("mailto:")) { 34 | // console.log(link) 35 | numhttp++ 36 | return 37 | } 38 | numlinks++ 39 | if (link.startsWith("#")) { 40 | link = filename + link 41 | } 42 | // split anchor 43 | let [linkfile, anchor] = link.split("#") 44 | linkfile = path.resolve(path.dirname(filename), linkfile) 45 | 46 | if (!fs.existsSync(linkfile)) { 47 | numerr++ 48 | console.log(`${filename}:${lineNo}: Broken link '${title}': ${link}`) 49 | return 50 | } 51 | 52 | if (anchor) { 53 | numanchors++ 54 | if (!links[linkfile + "#" + anchor]) { 55 | numerr++ 56 | console.log(`${filename}:${lineNo}: Broken link to anchor '${title}': ${link}`) 57 | } else { 58 | // console.log(`${filename}:${lineNo}: Found link to anchor '${title}': ${link}`) 59 | } 60 | } 61 | }) 62 | } 63 | } 64 | 65 | if (numerr > 0) { 66 | console.log(`Found ${numerr} broken links`) 67 | process.exit(1) 68 | } else { 69 | console.log(`Exclude: ${exclude.join(", ")}`) 70 | console.log(`Checked ${numlinks} links (incl. ${numanchors} anchors). Skipped ${numhttp} http links.`) 71 | } -------------------------------------------------------------------------------- /scripts/checklinks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | node scripts/checklinks.js *.md */*.md */*/*.md 4 | -------------------------------------------------------------------------------- /scripts/docker-build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | DOCKER_TARGET="$1" 6 | if test -z "$DOCKER_TARGET" ; then 7 | DOCKER_TARGET=vllm-general 8 | fi 9 | 10 | DOCKERFILE="$2" 11 | if test -z "$DOCKERFILE" ; then 12 | DOCKERFILE=.devcontainer/Dockerfile-prod-vllm 13 | fi 14 | 15 | D=`date +%Y%m%d-%H%M` 16 | TAG=`git describe --dirty --tags --match 'v[0-9]*' --always | sed -e 's/^v//; s/-dirty/-'"$D/"` 17 | 18 | set -x 19 | 20 | DOCKER_BUILDKIT=1 \ 21 | docker build . -f $DOCKERFILE \ 22 | --target $DOCKER_TARGET \ 23 | --tag $DOCKER_TARGET \ 24 | --build-arg tag="$TAG" \ 25 | --progress=plain 26 | 27 | if [ "X$DOCKER_PUSH" != X ] ; then 28 | if test -z "$DOCKER_TAG" ; then 29 | DOCKER_TAG=v$(date '+%Y%m%d-%H%M') 30 | fi 31 | docker tag $DOCKER_TARGET $DOCKER_PUSH:$DOCKER_TAG 32 | docker push $DOCKER_PUSH:$DOCKER_TAG 33 | 34 | set +x 35 | echo 36 | echo "Pushed $DOCKER_PUSH:$DOCKER_TAG" 37 | echo 38 | fi 39 | -------------------------------------------------------------------------------- /scripts/docker-cpp-run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | PORT=4242 4 | ADD_ARGS= 5 | VLLM_ARGS="--port $PORT" 6 | DOCKER_ARGS="" 7 | 8 | case "$1" in 9 | --phi2) 10 | shift 11 | ADD_ARGS="-m https://huggingface.co/TheBloke/phi-2-GGUF/blob/main/phi-2.Q8_0.gguf -t phi" 12 | ;; 13 | --folder) 14 | shift 15 | D=`cd $1; pwd` 16 | DOCKER_ARGS="--mount type=bind,source=$D,target=/workspace/model" 17 | ADD_ARGS="-m ./model --aici-tokenizer ./model/tokenizer.json --tokenizer ./model" 18 | shift 19 | ;; 20 | --shell) 21 | shift 22 | DOCKER_ARGS="--entrypoint /bin/bash -it" 23 | VLLM_ARGS="" 24 | ;; 25 | esac 26 | 27 | set -x 28 | docker run \ 29 | --mount source=profile,target=/root,type=volume \ 30 | -p $PORT:$PORT \ 31 | $DOCKER_ARGS \ 32 | aici/rllm-llamacpp \ 33 | $VLLM_ARGS \ 34 | $ADD_ARGS \ 35 | "$@" 36 | -------------------------------------------------------------------------------- /scripts/docker-run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | PORT=4242 4 | ADD_ARGS= 5 | VLLM_ARGS="--port $PORT" 6 | DOCKER_ARGS="" 7 | 8 | if test -z "$DOCKER_TARGET" ; then 9 | DOCKER_TARGET=vllm-general 10 | fi 11 | 12 | 13 | case "$1" in 14 | --orca) 15 | shift 16 | ADD_ARGS="--model microsoft/Orca-2-13b --revision refs/pr/22 --aici-tokenizer=orca" 17 | ;; 18 | --folder) 19 | shift 20 | D=`cd $1; pwd` 21 | DOCKER_ARGS="--mount type=bind,source=$D,target=/vllm-workspace/model" 22 | ADD_ARGS="--model ./model --aici-tokenizer ./model/tokenizer.json --tokenizer ./model" 23 | shift 24 | ;; 25 | --shell) 26 | shift 27 | DOCKER_ARGS="--entrypoint /bin/bash -it" 28 | VLLM_ARGS="" 29 | ;; 30 | esac 31 | 32 | set -x 33 | docker run \ 34 | --privileged \ 35 | --gpus=all \ 36 | --shm-size=8g \ 37 | --mount source=profile,target=/root,type=volume \ 38 | -p $PORT:$PORT \ 39 | $DOCKER_ARGS \ 40 | $DOCKER_TARGET \ 41 | $VLLM_ARGS \ 42 | $ADD_ARGS \ 43 | "$@" 44 | -------------------------------------------------------------------------------- /scripts/docker-vllm-build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd py/vllm 4 | DOCKER_BUILDKIT=1 \ 5 | docker build . \ 6 | --target vllm-openai \ 7 | --tag vllm/vllm-openai \ 8 | --build-arg max_jobs=16 \ 9 | --build-arg nvcc_threads=16 \ 10 | --build-arg torch_cuda_arch_list="8.0 9.0+PTX" 11 | -------------------------------------------------------------------------------- /scripts/hf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | RUST_LOG=info,tokenizers=error,rllm=debug,aicirt=info \ 7 | PYTHONPATH=py \ 8 | python3 scripts/py/run_hf.py \ 9 | --aici-rt ./target/release/aicirt \ 10 | --controller gh:microsoft/aici/pyctrl \ 11 | --controller-arg controllers/pyctrl/samples/test.py \ 12 | --aici-tokenizer phi \ 13 | --model microsoft/phi-2 14 | -------------------------------------------------------------------------------- /scripts/kill-rt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | P=`ps -ax|grep 'aicir[t]' | awk '{print $1}'` 4 | 5 | if [ "X$P" != "X" ] ; then 6 | echo "KILL $P" 7 | kill $P 8 | fi 9 | -------------------------------------------------------------------------------- /scripts/kill-server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | P=`ps -ax|grep 'aicir[t]\|rllm-serve[r]\|/serve[r]\.sh\|node.*/buil[t]/worker\|python[3] -m vllm.entrypoints' | awk '{print $1}' | xargs echo` 4 | 5 | if [ "X$P" != "X" ] ; then 6 | echo "KILL $P" 7 | kill $P 8 | sleep 1 9 | kill -9 $P 2>/dev/null 10 | true 11 | fi 12 | -------------------------------------------------------------------------------- /scripts/random/disasm.js: -------------------------------------------------------------------------------- 1 | const { exec } = require("child_process"); 2 | 3 | let args = process.argv.slice(2); 4 | if (args.length != 1 || !args[0].endsWith(".wasm")) { 5 | console.log("Usage: node scripts/disasm.js cache/.wasm"); 6 | console.log("Expects cache/.elf to also exist"); 7 | process.exit(1); 8 | } 9 | 10 | function run(cmd, cb) { 11 | exec(cmd, { 12 | encoding: "utf-8", 13 | maxBuffer: 128 * 1024 * 1024, 14 | }, (error, stdout, stderr) => { 15 | if (error) { 16 | console.log(`error: ${error.message}`) 17 | process.exit(1) 18 | } 19 | if (stderr) { 20 | console.log(`stderr: ${stderr}`); 21 | process.exit(1) 22 | } 23 | cb(stdout) 24 | }); 25 | } 26 | 27 | run(`wasm-objdump -x ${args[0]} | rustfilt -h`, stdout => { 28 | const repl = {} 29 | stdout.split("\n").forEach(line => { 30 | // - func[6] sig=0 ::reserve_for_push::h27d4ac8d729e40c6> 31 | const m = /^\s*- func\[(\d+)\] .* <(.*)>$/.exec(line); 32 | if (m) { 33 | repl[m[1]] = m[2]; 34 | } 35 | return "" 36 | }) 37 | 38 | run(`objdump -d ${args[0].replace(".wasm", ".elf")}`, stdout => { 39 | console.log(stdout 40 | .replace( 41 | /<[^<>]*(\+0x[a-f0-9]+)>/g, 42 | (_, addr) => addr) 43 | .replace( 44 | /wasm\[0\]::function\[(\d+)\]/g, 45 | (_, no) => { 46 | return repl[no] || `wasm[0]::function[${no}]` 47 | } 48 | )) 49 | }) 50 | }); -------------------------------------------------------------------------------- /scripts/random/parse-tokenizer-automaton.js: -------------------------------------------------------------------------------- 1 | const fs = require("fs") 2 | 3 | const nevers = [0, 1, 2, 16, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 32000] 4 | 5 | const MIN_DELTA = 0 6 | const MAX_DELTA = MIN_DELTA + 2 7 | 8 | function diff(a, b) { 9 | let i = 0 10 | let j = 0 11 | let delta = 0 12 | while (i < a.length || j < b.length) { 13 | if (i < a.length && j < b.length) { 14 | if (a[i] == b[j]) { 15 | i++; 16 | j++; 17 | } else if (a[i] < b[j]) { 18 | delta++; 19 | i++; 20 | } else { 21 | delta++; 22 | j++; 23 | } 24 | } else if (i < a.length) { 25 | delta++; 26 | i++; 27 | } else { 28 | delta++; 29 | j++; 30 | } 31 | 32 | if (delta > MAX_DELTA) { 33 | return delta; 34 | } 35 | } 36 | return delta; 37 | } 38 | 39 | const buckets = [] 40 | let no_bucket_size = 0 41 | 42 | fs.readFileSync("tmp/tokens.txt", "utf8").split("\n").forEach((line, i) => { 43 | const m = /^(\d+) ==> (true|false) (.*)/.exec(line); 44 | if (!m) return 45 | const tokid = +m[1]; 46 | let elts = Array.from(JSON.parse(m[3])); 47 | const neg = m[2] == "true"; 48 | const isAllowed = (e) => { 49 | if (neg) return !elts.includes(e); 50 | return elts.includes(e); 51 | } 52 | 53 | const nev = nevers.find(e => isAllowed(e)); 54 | if (nev) { 55 | console.log(tokid, "N", nev); 56 | } 57 | 58 | const empty = elts.length == 0 && !neg; 59 | if (empty) { 60 | //console.log(tokid, "E"); 61 | } else { 62 | if (!neg) { 63 | console.log(tokid, "A", elts.length); 64 | } else { 65 | let existing = false 66 | elts = elts.filter(e => !nevers.includes(e)); 67 | for (const b of buckets) { 68 | if (diff(elts, b) <= MIN_DELTA) { 69 | existing = true; 70 | break; 71 | } 72 | } 73 | if (!existing) { 74 | buckets.push(elts); 75 | } 76 | no_bucket_size += elts.length; 77 | console.log(tokid, "F", elts.length, buckets.length); 78 | } 79 | } 80 | }) 81 | 82 | console.log(buckets.reduce((a, b) => a + b.length, 0), no_bucket_size) 83 | -------------------------------------------------------------------------------- /scripts/random/tokenizer-stats.js: -------------------------------------------------------------------------------- 1 | const fs = require("fs"); 2 | const folder = "src/tokenizers/"; 3 | 4 | for (const fn of fs.readdirSync(folder)) { 5 | if (!fn.endsWith(".json")) continue; 6 | // if (fn != "gpt4.json") continue; 7 | console.log(fn); 8 | 9 | const obj = JSON.parse(fs.readFileSync(folder + fn, "utf-8")); 10 | 11 | const keys = []; 12 | 13 | for (const k of Object.keys(obj.binary)) { 14 | keys.push(Buffer.from(k, "hex").toString("utf-8")); 15 | } 16 | keys.push(...Object.keys(obj.text)); 17 | keys.sort((a, b) => b.length - a.length); 18 | 19 | // console.log( keys.filter(k => /^\s*$/.test(k)).map(x => JSON.stringify(x)).join("\n") ) 20 | // console.log( 21 | // keys 22 | // .filter((k) => /^\s*$/.test(k)) 23 | // .filter((k) => !/^[ \t\r\n]*$/.test(k) && k.length != 1) 24 | // .map((x) => JSON.stringify(x)) 25 | // .join("\n") 26 | // ); 27 | 28 | const nonws = keys.filter((k) => !/^\s*$/.test(k)); 29 | 30 | console.log( 31 | nonws 32 | .filter((k) => / /.test(k)) 33 | .filter((k) => k[0] != " ") 34 | .map((x) => JSON.stringify(x)) 35 | .join("\n") 36 | ); 37 | } 38 | -------------------------------------------------------------------------------- /scripts/sample-uppercase.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | ./aici.sh run --build controllers/uppercase 3 | -------------------------------------------------------------------------------- /scripts/sample-yesno.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | PROMPT="$*" 3 | if [ -z "$PROMPT" ]; then 4 | PROMPT="Is coffee any good?" 5 | fi 6 | set -x 7 | echo "$PROMPT" | ./aici.sh run --build controllers/aici_abi::yesno - 8 | -------------------------------------------------------------------------------- /scripts/tag-ctrls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd "`dirname $0`/.." 4 | 5 | TIMESTAMP=`date --utc '+%+4Y-%m-%d-%H%M'` 6 | 7 | CTRLS="$*" 8 | if [ X"$CTRLS" = X ]; then 9 | CTRLS="declctrl pyctrl jsctrl llguidance_ctrl" 10 | fi 11 | 12 | for ctrl in $CTRLS ; do 13 | bctrl=$(echo $ctrl | sed -e 's/_ctrl//') 14 | ./aici.sh build controllers/$ctrl -T $ctrl-latest -T $ctrl-$TIMESTAMP -T $bctrl 15 | done 16 | -------------------------------------------------------------------------------- /scripts/test-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | ./scripts/test-pyctrl.sh 6 | ./scripts/test-jsctrl.sh 7 | pytest 8 | -------------------------------------------------------------------------------- /scripts/test-guidance.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "X$AZURE_GUIDANCE_URL" = "X" ] ; then 4 | if [ "X$AICI_API_BASE" = "X" ] ; then 5 | AICI_API_BASE="http://127.0.0.1:4242/v1/" 6 | fi 7 | AZURE_GUIDANCE_URL="$AICI_API_BASE" 8 | fi 9 | export AZURE_GUIDANCE_URL 10 | 11 | FILES="tests/need_credentials/test_azure_guidance.py tests/model_integration/test_greedy.py" 12 | 13 | cd $(dirname $0)/../py/guidance 14 | 15 | if [ "X$1" != "X" ] ; then 16 | if [ "X${1:0:2}" = "X::" ] ; then 17 | FILES="tests/need_credentials/test_azure_guidance.py$1" 18 | shift 19 | pytest --selected_model azure_guidance --durations=10 $FILES "$@" 20 | exit $? 21 | fi 22 | 23 | if [ "X$1" = "X--ll" ] ; then 24 | shift 25 | pytest tests/unit/test_ll.py "$@" 26 | exit $? 27 | fi 28 | 29 | if [ "X$1" = "X--server" ] ; then 30 | shift 31 | pytest --selected_model azure_guidance --durations=10 $FILES "$@" 32 | exit $? 33 | fi 34 | fi 35 | 36 | function runtest() { 37 | pytest "$@" 38 | if [ $? -ne 0 -a $? -ne 5 ] ; then 39 | : 40 | exit 1 41 | fi 42 | } 43 | 44 | # quick tests first 45 | runtest tests/unit/test_ll.py "$@" 46 | runtest tests/unit "$@" 47 | runtest --selected_model azure_guidance --durations=10 $FILES "$@" 48 | runtest tests/model_integration "$@" 49 | -------------------------------------------------------------------------------- /scripts/test-infer1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "X$AICI_API_BASE" = "X" ] ; then 4 | export AICI_API_BASE="http://127.0.0.1:4242/v1/" 5 | fi 6 | 7 | echo "using model name 'model'; can lead to 404" 8 | 9 | curl -X POST "${AICI_API_BASE}chat/completions" \ 10 | -H "Content-Type: application/json" \ 11 | -d '{ 12 | "model": "model", 13 | "messages": [{"role": "user", "content": "Hello, how are you?"}], 14 | "temperature": 0.7 15 | }' 16 | -------------------------------------------------------------------------------- /scripts/test-jsctrl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | cd `dirname $0` 5 | HERE=`pwd` 6 | cd $HERE/../controllers/jsctrl 7 | tsc --version || npm install -g typescript 8 | tsc -p samples 9 | PYTHONPATH=$HERE/../py \ 10 | python3 ../pyctrl/driver.py samples/dist/test.js "$@" 11 | -------------------------------------------------------------------------------- /scripts/test-llg1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "X$AICI_API_BASE" = "X" ] ; then 4 | export AICI_API_BASE="http://127.0.0.1:4242/v1/" 5 | fi 6 | 7 | curl -X POST "${AICI_API_BASE}run" \ 8 | -H "Content-Type: application/json" \ 9 | -d '{"controller": "llguidance", "controller_arg": {"grammar": {"grammars": [{"nodes": [{"Join": {"sequence": [1, 2]}}, {"String": {"literal": "2 + 2 = "}}, {"Join": {"sequence": [3]}}, {"Gen": {"body_rx": "[0-9]+", "stop_rx": " ", "lazy": true, "stop_capture_name": null, "temperature": 0.0}}], "rx_nodes": []}]}}, "prompt": "", "max_tokens": 3, "temperature": 0.0}' 10 | -------------------------------------------------------------------------------- /scripts/test-parallel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | N=10 4 | if [ -n "$1" ]; then 5 | N=$1 6 | fi 7 | 8 | mkdir -p tmp 9 | rm -f tmp/fail 10 | 11 | for n in $(seq $N) ; do 12 | echo "Start $n" 13 | if ./scripts/test-pyctrl.sh > tmp/logs-$n.txt 2>&1 ; then 14 | echo "Passed test $n" 15 | else 16 | echo "Failed test $n; see tmp/logs-$n.txt" 17 | echo $n >> tmp/fail 18 | fi & 19 | sleep 1 20 | done 21 | 22 | wait 23 | 24 | if [ -f tmp/fail ]; then 25 | echo "Some tests failed; see tmp/fail" 26 | exit 1 27 | fi 28 | -------------------------------------------------------------------------------- /scripts/test-pyctrl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | cd `dirname $0` 5 | HERE=`pwd` 6 | cd $HERE/../controllers/pyctrl 7 | PYTHONPATH=$HERE/../py \ 8 | python3 driver.py samples/test*.py "$@" 9 | -------------------------------------------------------------------------------- /scripts/upload-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export AICI="$(cd `dirname $0`; pwd)/aici.sh --all-prefixes" 4 | 5 | case "$1" in 6 | *aici-controllers-wasm32-wasi-*.tar.?z) 7 | mkdir -p tmp/aici-controllers 8 | tar --strip-components=1 -xf "$1" -C tmp/aici-controllers 9 | if test -f tmp/aici-controllers/tag.sh ; then 10 | cd tmp/aici-controllers 11 | ./tag.sh --latest 12 | rm -rf tmp/aici-controllers 13 | else 14 | echo "No tag.sh found in tmp/aici-controllers" 15 | exit 1 16 | fi 17 | ;; 18 | *) 19 | echo "Usage: $0 aici-controllers-wasm32-wasi-....tar.[xz|gz]" 20 | exit 1 21 | ;; 22 | esac 23 | -------------------------------------------------------------------------------- /scripts/vllm-init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | set -e 5 | mkdir -p tmp 6 | if test -f py/vllm/setup.py; then 7 | : 8 | else 9 | git submodule update --init --recursive 10 | fi 11 | cd py/vllm 12 | python setup.py develop 13 | -------------------------------------------------------------------------------- /scripts/vllm-server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | set -x 5 | 6 | if [ -z "$FOLDER" ]; then 7 | #MODEL_ARGS="--model microsoft/Orca-2-13b --revision refs/pr/22 --aici-tokenizer orca" 8 | MODEL_ARGS="--model microsoft/Phi-3-mini-128k-instruct --trust-remote-code" 9 | #MODEL_ARGS="--model microsoft/Phi-3-medium-128k-instruct --trust-remote-code" 10 | #MODEL_ARGS="--model microsoft/Phi-3-mini-4k-instruct --trust-remote-code" 11 | else 12 | MODEL_ARGS="--model ./$FOLDER --aici-tokenizer ./$FOLDER/tokenizer.json --tokenizer ./$FOLDER" 13 | fi 14 | 15 | (cd aicirt && cargo build --release) 16 | 17 | RUST_LOG=info,tokenizers=error,aicirt=info \ 18 | RUST_BACKTRACE=1 \ 19 | PYTHONPATH=py:py/vllm \ 20 | python3 -m pyaici.vllm_server \ 21 | --enforce-eager \ 22 | --use-v2-block-manager \ 23 | --enable-chunked-prefill \ 24 | --served-model-name=model \ 25 | --aici-rt ./target/release/aicirt \ 26 | -A--wasm-timer-resolution-us=10 \ 27 | $MODEL_ARGS \ 28 | --port 4242 --host 127.0.0.1 \ 29 | "$@" 30 | 31 | # --aici-rtarg="--wasm-max-step-time=50" \ 32 | # --aici-rtarg="--wasm-max-pre-step-time=2" \ 33 | # --aici-rtarg="--wasm-max-init-time=1000" \ 34 | # --aici-rtarg="--wasm-max-memory=64" \ 35 | # --aici-rtarg="--wasm-max-pre-step-time=10" \ 36 | --------------------------------------------------------------------------------