├── src
├── styles
│ └── globals.css
├── pages
│ ├── _document.tsx
│ ├── index.tsx
│ └── _app.tsx
├── server
│ ├── api
│ │ ├── types.ts
│ │ ├── root.ts
│ │ ├── shared.ts
│ │ ├── trpc.ts
│ │ └── routers
│ │ │ └── llama.ts
│ ├── llama
│ │ ├── index.ts
│ │ └── adapters
│ │ │ └── llamacpp.ts
│ ├── plugins
│ │ └── next.ts
│ └── index.ts
├── utils
│ ├── utils.ts
│ └── api.ts
├── recoil
│ ├── states.ts
│ └── templates.ts
├── partials
│ ├── Header.tsx
│ ├── PromptEditor.tsx
│ ├── Generate.tsx
│ ├── TemplateSelect.tsx
│ └── ModelControls.tsx
└── env.mjs
├── .vscode
└── settings.json
├── public
├── demo.gif
├── logo.png
└── favicon.ico
├── .dockerignore
├── postcss.config.cjs
├── prettier.config.cjs
├── tsup.config.ts
├── docker-compose.dev.yml
├── docker-compose.yml
├── next.config.mjs
├── .gitignore
├── tsconfig.json
├── .env.example
├── LICENSE
├── Dockerfile
├── .github
└── workflows
│ └── docker.yml
├── scripts
└── startup.mjs
├── package.json
└── README.md
/src/styles/globals.css:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.tabSize": 2
3 | }
--------------------------------------------------------------------------------
/public/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ItzDerock/llama-playground/HEAD/public/demo.gif
--------------------------------------------------------------------------------
/public/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ItzDerock/llama-playground/HEAD/public/logo.png
--------------------------------------------------------------------------------
/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ItzDerock/llama-playground/HEAD/public/favicon.ico
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | /dist
2 | /bin
3 | /.next
4 | /.vscode
5 | /node_modules
6 | /docker-compose.yml
7 | /docker-compose.dev.yml
8 |
--------------------------------------------------------------------------------
/postcss.config.cjs:
--------------------------------------------------------------------------------
1 | const config = {
2 | plugins: {
3 | autoprefixer: {},
4 | },
5 | };
6 |
7 | module.exports = config;
8 |
--------------------------------------------------------------------------------
/prettier.config.cjs:
--------------------------------------------------------------------------------
1 | /** @type {import("prettier").Config} */
2 | const config = {
3 | plugins: [],
4 | };
5 |
6 | module.exports = config;
7 |
--------------------------------------------------------------------------------
/src/pages/_document.tsx:
--------------------------------------------------------------------------------
1 | import Document from "next/document";
2 | import { createGetInitialProps } from "@mantine/next";
3 |
4 | const getInitialProps = createGetInitialProps();
5 |
6 | export default class _Document extends Document {
7 | static getInitialProps = getInitialProps;
8 | }
9 |
--------------------------------------------------------------------------------
/tsup.config.ts:
--------------------------------------------------------------------------------
1 | import { defineConfig, Options } from "tsup";
2 |
3 | const opts: Options = {
4 | platform: "node",
5 | format: ["cjs"],
6 | treeshake: true,
7 | clean: true,
8 | sourcemap: true,
9 | };
10 |
11 | export default defineConfig([
12 | {
13 | entryPoints: ["src/server/index.ts"],
14 | ...opts,
15 | },
16 | ]);
17 |
--------------------------------------------------------------------------------
/src/server/api/types.ts:
--------------------------------------------------------------------------------
1 | export enum WSMessageType {
2 | // Identity -> server gives client a unique UUID
3 | IDENTITY,
4 | // Completion -> generated tokens are sent to the client
5 | COMPLETION,
6 | // Request Complete -> done generating tokens
7 | REQUEST_COMPLETE,
8 | }
9 |
10 | export type WSMessage = {
11 | type: WSMessageType;
12 | data: string;
13 | };
14 |
--------------------------------------------------------------------------------
/docker-compose.dev.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | app:
5 | build: .
6 | ports:
7 | - 3000:3000
8 | environment:
9 | - USE_BUILT_IN_LLAMA_SERVER=true
10 | - LLAMA_TCP_BIN=auto
11 | - LLAMA_SERVER_HOST=auto
12 | - LLAMA_SERVER_PORT=auto
13 | - LLAMA_MODEL_PATH=/app/models/ggml-model-q4_0.bin
14 | - NODE_ENV=production
15 | volumes:
16 | - ./path/to/model/:/app/models
--------------------------------------------------------------------------------
/src/server/api/root.ts:
--------------------------------------------------------------------------------
1 | import { createTRPCRouter } from "~/server/api/trpc";
2 | import { llamaRouter } from "./routers/llama";
3 |
4 | /**
5 | * This is the primary router for your server.
6 | *
7 | * All routers added in /api/routers should be manually added here.
8 | */
9 | export const appRouter = createTRPCRouter({
10 | llama: llamaRouter
11 | });
12 |
13 | // export type definition of API
14 | export type AppRouter = typeof appRouter;
15 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | app:
5 | image: ghcr.io/itzderock/llama-playground:latest
6 | ports:
7 | - 3000:3000
8 | - 3001:3001
9 | environment:
10 | - USE_BUILT_IN_LLAMA_SERVER=true
11 | - LLAMA_TCP_BIN=auto
12 | - LLAMA_SERVER_HOST=auto
13 | - LLAMA_SERVER_PORT=auto
14 | - LLAMA_MODEL_PATH=/app/models/ggml-model-q4_0.bin
15 | volumes:
16 | - ./path/to/7B/:/app/models
--------------------------------------------------------------------------------
/next.config.mjs:
--------------------------------------------------------------------------------
1 | /**
2 | * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
3 | * This is especially useful for Docker builds.
4 | */
5 | !process.env.SKIP_ENV_VALIDATION && (await import("./src/env.mjs"));
6 |
7 | /** @type {import("next").NextConfig} */
8 | const config = {
9 | reactStrictMode: true,
10 |
11 | /**
12 | * If you have the "experimental: { appDir: true }" setting enabled, then you
13 | * must comment the below `i18n` config out.
14 | *
15 | * @see https://github.com/vercel/next.js/issues/41980
16 | */
17 | i18n: {
18 | locales: ["en"],
19 | defaultLocale: "en",
20 | },
21 | };
22 | export default config;
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.js
7 |
8 | # testing
9 | /coverage
10 |
11 | # database
12 | /prisma/db.sqlite
13 | /prisma/db.sqlite-journal
14 |
15 | # next.js
16 | /.next/
17 | /out/
18 | next-env.d.ts
19 |
20 | # production
21 | /build
22 | /dist
23 |
24 | # misc
25 | .DS_Store
26 | *.pem
27 |
28 | # debug
29 | npm-debug.log*
30 | yarn-debug.log*
31 | yarn-error.log*
32 | .pnpm-debug.log*
33 |
34 | # local env files
35 | # do not commit any .env files to git, except for the .env.example file. https://create.t3.gg/en/usage/env-variables#using-environment-variables
36 | .env
37 | .env*.local
38 |
39 | # vercel
40 | .vercel
41 |
42 | # typescript
43 | *.tsbuildinfo
44 |
45 | # built binaries
46 | /bin/
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es2017",
4 | "lib": ["dom", "dom.iterable", "esnext"],
5 | "allowJs": true,
6 | "checkJs": true,
7 | "skipLibCheck": true,
8 | "strict": true,
9 | "forceConsistentCasingInFileNames": true,
10 | "noEmit": true,
11 | "esModuleInterop": true,
12 | "module": "esnext",
13 | "moduleResolution": "node",
14 | "resolveJsonModule": true,
15 | "isolatedModules": true,
16 | "jsx": "preserve",
17 | "incremental": true,
18 | "noUncheckedIndexedAccess": true,
19 | "baseUrl": ".",
20 | "paths": {
21 | "~/*": ["./src/*"]
22 | }
23 | },
24 | "include": [
25 | ".eslintrc.cjs",
26 | "next-env.d.ts",
27 | "**/*.ts",
28 | "**/*.tsx",
29 | "**/*.cjs",
30 | "**/*.mjs"
31 | ],
32 | "exclude": ["node_modules"]
33 | }
34 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | # The web-server's HOST and PORT
2 | HOST=127.0.0.1
3 | PORT=3000
4 |
5 | # Disable this if you want to run llama.cpp#tcp_server on your own
6 | # Uses the binary path set below
7 | USE_BUILT_IN_LLAMA_SERVER=true
8 |
9 | # Binary location for llama.cpp#tcp_server
10 | # Auto will automatically pull and build the latest version
11 | # Requires build-essential (or equivalent) and make
12 | # This does nothing if USE_BUILT_IN_LLAMA_SERVER is disabled
13 | LLAMA_TCP_BIN=auto
14 |
15 | # If USE_BUILT_IN_LLAMA_SERVER is disabled, enter the llama.cpp#tcp_server tcp details here
16 | # Otherwise, this app will automatically start a llama.cpp#tcp_server server
17 | # If port is set to auto, it will listen on a random open port
18 | LLAMA_SERVER_HOST=auto
19 | LLAMA_SERVER_PORT=auto
20 |
21 | # The path to a model's .bin file
22 | LLAMA_MODEL_PATH=/path/to/llama/ggml-model-q4_0.bin
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Derock
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/pages/index.tsx:
--------------------------------------------------------------------------------
1 | import { Flex, LoadingOverlay } from "@mantine/core";
2 | import { type NextPage } from "next";
3 | import Head from "next/head";
4 | import { useRecoilValue } from "recoil";
5 | import Generate from "~/partials/Generate";
6 | import Header from "~/partials/Header";
7 | import ModelControls from "~/partials/ModelControls";
8 | import PromptEditor from "~/partials/PromptEditor";
9 | import { ClientWSState, wsState } from "~/recoil/states";
10 |
11 | const Home: NextPage = () => {
12 | const state = useRecoilValue(wsState);
13 |
14 | return (
15 | <>
16 |
17 |
21 |
22 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | >
38 | );
39 | };
40 |
41 | export default Home;
42 |
--------------------------------------------------------------------------------
/src/server/llama/index.ts:
--------------------------------------------------------------------------------
1 | import { env } from "~/env.mjs";
2 | import LLaMATCPClient from "./adapters/llamacpp";
3 |
4 | // https://stackoverflow.com/questions/64093560/can-you-keep-a-postgresql-connection-alive-from-within-a-next-js-api
5 | export function getLLaMAClient() {
6 | // TODO: in the future, allow multiple adapters like llama.cpp and llama-rs and other similar models.
7 | if (!("llamaClient" in global)) {
8 | if (
9 | env.USE_BUILT_IN_LLAMA_SERVER === "false" &&
10 | env.LLAMA_SERVER_PORT === "auto"
11 | ) {
12 | throw new Error(
13 | "LLAMA_SERVER_PORT must be set to a port number when USE_BUILT_IN_LLAMA_SERVER is false."
14 | );
15 | }
16 |
17 | const LLaMAClient = new LLaMATCPClient(
18 | env.USE_BUILT_IN_LLAMA_SERVER === "true"
19 | ? {
20 | port: env.LLAMA_SERVER_PORT,
21 | modelPath: env.LLAMA_MODEL_PATH,
22 | binPath: env.LLAMA_TCP_BIN,
23 | debug: env.NODE_ENV === "development",
24 | }
25 | : {
26 | host: env.LLAMA_SERVER_HOST,
27 | port: env.LLAMA_SERVER_PORT as number,
28 | debug: env.NODE_ENV === "development",
29 | }
30 | );
31 |
32 | LLaMAClient.start();
33 |
34 | return ((global as any).llamaClient = LLaMAClient);
35 | }
36 |
37 | return (global as any).llamaClient as LLaMATCPClient;
38 | }
39 |
--------------------------------------------------------------------------------
/src/utils/utils.ts:
--------------------------------------------------------------------------------
1 | import { createServer, AddressInfo } from "net";
2 |
3 | /**
4 | * Searches randomly for an open port
5 | * @param min The minimum port number to search for
6 | * @param max The maximum port number to search for
7 | * @param _i The current iteration
8 | */
9 | export function findRandomOpenPort(min: number = 1000, max: number = 65535, _i: number = 0) {
10 | // if _i is a multiple of 10, log a warning
11 | if (_i > 0 && _i % 10 === 0) {
12 | console.warn(`[warn] findRandomOpenPort: ${_i} iterations, no open port found`);
13 | }
14 |
15 | // generate random port number
16 | const randomNumber = Math.floor(Math.random() * (max - min + 1)) + min;
17 |
18 | // check if port is open
19 | return new Promise((resolve, reject) => {
20 | const server = createServer();
21 | server.unref();
22 | server.on("error", (error) => {
23 | // check the error
24 | if (error.message.includes("EADDRINUSE")) {
25 | // port is in use, try again
26 | resolve(findRandomOpenPort(min, max, _i + 1));
27 | return;
28 | }
29 |
30 | // some other error, throw it
31 | reject(error);
32 | });
33 |
34 | // attempt to listen on the port
35 | server.listen(randomNumber, () => {
36 | // success! close the server and return the port
37 | const { port } = server.address() as AddressInfo;
38 | server.close(() => {
39 | resolve(port);
40 | });
41 | });
42 | });
43 | }
--------------------------------------------------------------------------------
/src/recoil/states.ts:
--------------------------------------------------------------------------------
1 | import { atom } from "recoil";
2 |
3 | // websocket state
4 | export const wsUUIDState = atom({
5 | key: "ws_uuid",
6 | default: "",
7 | });
8 |
9 | export enum ClientWSState {
10 | CONNECTING,
11 | READY,
12 | }
13 |
14 | export const wsState = atom({
15 | key: "ws_state",
16 | default: ClientWSState.CONNECTING,
17 | });
18 |
19 | // model options
20 | export const maxPredictionTokensState = atom({
21 | key: "num_predict",
22 | default: 2048,
23 | });
24 |
25 | export const nBatchState = atom({
26 | key: "n_batch",
27 | default: 8,
28 | });
29 |
30 | export const topKState = atom({
31 | key: "top_k",
32 | default: 40,
33 | });
34 |
35 | export const topPState = atom({
36 | key: "top_p",
37 | default: 0.95,
38 | });
39 |
40 | export const repeatPenaltyState = atom({
41 | key: "repeat_penalty",
42 | default: 1.3,
43 | });
44 |
45 | export const repeatLastNState = atom({
46 | key: "repeat_last_n",
47 | default: 64,
48 | });
49 |
50 | export const tempState = atom({
51 | key: "temp",
52 | default: 0.8,
53 | });
54 |
55 | export const promptState = atom({
56 | key: "prompt",
57 | default: "asdf",
58 | });
59 |
60 | export const promptTemplateState = atom({
61 | key: "prompt_template",
62 | default: "",
63 | });
64 |
65 | export const generatingState = atom({
66 | key: "generating",
67 | default: false,
68 | });
69 |
70 | export const generatedText = atom({
71 | key: "generated_text",
72 | default: "",
73 | });
74 |
--------------------------------------------------------------------------------
/src/recoil/templates.ts:
--------------------------------------------------------------------------------
1 | import { stripIndents } from "common-tags";
2 | import { atom } from "recoil";
3 |
4 | export type TemplateData = {
5 | name: string;
6 | prompt: string;
7 | options?: Partial<{
8 | maximum_tokens: number;
9 | temperature: number;
10 | top_p: number;
11 | top_k: number;
12 | repeat_penalty: number;
13 | repeat_last_n: number;
14 | }>;
15 | };
16 |
17 | export const defaultTemplates: TemplateData[] = [
18 | {
19 | name: "Question Answering",
20 | prompt: stripIndents`
21 | Below is an instruction that describes a task. Write a response that appropriately completes the task. The response must be accurate, concise, coherent, and evidence-based whenever possible.
22 |
23 | **Task:** Describe Machine Learning in your own words.
24 | `,
25 | },
26 | {
27 | name: "Chatbot",
28 | prompt: stripIndents`
29 | Below is a conversation between two people. Write a response that appropriately completes a response for the AI. The response must be accurate, concise, coherent, and evidence-based whenever possible. Do not complete the User's part.
30 |
31 | User: What is Machine Learning?
32 | AI:
33 | `,
34 | },
35 | {
36 | name: "Story",
37 | prompt: stripIndents`
38 | Below is a prompt for a story. The story must relate to the prompt and be creative and interesting.
39 |
40 | Prompt: It was a dark and stormy night...
41 | `,
42 | options: {
43 | temperature: 0.5,
44 | },
45 | },
46 | ];
47 |
48 | export const allTemplatesState = atom({
49 | key: "all_templates",
50 | default: defaultTemplates,
51 | });
52 |
--------------------------------------------------------------------------------
/src/server/api/shared.ts:
--------------------------------------------------------------------------------
1 | import EventEmitter from "events";
2 |
3 | // this is a set of client IDs that are currently generating text
4 | export const generating = new Set();
5 | export const clients = new Set();
6 |
7 | // EventEmitter but with types
8 | export class TypedEventEmitter<
9 | TEvents extends Record
10 | > extends EventEmitter {
11 | emit(
12 | eventName: TEventName,
13 | ...eventArg: TEvents[TEventName]
14 | ) {
15 | return super.emit(eventName, ...(eventArg as []));
16 | }
17 |
18 | on(
19 | eventName: TEventName,
20 | handler: (...eventArg: TEvents[TEventName]) => void
21 | ) {
22 | return super.on(eventName, handler as any);
23 | }
24 |
25 | off(
26 | eventName: TEventName,
27 | handler: (...eventArg: TEvents[TEventName]) => void
28 | ) {
29 | return super.off(eventName, handler as any);
30 | }
31 |
32 | incrementMaxListeners(amount = 1) {
33 | return this.setMaxListeners(this.getMaxListeners() + amount);
34 | }
35 |
36 | decrementMaxListeners(amount = 1) {
37 | return this.setMaxListeners(this.getMaxListeners() - amount);
38 | }
39 | }
40 |
41 | // the events that can be emitted
42 | type Events = {
43 | // client id, message
44 | generate: [string, string];
45 |
46 | // client id
47 | "generate:done": [string];
48 |
49 | // client id, error
50 | "generate:error": [string, Error];
51 |
52 | // client id
53 | "generate:cancel": [string];
54 | };
55 |
56 | // global event emitter
57 | // can be replaced with a pubsub system if you want to scale
58 | export const events = new TypedEventEmitter();
59 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ## Base image
2 | FROM node:18-alpine AS base
3 |
4 | # Install package manager
5 | RUN npm i --global --no-update-notifier --no-fund pnpm
6 |
7 | # Set the working directory
8 | WORKDIR /app
9 |
10 | ## Builder image
11 | FROM base AS web-builder
12 |
13 | # Copy the package.json and package-lock.json
14 | COPY package*.json ./
15 | COPY pnpm-lock.yaml ./
16 |
17 | # Install the dependencies
18 | RUN pnpm install --frozen-lockfile
19 |
20 | # Copy the source code
21 | COPY . .
22 |
23 | # Build the app
24 | RUN SKIP_ENV_VALIDATION=true npm run build
25 |
26 | ## Builder for llama-cpp
27 | FROM alpine:3.14 AS llama-cpp-builder
28 |
29 | # Install build dependencies
30 | RUN apk add --no-cache \
31 | build-base \
32 | make \
33 | git
34 |
35 | # Set the working directory
36 | WORKDIR /app
37 |
38 | # Clone the repository
39 | RUN git clone --depth 1 --branch tcp_server \
40 | https://github.com/ggerganov/llama.cpp.git .
41 |
42 | # Build
43 | RUN make
44 |
45 | # Production image based on alpine node:18
46 | FROM base AS production
47 |
48 | # Set the working directory
49 | WORKDIR /app
50 |
51 | # Copy the package.json and package-lock.json
52 | COPY package*.json ./
53 | COPY pnpm-lock.yaml ./
54 |
55 | # Install production dependencies
56 | RUN pnpm install --frozen-lockfile --production
57 |
58 | # Copy all the built files
59 | COPY --from=llama-cpp-builder /app/main ./bin/main
60 | COPY --from=web-builder /app/ ./
61 | # COPY --from=web-builder /app/dist ./dist
62 | # COPY --from=web-builder /app/.next ./.next
63 | # COPY --from=web-builder /app/public ./public
64 | # COPY --from=web-builder /app/scripts ./scripts
65 | # COPY --from=web-builder /app/src/env.mjs ./src/env.mjs
66 |
67 | # Default environment variables
68 | ENV NODE_ENV=production
69 | ENV HOST=0.0.0.0
70 |
71 | # Done
72 | CMD ["pnpm", "run", "start"]
73 |
--------------------------------------------------------------------------------
/src/server/plugins/next.ts:
--------------------------------------------------------------------------------
1 | import { FastifyInstance, FastifyReply, FastifyRequest } from "fastify";
2 | import fastifyPlugin from "fastify-plugin";
3 | import { IncomingMessage, OutgoingMessage, ServerResponse } from "http";
4 | import next from "next";
5 | import { NextServerOptions } from "next/dist/server/next";
6 |
7 | async function nextPlugin(
8 | fastify: FastifyInstance,
9 | options: NextServerOptions
10 | ) {
11 | const nextServer = next(options);
12 | const handle = nextServer.getRequestHandler();
13 |
14 | fastify
15 | .decorate("nextServer", nextServer)
16 | .decorate("nextHandle", handle)
17 | .decorate("next", route.bind(fastify));
18 |
19 | return nextServer.prepare();
20 |
21 | function route(
22 | path: string,
23 | opts: { method: string | string[] } = { method: "GET" }
24 | ) {
25 | if (typeof opts.method === "string")
26 | // @ts-ignore
27 | this[opts.method.toLowerCase()](path, opts, handler);
28 | else if (Array.isArray(opts.method)) {
29 | for (const method of opts.method) {
30 | // @ts-ignore
31 | this[method.toLowerCase()](path, opts, handler);
32 | }
33 | }
34 |
35 | async function handler(req: FastifyRequest, reply: FastifyReply) {
36 | for (const [key, value] of Object.entries(reply.getHeaders())) {
37 | if (value) reply.raw.setHeader(key, value);
38 | }
39 |
40 | await handle(req.raw, reply.raw);
41 |
42 | reply.hijack();
43 | }
44 | }
45 | }
46 |
47 | export default fastifyPlugin(nextPlugin, {
48 | name: "next",
49 | fastify: "4.x",
50 | });
51 |
52 | declare module "fastify" {
53 | interface FastifyInstance {
54 | nextServer: ReturnType;
55 | next: (path: string, opts?: { method: string | string[] }) => void;
56 | nextHandle: (
57 | req: IncomingMessage,
58 | res: OutgoingMessage | ServerResponse
59 | ) => Promise;
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/.github/workflows/docker.yml:
--------------------------------------------------------------------------------
1 | name: Release to Github Packages. # https://ghcr.io
2 | on: # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#on
3 | workflow_dispatch:
4 | push:
5 | tags: # This builds for all branches with semantically versioned tags (v0.12.3).
6 | - 'v*' # https://semver.org will fail, if there are any other tags
7 |
8 | #release: # Builds only releases. Includes draft and pre-releases.
9 | #types: [created]
10 |
11 | #pull_request: # Run 'tests' for any PRs. Default is to not run for first-time contributors: see /settings/actions
12 |
13 | env:
14 | TAG_LATEST: true # Encourage users to use a major version (foobar:1) instead of :latest.
15 | # By semantic versioning standards, major changes are changes 'backwards incompatible'. Major upgrades are often rare and prehaps, need attention from the user.
16 | jobs:
17 | # Push image to GitHub Packages.
18 | push:
19 | runs-on: ubuntu-latest
20 | permissions:
21 | packages: write
22 | contents: read
23 |
24 | steps:
25 | - uses: actions/checkout@v2
26 |
27 | -
28 | name: Set up QEMU
29 | uses: docker/setup-qemu-action@v2
30 | -
31 | name: Set up Docker Buildx
32 | uses: docker/setup-buildx-action@v2
33 | -
34 | name: Login to GHCR
35 | if: github.event_name != 'pull_request'
36 | uses: docker/login-action@v2
37 | with:
38 | registry: ghcr.io
39 | username: ${{ github.repository_owner }}
40 | password: ${{ secrets.PAT }}
41 | -
42 | name: Extract Builder meta
43 | id: builder-meta
44 | uses: docker/metadata-action@v4
45 | with:
46 | images: ghcr.io/itzderock/llama-playground
47 | tags: |
48 | type=semver,pattern={{version}}
49 | type=semver,pattern={{major}}.{{minor}}
50 | -
51 | name: Build and push
52 | uses: docker/build-push-action@v3
53 | with:
54 | context: .
55 | platforms: linux/amd64,linux/arm64
56 | push: ${{ github.event_name != 'pull_request' }}
57 | tags: ${{ steps.builder-meta.outputs.tags }}
58 | labels: ${{ steps.builder-meta.outputs.labels }}
59 |
--------------------------------------------------------------------------------
/src/pages/_app.tsx:
--------------------------------------------------------------------------------
1 | import { type AppType } from "next/app";
2 | import {
3 | ColorScheme,
4 | ColorSchemeProvider,
5 | MantineProvider,
6 | } from "@mantine/core";
7 | import { api } from "~/utils/api";
8 | import { RecoilRoot } from "recoil";
9 |
10 | import "~/styles/globals.css";
11 | import { useColorScheme, useLocalStorage } from "@mantine/hooks";
12 | import Head from "next/head";
13 |
14 | const App: AppType = ({ Component, pageProps }) => {
15 | // load color scheme from local storage
16 | const preferredColorScheme = useColorScheme();
17 | const [colorScheme, setColorScheme] = useLocalStorage({
18 | key: "colorScheme",
19 | defaultValue: preferredColorScheme,
20 | });
21 |
22 | const toggleColorScheme = (value?: ColorScheme) =>
23 | setColorScheme(value ?? (colorScheme === "light" ? "dark" : "light"));
24 |
25 | return (
26 | <>
27 |
28 |
32 |
33 |
34 |
35 | LLaMA Playground
36 |
37 |
38 |
42 |
43 |
44 |
45 |
46 | {/* thumbnail (large) */}
47 |
48 |
49 |
50 |
51 |
55 |
60 |
61 |
62 |
63 |
64 |
65 | >
66 | );
67 | };
68 |
69 | export default api.withTRPC(App);
70 |
--------------------------------------------------------------------------------
/src/partials/Header.tsx:
--------------------------------------------------------------------------------
1 | import {
2 | createStyles,
3 | Header,
4 | Group,
5 | Button,
6 | Box,
7 | Text,
8 | useMantineColorScheme,
9 | Flex,
10 | } from "@mantine/core";
11 | import { IconBrandGithub, IconMoon, IconSun } from "@tabler/icons-react";
12 | import TemplateSelect from "./TemplateSelect";
13 |
14 | const useStyles = createStyles((theme) => ({
15 | text: {
16 | color: theme.colorScheme === "dark" ? theme.colors.dark[0] : theme.black,
17 | fontWeight: 500,
18 | },
19 |
20 | hiddenMobile: {
21 | [theme.fn.smallerThan("sm")]: {
22 | display: "none",
23 | },
24 | },
25 |
26 | hiddenDesktop: {
27 | [theme.fn.largerThan("sm")]: {
28 | display: "none",
29 | },
30 | },
31 | }));
32 |
33 | export default function FullHeader() {
34 | const { colorScheme, toggleColorScheme } = useMantineColorScheme();
35 | const { classes } = useStyles();
36 |
37 | return (
38 |
39 |
40 |
41 | LLaMA Playground
42 |
43 |
44 |
45 |
46 |
47 |
48 |
60 |
61 |
69 |
70 |
71 |
72 |
73 |
80 |
81 |
82 |
83 | );
84 | }
85 |
--------------------------------------------------------------------------------
/src/server/index.ts:
--------------------------------------------------------------------------------
1 | import fastify, { FastifyLoggerOptions } from "fastify";
2 | import { LoggerOptions } from "pino";
3 | import fastifyNextJS from "./plugins/next";
4 | import { fastifyTRPCPlugin } from "@trpc/server/adapters/fastify";
5 | import "dotenv/config";
6 | import { env } from "~/env.mjs";
7 | import { appRouter } from "./api/root";
8 | import { createTRPCContext } from "./api/trpc";
9 | import ws from "@fastify/websocket";
10 | import path from "path";
11 | import { fastifyStatic } from "@fastify/static";
12 |
13 | // logger
14 | const envToLogger = {
15 | development: {
16 | transport: {
17 | target: "pino-pretty",
18 | options: {
19 | translateTime: "HH:MM:ss Z",
20 | ignore: "pid,hostname",
21 | },
22 | },
23 | },
24 | production: true,
25 | test: false,
26 | } satisfies {
27 | [key: string]: (FastifyLoggerOptions & LoggerOptions) | boolean;
28 | };
29 |
30 | // create fastify app
31 | export const app = fastify({
32 | // prevent errors during large batch requests
33 | maxParamLength: 5000,
34 |
35 | // logger
36 | logger: envToLogger[env.NODE_ENV] ?? true,
37 | });
38 |
39 | // register the next.js plugin
40 | app.register(fastifyNextJS, {
41 | dev: env.NODE_ENV !== "production",
42 | dir: ".",
43 | });
44 |
45 | // serve static files in /public
46 | app.register(fastifyStatic, {
47 | root: path.join(__dirname, "..", "public"),
48 | prefix: "/public/",
49 | });
50 |
51 | // register ws for websocket support
52 | // needed for subscriptions
53 | app.register(ws);
54 |
55 | // register the trpc plugin
56 | app.register(fastifyTRPCPlugin, {
57 | prefix: "/api/trpc",
58 | useWSS: true,
59 | trpcOptions: {
60 | router: appRouter,
61 | createContext: createTRPCContext,
62 | },
63 | });
64 |
65 | // pass to next.js if no route is defined
66 | app.addHook("onRequest", async (req, rep) => {
67 | // if a route is defined, skip
68 | if (req.routerPath) {
69 | return;
70 | }
71 |
72 | // pass along headers
73 | for (const [key, value] of Object.entries(rep.getHeaders())) {
74 | if (value) rep.raw.setHeader(key, value);
75 | }
76 |
77 | // otherwise, pass to next.js
78 | await app.nextHandle(req.raw, rep.raw);
79 | rep.hijack();
80 | });
81 |
82 | // start the server
83 | app.listen({ port: parseInt(env.PORT), host: env.HOST }).then((address) => {
84 | console.log(`🚀 Server ready at ${address}`);
85 | });
86 |
--------------------------------------------------------------------------------
/scripts/startup.mjs:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env zx
2 | import "dotenv/config";
3 | import "zx/globals";
4 | import os from "os";
5 | import path from "path";
6 | import { env } from "../src/env.mjs";
7 |
8 | // don't run if user is providing their own server.
9 | if (!env.USE_BUILT_IN_LLAMA_SERVER) process.exit(0);
10 |
11 | // check if user is providing their own binary
12 | if (env.LLAMA_TCP_BIN != "auto") {
13 | if (!fs.existsSync(env.LLAMA_TCP_BIN)) {
14 | console.error(
15 | `❌ LLAMA_TCP_BIN is set to '${env.LLAMA_TCP_BIN}' but that file doesn't exist.`
16 | );
17 | process.exit(1);
18 | }
19 |
20 | process.exit(0);
21 | }
22 |
23 | // check if llama.cpp's tcp_server is already built
24 | if (fs.existsSync("./bin/") && fs.existsSync("./bin/main")) {
25 | console.log(chalk.green("✅ Llama.cpp's tcp_server is already built."));
26 | process.exit(0);
27 | }
28 |
29 | // clone and build the repo
30 | const { path: tempDir, delete: deleteTempDir } = createTempDir();
31 | console.log(
32 | chalk.gray(
33 | `🔃 (1/4) Created temporary directory at "${chalk.green(
34 | tempDir
35 | )}". Cloning repo...`
36 | )
37 | );
38 |
39 | // clone the repo
40 | await $`git clone --depth 1 --branch tcp_server https://github.com/ggerganov/llama.cpp ${tempDir}`;
41 | console.log(
42 | chalk.gray(`🔃 (2/4) Cloned llama.cpp's tcp_server branch to ${tempDir}`)
43 | );
44 |
45 | // build the repo
46 | console.log(chalk.gray(`🔃 (3/4) Building llama.cpp's tcp_server...`));
47 | try {
48 | await $`cd ${tempDir} && make -j`;
49 | } catch (error) {
50 | console.error(
51 | chalk.red(`❌ Failed to build llama.cpp's tcp_server! Error: ${error}`)
52 | );
53 | console.error(`🗑️ Cleaning up...`);
54 | deleteTempDir();
55 | process.exit(1);
56 | }
57 |
58 | // copy the binary to the bin folder
59 | console.log(chalk.gray(`🔃 (4/4) Copying binary to ./bin/main...`));
60 | fs.mkdirSync("./bin/");
61 | fs.copyFileSync(path.join(tempDir, "main"), "./bin/main");
62 |
63 | // done
64 | console.log(chalk.green("✅ Built llama.cpp's tcp_server! Cleaning up..."));
65 | deleteTempDir();
66 |
67 | /**
68 | * Creates a temporary directory and returns that directory's path and a function to delete it.
69 | * @returns {{ path: string; delete: () => void; }}
70 | */
71 | function createTempDir() {
72 | const random = Math.random().toString(36).substring(2, 15);
73 | const tempDir = path.join(os.tmpdir(), `llama-playground-tmp-${random}`);
74 |
75 | if (!fs.existsSync(tempDir)) {
76 | fs.mkdirSync(tempDir);
77 | return {
78 | path: tempDir,
79 | delete: () => fs.removeSync(tempDir),
80 | };
81 | }
82 |
83 | return createTempDir();
84 | }
85 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "llama-playground",
3 | "version": "1.0.1",
4 | "private": true,
5 | "scripts": {
6 | "build": "npm-run-all build:next build:server",
7 | "build:next": "next build",
8 | "build:server": "tsup",
9 | "dev": "npm-run-all build:server dev:run",
10 | "dev:run": "cross-env DEBUG=true REACT_EDITOR=code NODE_ENV=development RECOIL_DUPLICATE_ATOM_KEY_CHECKING_ENABLED=false node --enable-source-maps dist",
11 | "lint": "next lint",
12 | "start": "node --enable-source-maps ./dist/index.js",
13 | "prestart": "node ./scripts/startup.mjs"
14 | },
15 | "dependencies": {
16 | "@emotion/react": "^11.10.6",
17 | "@emotion/server": "^11.10.0",
18 | "@fastify/static": "^6.9.0",
19 | "@fastify/websocket": "^7.1.3",
20 | "@mantine/core": "^6.0.2",
21 | "@mantine/form": "^6.0.4",
22 | "@mantine/hooks": "^6.0.2",
23 | "@mantine/next": "^6.0.2",
24 | "@mantine/tiptap": "^6.0.2",
25 | "@tabler/icons-react": "^2.11.0",
26 | "@tanstack/react-query": "^4.20.2",
27 | "@tiptap/core": ">=2.0.0-beta.209 <3.0.0",
28 | "@tiptap/extension-document": "2.0.0-beta.220",
29 | "@tiptap/extension-highlight": "2.0.0-beta.220",
30 | "@tiptap/extension-link": "2.0.0-beta.220",
31 | "@tiptap/extension-paragraph": "2.0.0-beta.220",
32 | "@tiptap/extension-placeholder": "2.0.0-beta.220",
33 | "@tiptap/extension-text": "2.0.0-beta.220",
34 | "@tiptap/pm": ">=2.0.0-beta.209 <3.0.0",
35 | "@tiptap/react": "2.0.0-beta.220",
36 | "@trpc/client": "^10.9.0",
37 | "@trpc/next": "^10.9.0",
38 | "@trpc/react-query": "^10.9.0",
39 | "@trpc/server": "^10.9.0",
40 | "@types/common-tags": "^1.8.1",
41 | "common-tags": "^1.8.2",
42 | "cross-env": "^7.0.3",
43 | "dotenv": "^16.0.3",
44 | "fastify": "^4.15.0",
45 | "fastify-plugin": "^4.5.0",
46 | "next": "^13.2.1",
47 | "npm-run-all": "^4.1.5",
48 | "pino": "^8.11.0",
49 | "react": "18.2.0",
50 | "react-dom": "18.2.0",
51 | "recoil": "^0.7.7",
52 | "state": "link:@tiptap/pm/state",
53 | "superjson": "1.9.1",
54 | "ts-node": "^10.9.1",
55 | "tsup": "^6.7.0",
56 | "zod": "^3.20.6",
57 | "zx": "^7.2.1"
58 | },
59 | "devDependencies": {
60 | "@types/eslint": "^8.21.1",
61 | "@types/node": "^18.14.0",
62 | "@types/prettier": "^2.7.2",
63 | "@types/react": "^18.0.28",
64 | "@types/react-dom": "^18.0.11",
65 | "@typescript-eslint/eslint-plugin": "^5.53.0",
66 | "@typescript-eslint/parser": "^5.53.0",
67 | "autoprefixer": "^10.4.14",
68 | "eslint": "^8.34.0",
69 | "eslint-config-next": "^13.2.1",
70 | "pino-pretty": "^10.0.0",
71 | "postcss": "^8.4.14",
72 | "prettier": "^2.8.1",
73 | "typescript": "^4.9.5"
74 | },
75 | "ct3aMetadata": {
76 | "initVersion": "7.8.0"
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/partials/PromptEditor.tsx:
--------------------------------------------------------------------------------
1 | import { RichTextEditor } from "@mantine/tiptap";
2 | import { useEditor, type Content } from "@tiptap/react";
3 | import { useEffect } from "react";
4 | import { useRecoilState } from "recoil";
5 | import {
6 | generatedText,
7 | generatingState,
8 | promptState,
9 | promptTemplateState,
10 | } from "~/recoil/states";
11 | import Document from "@tiptap/extension-document";
12 | import Text from "@tiptap/extension-text";
13 | import Paragraph from "@tiptap/extension-paragraph";
14 | import Placeholder from "@tiptap/extension-placeholder";
15 | import Highlight from "@tiptap/extension-highlight";
16 |
17 | export default function PromptEditor() {
18 | const [_prompt, setPrompt] = useRecoilState(promptState);
19 | const [generating] = useRecoilState(generatingState);
20 | const [generated] = useRecoilState(generatedText);
21 | const [templatePrompt, __] = useRecoilState(promptTemplateState);
22 |
23 | // build a tiptap editor
24 | const editor = useEditor({
25 | extensions: [
26 | Document,
27 | Paragraph,
28 | Text,
29 | Placeholder.configure({
30 | placeholder: "Q: What is a llama?",
31 | }),
32 | Highlight,
33 | ],
34 | onUpdate: ({ editor }) => {
35 | setPrompt(
36 | editor.getText({
37 | blockSeparator: "\n",
38 | })
39 | );
40 | },
41 | });
42 |
43 | // handle enabling/disabling the editor
44 | useEffect(() => {
45 | if (generating) {
46 | editor?.setEditable(false);
47 | editor?.chain().focus("end").setHighlight({ color: "#FFF68F" }).run();
48 | } else {
49 | editor?.commands.unsetHighlight();
50 | editor?.setEditable(true);
51 | }
52 | }, [generating]);
53 |
54 | // handle updating the generated text
55 | useEffect(() => {
56 | if (!editor) return;
57 |
58 | editor
59 | .chain()
60 | .focus("end")
61 | .setHighlight({ color: "#FFF68F" })
62 | .insertContent(
63 | generated === "\n"
64 | ? {
65 | type: "paragraph",
66 | content: [],
67 | }
68 | : generated
69 | )
70 | .run();
71 | }, [generated]);
72 |
73 | // handle prompt changes
74 | useEffect(() => {
75 | if (!editor) return;
76 |
77 | // delete all content and insert the new prompt
78 | editor
79 | .chain()
80 | .selectAll()
81 | .deleteSelection()
82 | .insertContent(templatePrompt)
83 | .run();
84 | }, [templatePrompt]);
85 |
86 | return (
87 |
97 |
98 |
99 | );
100 | }
101 |
--------------------------------------------------------------------------------
/src/server/api/trpc.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * YOU PROBABLY DON'T NEED TO EDIT THIS FILE, UNLESS:
3 | * 1. You want to modify request context (see Part 1).
4 | * 2. You want to create a new middleware or type of procedure (see Part 3).
5 | *
6 | * TL;DR - This is where all the tRPC server stuff is created and plugged in. The pieces you will
7 | * need to use are documented accordingly near the end.
8 | */
9 |
10 | /**
11 | * 1. CONTEXT
12 | *
13 | * This section defines the "contexts" that are available in the backend API.
14 | *
15 | * These allow you to access things when processing a request, like the database, the session, etc.
16 | */
17 | import { type CreateFastifyContextOptions } from "@trpc/server/adapters/fastify";
18 |
19 | /** Replace this with an object if you want to pass things to `createContextInner`. */
20 | type CreateContextOptions = Record;
21 |
22 | /**
23 | * This helper generates the "internals" for a tRPC context. If you need to use it, you can export
24 | * it from here.
25 | *
26 | * Examples of things you may need it for:
27 | * - testing, so we don't have to mock Next.js' req/res
28 | * - tRPC's `createSSGHelpers`, where we don't have req/res
29 | *
30 | * @see https://create.t3.gg/en/usage/trpc#-servertrpccontextts
31 | */
32 | const createInnerTRPCContext = (_opts: CreateContextOptions) => {
33 | return {};
34 | };
35 |
36 | /**
37 | * This is the actual context you will use in your router. It will be used to process every request
38 | * that goes through your tRPC endpoint.
39 | *
40 | * @see https://trpc.io/docs/context
41 | */
42 | export const createTRPCContext = (_opts: CreateFastifyContextOptions) => {
43 | return createInnerTRPCContext({});
44 | };
45 |
46 | /**
47 | * 2. INITIALIZATION
48 | *
49 | * This is where the tRPC API is initialized, connecting the context and transformer. We also parse
50 | * ZodErrors so that you get typesafety on the frontend if your procedure fails due to validation
51 | * errors on the backend.
52 | */
53 | import { initTRPC } from "@trpc/server";
54 | import superjson from "superjson";
55 | import { ZodError } from "zod";
56 |
57 | const t = initTRPC.context().create({
58 | transformer: superjson,
59 | errorFormatter({ shape, error }) {
60 | return {
61 | ...shape,
62 | data: {
63 | ...shape.data,
64 | zodError:
65 | error.cause instanceof ZodError ? error.cause.flatten() : null,
66 | },
67 | };
68 | },
69 | });
70 |
71 | /**
72 | * 3. ROUTER & PROCEDURE (THE IMPORTANT BIT)
73 | *
74 | * These are the pieces you use to build your tRPC API. You should import these a lot in the
75 | * "/src/server/api/routers" directory.
76 | */
77 |
78 | /**
79 | * This is how you create new routers and sub-routers in your tRPC API.
80 | *
81 | * @see https://trpc.io/docs/router
82 | */
83 | export const createTRPCRouter = t.router;
84 |
85 | /**
86 | * Public (unauthenticated) procedure
87 | *
88 | * This is the base piece you use to build new queries and mutations on your tRPC API. It does not
89 | * guarantee that a user querying is authorized, but you can still access user session data if they
90 | * are logged in.
91 | */
92 | export const publicProcedure = t.procedure;
93 |
--------------------------------------------------------------------------------
/src/utils/api.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * This is the client-side entrypoint for your tRPC API. It is used to create the `api` object which
3 | * contains the Next.js App-wrapper, as well as your type-safe React Query hooks.
4 | *
5 | * We also create a few inference helpers for input and output types.
6 | */
7 | import {
8 | createWSClient,
9 | httpBatchLink,
10 | loggerLink,
11 | wsLink,
12 | } from "@trpc/client";
13 | import { createTRPCNext } from "@trpc/next";
14 | import { type inferRouterInputs, type inferRouterOutputs } from "@trpc/server";
15 | import superjson from "superjson";
16 |
17 | import { type AppRouter } from "~/server/api/root";
18 |
19 | const getBaseUrl = () => {
20 | if (typeof window !== "undefined") return ""; // browser should use relative url
21 | if (process.env.VERCEL_URL) return `https://${process.env.VERCEL_URL}`; // SSR should use vercel url
22 | return `http://localhost:${process.env.PORT ?? 3000}`; // dev SSR should use localhost
23 | };
24 |
25 | /** A set of type-safe react-query hooks for your tRPC API. */
26 | export const api = createTRPCNext({
27 | config() {
28 | const links = [
29 | loggerLink({
30 | enabled: (opts) =>
31 | process.env.NODE_ENV === "development" ||
32 | (opts.direction === "down" && opts.result instanceof Error),
33 | }),
34 | ];
35 |
36 | // If we're in the browser, we'll use a WebSocket connection to the server.
37 | if (typeof window !== "undefined") {
38 | // url should be the page url but with ws(s) instead of http(s)
39 | let url = window.location.host;
40 |
41 | // add protocol
42 | if (window.location.protocol === "https:") {
43 | url = "wss://" + url;
44 | } else {
45 | url = "ws://" + url;
46 | }
47 |
48 | links.push(
49 | wsLink({
50 | client: createWSClient({
51 | url: `${url}/api/trpc`,
52 | }),
53 | })
54 | );
55 | } else {
56 | // If we're on the server, we'll use HTTP batching.
57 | links.push(
58 | httpBatchLink({
59 | url: `${getBaseUrl()}/api/trpc`,
60 | })
61 | );
62 | }
63 |
64 | return {
65 | /**
66 | * Transformer used for data de-serialization from the server.
67 | *
68 | * @see https://trpc.io/docs/data-transformers
69 | */
70 | transformer: superjson,
71 |
72 | /**
73 | * Links used to determine request flow from client to server.
74 | *
75 | * @see https://trpc.io/docs/links
76 | */
77 | links,
78 | };
79 | },
80 |
81 | /**
82 | * Whether tRPC should await queries when server rendering pages.
83 | *
84 | * @see https://trpc.io/docs/nextjs#ssr-boolean-default-false
85 | */
86 | ssr: false,
87 | });
88 |
89 | /**
90 | * Inference helper for inputs.
91 | *
92 | * @example type HelloInput = RouterInputs['example']['hello']
93 | */
94 | export type RouterInputs = inferRouterInputs;
95 |
96 | /**
97 | * Inference helper for outputs.
98 | *
99 | * @example type HelloOutput = RouterOutputs['example']['hello']
100 | */
101 | export type RouterOutputs = inferRouterOutputs;
102 |
--------------------------------------------------------------------------------
/src/partials/Generate.tsx:
--------------------------------------------------------------------------------
1 | import { Flex, Button } from "@mantine/core";
2 | import { useRecoilState, useRecoilValue } from "recoil";
3 | import {
4 | promptState,
5 | generatingState,
6 | wsState,
7 | wsUUIDState,
8 | ClientWSState,
9 | generatedText,
10 | maxPredictionTokensState,
11 | topKState,
12 | topPState,
13 | repeatPenaltyState,
14 | tempState,
15 | repeatLastNState,
16 | } from "~/recoil/states";
17 | import { WSMessageType } from "~/server/api/types";
18 | import { api } from "~/utils/api";
19 |
20 | export default function Generate() {
21 | const [prompt] = useRecoilState(promptState);
22 | const [loading, setLoading] = useRecoilState(generatingState);
23 | const [state, setWSState] = useRecoilState(wsState);
24 | const [uuid, setUUID] = useRecoilState(wsUUIDState);
25 | const [_generated, setGenerated] = useRecoilState(generatedText);
26 | const nPredict = useRecoilValue(maxPredictionTokensState);
27 | const topK = useRecoilValue(topKState);
28 | const topP = useRecoilValue(topPState);
29 | const repeatPenalty = useRecoilValue(repeatPenaltyState);
30 | const temp = useRecoilValue(tempState);
31 | const repeatLastN = useRecoilValue(repeatLastNState);
32 |
33 | // subscribe to generation updates
34 | api.llama.subscription.useSubscription(undefined, {
35 | onData: (data) => {
36 | // wait for identity
37 | if (data.type === WSMessageType.IDENTITY) {
38 | setUUID(data.data);
39 | setWSState(ClientWSState.READY);
40 | return;
41 | }
42 |
43 | // loading false when generation done
44 | if (data.type === WSMessageType.REQUEST_COMPLETE) {
45 | setLoading(false);
46 | return;
47 | }
48 |
49 | // update generated text
50 | if (data.type === WSMessageType.COMPLETION) {
51 | setGenerated(data.data);
52 | return;
53 | }
54 | },
55 |
56 | onError: (error) => {
57 | setWSState(ClientWSState.CONNECTING);
58 | },
59 | });
60 |
61 | // mutations
62 | const generate = api.llama.startGeneration.useMutation();
63 | const cancel = api.llama.cancelGeneration.useMutation();
64 |
65 | return (
66 |
67 |
89 |
90 | {/* cancel */}
91 | {loading && (
92 |
101 | )}
102 |
103 | );
104 | }
105 |
--------------------------------------------------------------------------------
/src/env.mjs:
--------------------------------------------------------------------------------
1 | import { z } from "zod";
2 |
3 | /**
4 | * Specify your server-side environment variables schema here. This way you can ensure the app isn't
5 | * built with invalid env vars.
6 | */
7 | const server = z.object({
8 | NODE_ENV: z
9 | .enum(["development", "test", "production"])
10 | .default("development"),
11 | USE_BUILT_IN_LLAMA_SERVER: z
12 | .literal("true")
13 | .or(z.literal("false"))
14 | .or(z.boolean()),
15 | LLAMA_SERVER_HOST: z.string().min(1).or(z.literal("auto")),
16 | LLAMA_SERVER_PORT: z.number().int().positive().or(z.literal("auto")),
17 | LLAMA_MODEL_PATH: z.string().min(1),
18 | LLAMA_TCP_BIN: z.string().min(1).or(z.literal("auto")),
19 | PORT: z.string(),
20 | HOST: z.string(),
21 | });
22 |
23 | /**
24 | * Specify your client-side environment variables schema here. This way you can ensure the app isn't
25 | * built with invalid env vars. To expose them to the client, prefix them with `NEXT_PUBLIC_`.
26 | */
27 | const client = z.object({
28 | // NEXT_PUBLIC_CLIENTVAR: z.string().min(1),
29 | });
30 |
31 | /**
32 | * You can't destruct `process.env` as a regular object in the Next.js edge runtimes (e.g.
33 | * middlewares) or client-side so we need to destruct manually.
34 | *
35 | * @type {Record | keyof z.infer, string | number | undefined>}
36 | */
37 | const processEnv = {
38 | NODE_ENV: process.env.NODE_ENV,
39 | USE_BUILT_IN_LLAMA_SERVER: process.env.USE_BUILT_IN_LLAMA_SERVER,
40 | LLAMA_SERVER_HOST: process.env.LLAMA_SERVER_HOST,
41 | LLAMA_SERVER_PORT: isNaN(parseInt(process.env.LLAMA_SERVER_PORT ?? ""))
42 | ? process.env.LLAMA_SERVER_PORT
43 | : parseInt(process.env.LLAMA_SERVER_PORT ?? ""),
44 | LLAMA_MODEL_PATH: process.env.LLAMA_MODEL_PATH,
45 | LLAMA_TCP_BIN: process.env.LLAMA_TCP_BIN,
46 | PORT: process.env.PORT ?? "3000",
47 | HOST: process.env.HOST ?? "localhost",
48 | };
49 |
50 | // Don't touch the part below
51 | // --------------------------
52 |
53 | const merged = server.merge(client);
54 |
55 | /** @typedef {z.input} MergedInput */
56 | /** @typedef {z.infer} MergedOutput */
57 | /** @typedef {z.SafeParseReturnType} MergedSafeParseReturn */
58 |
59 | let env = /** @type {MergedOutput} */ (process.env);
60 |
61 | if (!!process.env.SKIP_ENV_VALIDATION == false) {
62 | const isServer = typeof window === "undefined";
63 |
64 | const parsed = /** @type {MergedSafeParseReturn} */ (
65 | isServer
66 | ? merged.safeParse(processEnv) // on server we can validate all env vars
67 | : client.safeParse(processEnv) // on client we can only validate the ones that are exposed
68 | );
69 |
70 | if (parsed.success === false) {
71 | console.error(
72 | "❌ Invalid environment variables:",
73 | parsed.error.flatten().fieldErrors
74 | );
75 | throw new Error("Invalid environment variables");
76 | }
77 |
78 | env = new Proxy(parsed.data, {
79 | get(target, prop) {
80 | if (typeof prop !== "string") return undefined;
81 | // Throw a descriptive error if a server-side env var is accessed on the client
82 | // Otherwise it would just be returning `undefined` and be annoying to debug
83 | if (!isServer && !prop.startsWith("NEXT_PUBLIC_"))
84 | throw new Error(
85 | process.env.NODE_ENV === "production"
86 | ? "❌ Attempted to access a server-side environment variable on the client"
87 | : `❌ Attempted to access server-side environment variable '${prop}' on the client`
88 | );
89 | return target[/** @type {keyof typeof target} */ (prop)];
90 | },
91 | });
92 | }
93 |
94 | export { env };
95 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🦙 LLaMA Playground 🛝
2 |
3 | A simple Open-AI inspired interface that uses [llama.cpp#tcp_server](https://github.com/ggerganov/llama.cpp/tree/tcp_server) in the background.
4 |
5 | 
6 |
7 | ## Difference vs. other interfaces
8 |
9 | Other interfaces use the llama.cpp cli command to run the model. This is not ideal since it requires to spawn a new process for each request. This is not only slow but also requires to load the model each time. This interface uses the llama.cpp tcp_server to run the model in the background. This allows to run multiple requests in parallel and also to cache the model in memory.
10 |
11 | ## Features
12 |
13 | - Simple to use UI
14 | - Able to handle multiple requests in parallel quickly
15 | - Controls to change the model parameters on the fly
16 | - Does not require rebooting, changes are applied instantly
17 | - Save and load templates to save your work
18 | - Templates are saved in the browser's local storage and are not sent to the server
19 |
20 | ## About
21 |
22 | Built on top of a modified [T3-stack](https://github.com/t3-oss/create-t3-app) application.
23 | Fastify is used instead the regular next.js server since websocket support is needed.
24 | [Mantine](https://mantine.dev/) is used for the UI.
25 | [tRPC](https://trpc.io/) is used for an end-to-end type-safe API.
26 |
27 | The fastify server starts a tcp_server from llama.cpp in the background.
28 | Upon each request, the server establishes a new TCP connection to the tcp_server and sends the request.
29 | Output is then forwarded to the client via websockets.
30 |
31 | ## Notice
32 |
33 | This is not meant to be used in production. There is no rate-limiting, no authentication, etc. It is just a simple interface to play with the models.
34 |
35 | ## Usage
36 |
37 | ### Getting the model
38 |
39 | This repository will not include the model weights as these are the property of Meta. Do not share the weights in this repository.
40 |
41 | Currently, the application will not convert and quantize the model for you. You will need to do this yourself. This means you will need the llama.cpp build dependencies.
42 |
43 | - For ubuntu: `build-essentail make python3`
44 | - For arch: `base-devel make python3`
45 |
46 | ```bash
47 | # build this repo
48 | git clone https://github.com/ggerganov/llama.cpp
49 | cd llama.cpp
50 | make
51 |
52 | # obtain the original LLaMA model weights and place them in ./models
53 | ls ./models
54 | 65B 30B 13B 7B tokenizer_checklist.chk tokenizer.model
55 |
56 | # install Python dependencies
57 | python3 -m pip install torch numpy sentencepiece
58 |
59 | # convert the 7B model to ggml FP16 format
60 | python3 convert-pth-to-ggml.py models/7B/ 1
61 |
62 | # quantize the model to 4-bits
63 | python3 quantize.py 7B
64 | ```
65 |
66 | ^ (source [llama.cpp/README.md](https://github.com/ggerganov/llama.cpp/))
67 |
68 | Then you can start the server using one of the below methods:
69 |
70 | ### With Docker
71 |
72 | ```bash
73 | # Clone the repository
74 | git clone --depth 1 https://github.com/ItzDerock/llama-playground .
75 |
76 | # Edit the docker-compose.yml file to point to the correct model
77 | vim docker-compose.yml
78 |
79 | # Start the server
80 | docker-compose up -d
81 | ```
82 |
83 | ### Without Docker
84 |
85 | ```bash
86 | # Clone the repository
87 | git clone --depth 1 https://github.com/ItzDerock/llama-playground .
88 |
89 | # Install dependencies
90 | pnpm install # you will need pnpm
91 |
92 | # Edit the .env file to point to the correct model
93 | vim .env
94 |
95 | # Build the server
96 | pnpm build
97 |
98 | # Start the server
99 | pnpm start
100 | ```
101 |
102 | ## Development
103 |
104 | Run `pnpm run dev` to start the development server.
105 |
106 | ## License
107 |
108 | [MIT](./LICENSE)
109 |
--------------------------------------------------------------------------------
/src/server/api/routers/llama.ts:
--------------------------------------------------------------------------------
1 | import { observable } from "@trpc/server/observable";
2 | import { createTRPCRouter, publicProcedure } from "~/server/api/trpc";
3 | import { getLLaMAClient } from "~/server/llama";
4 | import { randomUUID } from "crypto";
5 | import { generating, events, clients } from "~/server/api/shared";
6 | import { z } from "zod";
7 | import { WSMessageType, WSMessage } from "../types";
8 |
9 | export const llamaRouter = createTRPCRouter({
10 | status: publicProcedure.query(() => {
11 | const client = getLLaMAClient();
12 |
13 | return {
14 | status: client.state,
15 | };
16 | }),
17 |
18 | subscription: publicProcedure.subscription(() => {
19 | return observable((emit) => {
20 | // each client gets a unique UUID so we can keep track of them
21 | const uuid = randomUUID();
22 |
23 | // send the client their UUID
24 | emit.next({
25 | type: WSMessageType.IDENTITY,
26 | data: uuid,
27 | });
28 |
29 | // register this client
30 | clients.add(uuid);
31 |
32 | // create a callback function
33 | function completion(id: string, completion: string) {
34 | if (id === uuid) {
35 | emit.next({
36 | type: WSMessageType.COMPLETION,
37 | data: completion,
38 | });
39 | }
40 | }
41 |
42 | // done callback
43 | function done(id: string) {
44 | if (id === uuid) {
45 | emit.next({
46 | type: WSMessageType.REQUEST_COMPLETE,
47 | data: "",
48 | });
49 | }
50 | }
51 |
52 | // register callbacks
53 | events.incrementMaxListeners(2);
54 | events.on("generate", completion);
55 | events.on("generate:done", done);
56 |
57 | // handle disconnects
58 | return () => {
59 | // remove the client from the clients set
60 | clients.delete(uuid);
61 |
62 | // remove the callbacks
63 | events.off("generate", completion);
64 | events.off("generate:done", done);
65 |
66 | // decrement the max listeners
67 | events.decrementMaxListeners(2);
68 |
69 | // remove the client from the generating if they are generating
70 | if (generating.has(uuid)) {
71 | generating.delete(uuid);
72 | events.emit("generate:cancel", uuid);
73 | }
74 | };
75 | });
76 | }),
77 |
78 | startGeneration: publicProcedure
79 | .input(
80 | z.object({
81 | prompt: z.string(),
82 | options: z.object({
83 | "--seed": z.string().optional(),
84 | "--threads": z.number().optional(),
85 | "--n_predict": z.number().optional(),
86 | "--top_k": z.number().optional(),
87 | "--top_p": z.number().optional(),
88 | "--repeat_last_n": z.number().optional(),
89 | "--repeat_penalty": z.number().optional(),
90 | "--ctx_size": z.number().optional(),
91 | "--ignore-eos": z.boolean().optional(),
92 | "--memory_f16": z.boolean().optional(),
93 | "--temp": z.number().optional(),
94 | "--n_parts": z.number().optional(),
95 | "--batch_size": z.number().optional(),
96 | "--perplexity": z.number().optional(),
97 | }),
98 | uuid: z.string(),
99 | })
100 | )
101 | .mutation(async ({ input }) => {
102 | // validate uuid
103 | if (!clients.has(input.uuid)) {
104 | throw new Error("Invalid UUID");
105 | }
106 |
107 | // check if the client is already generating
108 | if (generating.has(input.uuid)) {
109 | throw new Error("Client is already generating");
110 | }
111 |
112 | // set the client's generating state to true
113 | generating.add(input.uuid);
114 |
115 | // start generation
116 | const client = getLLaMAClient();
117 | const { stream, cancel } = await client.complete(
118 | input.prompt,
119 | input.options
120 | );
121 |
122 | // forward the generated tokens to the client
123 | stream.on("data", (data) => {
124 | events.emit("generate", input.uuid, data.toString());
125 | });
126 |
127 | // handle cancel
128 | function cancelRequest(id: string) {
129 | if (id === input.uuid) {
130 | cancel();
131 | }
132 | }
133 |
134 | // increment the max listeners
135 | events.incrementMaxListeners(1);
136 | events.on("generate:cancel", cancelRequest);
137 |
138 | // when the client is done generating, set the client's generating state to false
139 | stream.on("end", () => {
140 | generating.delete(input.uuid);
141 |
142 | // send the client a message saying that generation is complete
143 | events.emit("generate:done", input.uuid);
144 |
145 | // clean up
146 | events.off("generate:cancel", cancelRequest);
147 | events.decrementMaxListeners(1);
148 | });
149 | }),
150 |
151 | cancelGeneration: publicProcedure
152 | .input(
153 | z.object({
154 | uuid: z.string(),
155 | })
156 | )
157 | .mutation(async ({ input }) => {
158 | // validate uuid
159 | if (!clients.has(input.uuid)) {
160 | throw new Error("Invalid UUID");
161 | }
162 |
163 | // check if the client is already generating
164 | if (!generating.has(input.uuid)) {
165 | throw new Error("Client is not generating");
166 | }
167 |
168 | // cancel generation
169 | events.emit("generate:cancel", input.uuid);
170 | }),
171 | });
172 |
--------------------------------------------------------------------------------
/src/partials/TemplateSelect.tsx:
--------------------------------------------------------------------------------
1 | import { Button, Flex, Modal, Select, TextInput } from "@mantine/core";
2 | import { useDisclosure } from "@mantine/hooks";
3 | import { useForm } from "@mantine/form";
4 | import { useEffect } from "react";
5 | import { useRecoilState, useRecoilValue } from "recoil";
6 | import {
7 | generatingState,
8 | maxPredictionTokensState,
9 | promptState,
10 | promptTemplateState,
11 | repeatLastNState,
12 | repeatPenaltyState,
13 | tempState,
14 | topKState,
15 | topPState,
16 | } from "~/recoil/states";
17 | import {
18 | allTemplatesState,
19 | defaultTemplates,
20 | TemplateData,
21 | } from "~/recoil/templates";
22 |
23 | function setIfExists(setFunction: (arg0: T) => void, value?: T) {
24 | if (value !== undefined) setFunction(value);
25 | }
26 |
27 | export default function TemplateSelect({ fullWidth }: { fullWidth?: boolean }) {
28 | const [templates, setTemplates] = useRecoilState(allTemplatesState);
29 |
30 | // the model options
31 | const [maxTokens, setMaxTokens] = useRecoilState(maxPredictionTokensState);
32 | const [topK, setTopK] = useRecoilState(topKState);
33 | const [topP, setTopP] = useRecoilState(topPState);
34 | const [repeatPenalty, setRepeatPenalty] = useRecoilState(repeatPenaltyState);
35 | const [repeatLastN, setRepeatLastN] = useRecoilState(repeatLastNState);
36 | const [temp, setTemp] = useRecoilState(tempState);
37 | const [prompt, _setPrompt] = useRecoilState(promptState);
38 | const [_, setTemplatePrompt] = useRecoilState(promptTemplateState);
39 |
40 | // generation state
41 | const generating = useRecoilValue(generatingState);
42 |
43 | // modal state
44 | const [opened, { open, close }] = useDisclosure(false);
45 | const form = useForm({
46 | initialValues: {
47 | name: "",
48 | },
49 |
50 | validate: {
51 | name: (value) => {
52 | if (!value) return "Name is required";
53 | if (defaultTemplates.some((template) => template.name === value))
54 | return "Cannot override default templates.";
55 | },
56 | },
57 | });
58 |
59 | // read templates from local storage
60 | useEffect(() => {
61 | const storedTemplates = localStorage.getItem("templates");
62 | if (storedTemplates) {
63 | try {
64 | setTemplates((defaultTemplates) => [
65 | ...defaultTemplates,
66 | ...JSON.parse(storedTemplates ?? "[]"),
67 | ]);
68 | } catch (error) {
69 | console.error(`Error parsing templates:`, error);
70 | }
71 | }
72 | }, []);
73 |
74 | // save templates to local storage
75 | useEffect(() => {
76 | // only save if they are not the default templates
77 | const userCreatedTemplates = templates.filter(
78 | (template) => !defaultTemplates.some((tmp) => tmp.name === template.name)
79 | );
80 |
81 | localStorage.setItem("templates", JSON.stringify(userCreatedTemplates));
82 | }, [templates]);
83 |
84 | // save
85 | function saveTemplate(name: string) {
86 | const templateData = {
87 | name,
88 | prompt,
89 | options: {
90 | maximum_tokens: maxTokens,
91 | repeat_last_n: repeatLastN,
92 | temperature: temp,
93 | repeat_penalty: repeatPenalty,
94 | top_p: topP,
95 | top_k: topK,
96 | },
97 | } satisfies TemplateData;
98 |
99 | setTemplates((templates) => [...templates, templateData]);
100 | close();
101 | }
102 |
103 | // load
104 | function loadTemplate(name: string) {
105 | const template = templates.find((template) => template.name === name);
106 | if (!template) return;
107 |
108 | setTemplatePrompt(template.prompt);
109 | setIfExists(setMaxTokens, template.options?.maximum_tokens);
110 | setIfExists(setTopK, template.options?.top_k);
111 | setIfExists(setTopP, template.options?.top_p);
112 | setIfExists(setRepeatPenalty, template.options?.repeat_penalty);
113 | setIfExists(setRepeatLastN, template.options?.repeat_last_n);
114 | setIfExists(setTemp, template.options?.temperature);
115 | }
116 |
117 | return (
118 | <>
119 | {/* modal */}
120 |
121 |
132 |
133 |
134 |
143 |
172 | >
173 | );
174 | }
175 |
--------------------------------------------------------------------------------
/src/partials/ModelControls.tsx:
--------------------------------------------------------------------------------
1 | import {
2 | Flex,
3 | Slider,
4 | TextInput,
5 | Tooltip,
6 | Text,
7 | Container,
8 | NumberInput,
9 | Group,
10 | Select,
11 | } from "@mantine/core";
12 | import { useRecoilState } from "recoil";
13 | import {
14 | maxPredictionTokensState,
15 | repeatLastNState,
16 | repeatPenaltyState,
17 | tempState,
18 | topKState,
19 | topPState,
20 | } from "~/recoil/states";
21 | import { api } from "~/utils/api";
22 |
23 | export default function ModelControls() {
24 | const [temp, setTemp] = useRecoilState(tempState);
25 | const [maxTokens, setMaxTokens] = useRecoilState(maxPredictionTokensState);
26 | const [topK, setTopK] = useRecoilState(topKState);
27 | const [topP, setTopP] = useRecoilState(topPState);
28 | const [repeatPenalty, setRepeatPenalty] = useRecoilState(repeatPenaltyState);
29 | const [repeatLastN, setRepeatLastN] = useRecoilState(repeatLastNState);
30 |
31 | // status
32 | const status = api.llama.status.useQuery(undefined, {
33 | refetchInterval: (data) => {
34 | if (data && data.status === "ready") {
35 | return 0;
36 | } else {
37 | return 1000;
38 | }
39 | },
40 | });
41 |
42 | return (
43 |
44 | {/* future idea: allow loading multiple models, 7B, 15B, 30B, etc */}
45 |
46 |
47 | {/* Maximum Length */}
48 |
54 | setMaxTokens(value === "" ? 0 : value)}
57 | min={0}
58 | precision={0}
59 | hideControls
60 | label="Maximum Tokens"
61 | required
62 | />
63 |
64 |
65 | {/* Temperature */}
66 |
72 |
73 |
74 |
75 |
76 | Temperature
77 |
78 |
79 |
80 | setTemp(value === "" ? 0 : value)}
83 | min={0}
84 | max={1}
85 | precision={2}
86 | hideControls
87 | size="xs"
88 | required
89 | />
90 |
91 |
92 | setTemp(value)}
95 | min={0}
96 | max={1}
97 | step={0.01}
98 | precision={2}
99 | mt="sm"
100 | label={null}
101 | />
102 |
103 |
104 |
105 | {/* Top P */}
106 |
112 |
113 |
114 |
115 |
116 | Top P
117 |
118 |
119 |
120 | setTopP(value === "" ? 0 : value)}
123 | min={0}
124 | max={1}
125 | precision={2}
126 | hideControls
127 | required
128 | size="xs"
129 | />
130 |
131 |
132 | setTopP(value)}
135 | min={0}
136 | max={1}
137 | step={0.01}
138 | precision={2}
139 | mt="sm"
140 | label={null}
141 | />
142 |
143 |
144 |
145 | {/* Top K */}
146 |
152 | setTopK(value === "" ? 0 : value)}
155 | min={0}
156 | precision={0}
157 | hideControls
158 | required
159 | label="Top K"
160 | />
161 |
162 |
163 | {/* Repeat Penalty */}
164 |
170 | setRepeatPenalty(value === "" ? 0 : value)}
173 | min={0}
174 | precision={2}
175 | hideControls
176 | required
177 | label="Repeat Penalty"
178 | />
179 |
180 |
181 | {/* Repeat last n */}
182 |
188 | setRepeatLastN(value === "" ? 0 : value)}
191 | min={0}
192 | precision={1}
193 | hideControls
194 | required
195 | label="Repeat Last N"
196 | />
197 |
198 |
199 | {/* Model status */}
200 |
201 | {status.data?.status}
202 |
203 |
204 | );
205 | }
206 |
--------------------------------------------------------------------------------
/src/server/llama/adapters/llamacpp.ts:
--------------------------------------------------------------------------------
1 | import { ChildProcess, spawn } from "child_process";
2 | import { Socket } from "net";
3 | import { Readable } from "stream";
4 | import { findRandomOpenPort } from "~/utils/utils";
5 |
6 | type LLaMATCPClientOptions = (
7 | | {
8 | modelPath: string;
9 | binPath: string;
10 | port: number | "auto";
11 | }
12 | | {
13 | port: number;
14 | host: string;
15 | }
16 | ) & {
17 | debug?: boolean | ((...args: any[]) => void);
18 | };
19 |
20 | export default class LLaMATCPClient {
21 | // configuration options
22 | private options: LLaMATCPClientOptions;
23 |
24 | // the llama-tcp server process
25 | private _process?: ChildProcess;
26 |
27 | // state
28 | public state: "loading" | "ready" | "error" = "loading";
29 | private _loadingPromise?: Promise;
30 |
31 | // logging
32 | private _log: (...args: any[]) => void;
33 |
34 | /**
35 | * Create a new LLaMATCPClient
36 | * @param options - options for the client
37 | */
38 | constructor(options: LLaMATCPClientOptions) {
39 | // set logging
40 | if (options.debug === true) {
41 | this._log = (...args: any[]) => console.log("[llama-tcp] ", ...args);
42 | } else if (typeof options.debug === "function") {
43 | this._log = options.debug;
44 | } else {
45 | this._log = () => {};
46 | }
47 |
48 | // set options
49 | this.options = options;
50 |
51 | // if binPath is set and equals "auto", set it to the default path
52 | if ("binPath" in this.options && this.options.binPath === "auto") {
53 | this.options.binPath = "./bin/main";
54 | }
55 | }
56 |
57 | /**
58 | * Start the LLaMATCPClient
59 | * Creates a server process if necessary
60 | */
61 | async start() {
62 | // if we're already loading, return the loading promise
63 | if (this._loadingPromise) return this._loadingPromise;
64 |
65 | // create a new promise for loading
66 | let resolve: () => void = () => {};
67 | let reject: (error: Error) => void = () => {};
68 | this._loadingPromise = new Promise(async (res, rej) => {
69 | resolve = res;
70 | reject = rej;
71 | });
72 |
73 | // only start a server instance if binPath is set
74 | if ("binPath" in this.options) {
75 | this._log("starting server");
76 |
77 | // find a random port if port is set to "auto"
78 | if (this.options.port === "auto") {
79 | this._log("start(): port is set to auto, finding random open port");
80 | this.options.port = await findRandomOpenPort();
81 | this._log("start(): found random open port: ", this.options.port);
82 | }
83 |
84 | // start the server process
85 | this._log("start(): starting server process");
86 | this._process = spawn(
87 | this.options.binPath!,
88 | ["-l", this.options.port.toString(), "-m", this.options.modelPath],
89 | {
90 | stdio: "inherit",
91 | }
92 | );
93 |
94 | // handle errors
95 | this._process.on("error", (error) => {
96 | console.error(error);
97 | this.state = "error";
98 | reject(error);
99 | });
100 |
101 | // wait for the server to start
102 | this._log("start(): waiting for server to start");
103 | let iterations = 0;
104 |
105 | while (iterations < 50) {
106 | try {
107 | await this._createConnection();
108 | this.state = "ready";
109 | this._log("start(): server started successfully");
110 | resolve();
111 | return;
112 | } catch (error) {
113 | this._log(
114 | "start(): error: ",
115 | (error as any).message,
116 | ", retrying in 1s (",
117 | iterations,
118 | "/50)"
119 | );
120 | iterations++;
121 | await new Promise((res) => setTimeout(res, 1000));
122 | }
123 | }
124 |
125 | // if we get here, the server failed to start
126 | this.state = "error";
127 | reject(new Error("Failed to start server!"));
128 | } else {
129 | // if we don't have a binPath, we're just connecting to a remote server
130 | this.state = "ready";
131 | resolve();
132 | }
133 | }
134 |
135 | /**
136 | * Creates a new connection and asks the server to complete the given text
137 | * @param text - the text to complete
138 | * @param options - options for the completion
139 | * @returns a readable stream of the completion
140 | */
141 | async complete(
142 | text: string,
143 | options: {
144 | [key: string]: string | number | boolean;
145 | }
146 | ) {
147 | // make sure we're ready
148 | if (this.state !== "ready") {
149 | await this._loadingPromise;
150 | }
151 |
152 | // create a new connection
153 | const client = await this._createConnection();
154 |
155 | // build the tcp request
156 | let request = "";
157 | let args = 0;
158 |
159 | // the text goes in as a -p argument
160 | options["-p"] = text;
161 |
162 | // add the options
163 | for (const key in options) {
164 | // if the value is boolean and false, skip it
165 | if (typeof options[key] === "boolean" && !options[key]) continue;
166 |
167 | request += key + "\x00";
168 | args++;
169 |
170 | // if value is number, convert it to string
171 | if (typeof options[key] === "number") {
172 | options[key] = options[key]!.toString();
173 | }
174 |
175 | // if value is a non-empty string, add it
176 | if (typeof options[key] === "string" && options[key] !== "") {
177 | request += options[key] + "\x00";
178 | args++;
179 | }
180 | }
181 |
182 | // append # of args to the start of the request
183 | request = args.toString() + "\n" + request;
184 |
185 | // log the built tcp packet
186 | this._log(
187 | `complete(): sending tcp packet: `,
188 | request.replace("\x00", "\\x00")
189 | );
190 |
191 | // create a readable stream from the connection
192 | const stream = new Readable({
193 | read() {},
194 | });
195 |
196 | // send the request
197 | client.write(request);
198 |
199 | // variables to keep track of where we are in the response
200 | let gotSamplingParameters = false;
201 | let promptIndex = 0;
202 |
203 | // handle data
204 | client.on("data", (data) => {
205 | // convert the data to a string
206 | let parsedData = data.toString();
207 | this._log(`complete(): received tcp packet: `, parsedData);
208 |
209 | // wait for the packet with samping parameters:
210 | // this marks the first line of response
211 | if (!gotSamplingParameters) {
212 | if (parsedData.includes("sampling parameters:")) {
213 | gotSamplingParameters = true;
214 |
215 | // get the last line of data -- this will be part of the prompt.
216 | let after = parsedData.split("\n").pop()!;
217 |
218 | // remove the trailing space
219 | after = after.substring(1);
220 |
221 | // if the length is greater than the prompt length, then we need to substring
222 | if (after.length > text.length) {
223 | parsedData = after.substring(text.length - 1);
224 | console.log(parsedData, after);
225 | } else {
226 | // otherwise, ignore this chunk and update promptIndex
227 | promptIndex = after.length;
228 | return;
229 | }
230 | }
231 |
232 | if (parsedData === "") return;
233 | }
234 |
235 | // wait until the prompt is finished being echoed back
236 | if (promptIndex < text.length) {
237 | let requiredPromptChars = text.length - promptIndex;
238 |
239 | // if the data includes same amt or less chars, discard
240 | if (parsedData.length <= requiredPromptChars) {
241 | promptIndex += parsedData.length;
242 | return;
243 | }
244 |
245 | // otherwise, we substring what is required
246 | parsedData = parsedData.substring(requiredPromptChars);
247 | promptIndex += requiredPromptChars;
248 | }
249 |
250 | // push the data to the stream
251 | stream.push(parsedData);
252 | });
253 |
254 | // when the connection closes, end the stream
255 | client.on("close", () => {
256 | stream.push(null);
257 | });
258 |
259 | // handle errors
260 | client.on("error", (error) => {
261 | stream.emit("error", error);
262 | });
263 |
264 | // return the stream
265 | return {
266 | stream,
267 | cancel: () => {
268 | client.destroy();
269 | },
270 | };
271 | }
272 |
273 | private _createConnection() {
274 | return new Promise((resolve, reject) => {
275 | // create a tcp client
276 | const client = new Socket();
277 |
278 | // make sure a port has been resolved
279 | if (this.options.port === "auto") {
280 | reject(
281 | new Error("_createConnection() called before port was resolved!")
282 | );
283 | return;
284 | }
285 |
286 | // default host is localhost
287 | const host = "host" in this.options ? this.options.host : "localhost";
288 |
289 | // connect it to the server
290 | client.connect(this.options.port, host, () => {
291 | this._log("_createConnection(): connected to server");
292 | resolve(client);
293 | });
294 |
295 | // handle errors
296 | client.on("error", (error) => {
297 | this._log("_createConnection(): error: ", error.message);
298 | reject(error);
299 | });
300 | });
301 | }
302 | }
303 |
--------------------------------------------------------------------------------