├── src ├── styles │ └── globals.css ├── pages │ ├── _document.tsx │ ├── index.tsx │ └── _app.tsx ├── server │ ├── api │ │ ├── types.ts │ │ ├── root.ts │ │ ├── shared.ts │ │ ├── trpc.ts │ │ └── routers │ │ │ └── llama.ts │ ├── llama │ │ ├── index.ts │ │ └── adapters │ │ │ └── llamacpp.ts │ ├── plugins │ │ └── next.ts │ └── index.ts ├── utils │ ├── utils.ts │ └── api.ts ├── recoil │ ├── states.ts │ └── templates.ts ├── partials │ ├── Header.tsx │ ├── PromptEditor.tsx │ ├── Generate.tsx │ ├── TemplateSelect.tsx │ └── ModelControls.tsx └── env.mjs ├── .vscode └── settings.json ├── public ├── demo.gif ├── logo.png └── favicon.ico ├── .dockerignore ├── postcss.config.cjs ├── prettier.config.cjs ├── tsup.config.ts ├── docker-compose.dev.yml ├── docker-compose.yml ├── next.config.mjs ├── .gitignore ├── tsconfig.json ├── .env.example ├── LICENSE ├── Dockerfile ├── .github └── workflows │ └── docker.yml ├── scripts └── startup.mjs ├── package.json └── README.md /src/styles/globals.css: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.tabSize": 2 3 | } -------------------------------------------------------------------------------- /public/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ItzDerock/llama-playground/HEAD/public/demo.gif -------------------------------------------------------------------------------- /public/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ItzDerock/llama-playground/HEAD/public/logo.png -------------------------------------------------------------------------------- /public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ItzDerock/llama-playground/HEAD/public/favicon.ico -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | /dist 2 | /bin 3 | /.next 4 | /.vscode 5 | /node_modules 6 | /docker-compose.yml 7 | /docker-compose.dev.yml 8 | -------------------------------------------------------------------------------- /postcss.config.cjs: -------------------------------------------------------------------------------- 1 | const config = { 2 | plugins: { 3 | autoprefixer: {}, 4 | }, 5 | }; 6 | 7 | module.exports = config; 8 | -------------------------------------------------------------------------------- /prettier.config.cjs: -------------------------------------------------------------------------------- 1 | /** @type {import("prettier").Config} */ 2 | const config = { 3 | plugins: [], 4 | }; 5 | 6 | module.exports = config; 7 | -------------------------------------------------------------------------------- /src/pages/_document.tsx: -------------------------------------------------------------------------------- 1 | import Document from "next/document"; 2 | import { createGetInitialProps } from "@mantine/next"; 3 | 4 | const getInitialProps = createGetInitialProps(); 5 | 6 | export default class _Document extends Document { 7 | static getInitialProps = getInitialProps; 8 | } 9 | -------------------------------------------------------------------------------- /tsup.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig, Options } from "tsup"; 2 | 3 | const opts: Options = { 4 | platform: "node", 5 | format: ["cjs"], 6 | treeshake: true, 7 | clean: true, 8 | sourcemap: true, 9 | }; 10 | 11 | export default defineConfig([ 12 | { 13 | entryPoints: ["src/server/index.ts"], 14 | ...opts, 15 | }, 16 | ]); 17 | -------------------------------------------------------------------------------- /src/server/api/types.ts: -------------------------------------------------------------------------------- 1 | export enum WSMessageType { 2 | // Identity -> server gives client a unique UUID 3 | IDENTITY, 4 | // Completion -> generated tokens are sent to the client 5 | COMPLETION, 6 | // Request Complete -> done generating tokens 7 | REQUEST_COMPLETE, 8 | } 9 | 10 | export type WSMessage = { 11 | type: WSMessageType; 12 | data: string; 13 | }; 14 | -------------------------------------------------------------------------------- /docker-compose.dev.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | app: 5 | build: . 6 | ports: 7 | - 3000:3000 8 | environment: 9 | - USE_BUILT_IN_LLAMA_SERVER=true 10 | - LLAMA_TCP_BIN=auto 11 | - LLAMA_SERVER_HOST=auto 12 | - LLAMA_SERVER_PORT=auto 13 | - LLAMA_MODEL_PATH=/app/models/ggml-model-q4_0.bin 14 | - NODE_ENV=production 15 | volumes: 16 | - ./path/to/model/:/app/models -------------------------------------------------------------------------------- /src/server/api/root.ts: -------------------------------------------------------------------------------- 1 | import { createTRPCRouter } from "~/server/api/trpc"; 2 | import { llamaRouter } from "./routers/llama"; 3 | 4 | /** 5 | * This is the primary router for your server. 6 | * 7 | * All routers added in /api/routers should be manually added here. 8 | */ 9 | export const appRouter = createTRPCRouter({ 10 | llama: llamaRouter 11 | }); 12 | 13 | // export type definition of API 14 | export type AppRouter = typeof appRouter; 15 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | app: 5 | image: ghcr.io/itzderock/llama-playground:latest 6 | ports: 7 | - 3000:3000 8 | - 3001:3001 9 | environment: 10 | - USE_BUILT_IN_LLAMA_SERVER=true 11 | - LLAMA_TCP_BIN=auto 12 | - LLAMA_SERVER_HOST=auto 13 | - LLAMA_SERVER_PORT=auto 14 | - LLAMA_MODEL_PATH=/app/models/ggml-model-q4_0.bin 15 | volumes: 16 | - ./path/to/7B/:/app/models -------------------------------------------------------------------------------- /next.config.mjs: -------------------------------------------------------------------------------- 1 | /** 2 | * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. 3 | * This is especially useful for Docker builds. 4 | */ 5 | !process.env.SKIP_ENV_VALIDATION && (await import("./src/env.mjs")); 6 | 7 | /** @type {import("next").NextConfig} */ 8 | const config = { 9 | reactStrictMode: true, 10 | 11 | /** 12 | * If you have the "experimental: { appDir: true }" setting enabled, then you 13 | * must comment the below `i18n` config out. 14 | * 15 | * @see https://github.com/vercel/next.js/issues/41980 16 | */ 17 | i18n: { 18 | locales: ["en"], 19 | defaultLocale: "en", 20 | }, 21 | }; 22 | export default config; 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # database 12 | /prisma/db.sqlite 13 | /prisma/db.sqlite-journal 14 | 15 | # next.js 16 | /.next/ 17 | /out/ 18 | next-env.d.ts 19 | 20 | # production 21 | /build 22 | /dist 23 | 24 | # misc 25 | .DS_Store 26 | *.pem 27 | 28 | # debug 29 | npm-debug.log* 30 | yarn-debug.log* 31 | yarn-error.log* 32 | .pnpm-debug.log* 33 | 34 | # local env files 35 | # do not commit any .env files to git, except for the .env.example file. https://create.t3.gg/en/usage/env-variables#using-environment-variables 36 | .env 37 | .env*.local 38 | 39 | # vercel 40 | .vercel 41 | 42 | # typescript 43 | *.tsbuildinfo 44 | 45 | # built binaries 46 | /bin/ -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es2017", 4 | "lib": ["dom", "dom.iterable", "esnext"], 5 | "allowJs": true, 6 | "checkJs": true, 7 | "skipLibCheck": true, 8 | "strict": true, 9 | "forceConsistentCasingInFileNames": true, 10 | "noEmit": true, 11 | "esModuleInterop": true, 12 | "module": "esnext", 13 | "moduleResolution": "node", 14 | "resolveJsonModule": true, 15 | "isolatedModules": true, 16 | "jsx": "preserve", 17 | "incremental": true, 18 | "noUncheckedIndexedAccess": true, 19 | "baseUrl": ".", 20 | "paths": { 21 | "~/*": ["./src/*"] 22 | } 23 | }, 24 | "include": [ 25 | ".eslintrc.cjs", 26 | "next-env.d.ts", 27 | "**/*.ts", 28 | "**/*.tsx", 29 | "**/*.cjs", 30 | "**/*.mjs" 31 | ], 32 | "exclude": ["node_modules"] 33 | } 34 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # The web-server's HOST and PORT 2 | HOST=127.0.0.1 3 | PORT=3000 4 | 5 | # Disable this if you want to run llama.cpp#tcp_server on your own 6 | # Uses the binary path set below 7 | USE_BUILT_IN_LLAMA_SERVER=true 8 | 9 | # Binary location for llama.cpp#tcp_server 10 | # Auto will automatically pull and build the latest version 11 | # Requires build-essential (or equivalent) and make 12 | # This does nothing if USE_BUILT_IN_LLAMA_SERVER is disabled 13 | LLAMA_TCP_BIN=auto 14 | 15 | # If USE_BUILT_IN_LLAMA_SERVER is disabled, enter the llama.cpp#tcp_server tcp details here 16 | # Otherwise, this app will automatically start a llama.cpp#tcp_server server 17 | # If port is set to auto, it will listen on a random open port 18 | LLAMA_SERVER_HOST=auto 19 | LLAMA_SERVER_PORT=auto 20 | 21 | # The path to a model's .bin file 22 | LLAMA_MODEL_PATH=/path/to/llama/ggml-model-q4_0.bin -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Derock 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/pages/index.tsx: -------------------------------------------------------------------------------- 1 | import { Flex, LoadingOverlay } from "@mantine/core"; 2 | import { type NextPage } from "next"; 3 | import Head from "next/head"; 4 | import { useRecoilValue } from "recoil"; 5 | import Generate from "~/partials/Generate"; 6 | import Header from "~/partials/Header"; 7 | import ModelControls from "~/partials/ModelControls"; 8 | import PromptEditor from "~/partials/PromptEditor"; 9 | import { ClientWSState, wsState } from "~/recoil/states"; 10 | 11 | const Home: NextPage = () => { 12 | const state = useRecoilValue(wsState); 13 | 14 | return ( 15 | <> 16 |
17 | 21 |
22 | 30 | 31 | 32 | 33 | 34 | 35 | 36 |
37 | 38 | ); 39 | }; 40 | 41 | export default Home; 42 | -------------------------------------------------------------------------------- /src/server/llama/index.ts: -------------------------------------------------------------------------------- 1 | import { env } from "~/env.mjs"; 2 | import LLaMATCPClient from "./adapters/llamacpp"; 3 | 4 | // https://stackoverflow.com/questions/64093560/can-you-keep-a-postgresql-connection-alive-from-within-a-next-js-api 5 | export function getLLaMAClient() { 6 | // TODO: in the future, allow multiple adapters like llama.cpp and llama-rs and other similar models. 7 | if (!("llamaClient" in global)) { 8 | if ( 9 | env.USE_BUILT_IN_LLAMA_SERVER === "false" && 10 | env.LLAMA_SERVER_PORT === "auto" 11 | ) { 12 | throw new Error( 13 | "LLAMA_SERVER_PORT must be set to a port number when USE_BUILT_IN_LLAMA_SERVER is false." 14 | ); 15 | } 16 | 17 | const LLaMAClient = new LLaMATCPClient( 18 | env.USE_BUILT_IN_LLAMA_SERVER === "true" 19 | ? { 20 | port: env.LLAMA_SERVER_PORT, 21 | modelPath: env.LLAMA_MODEL_PATH, 22 | binPath: env.LLAMA_TCP_BIN, 23 | debug: env.NODE_ENV === "development", 24 | } 25 | : { 26 | host: env.LLAMA_SERVER_HOST, 27 | port: env.LLAMA_SERVER_PORT as number, 28 | debug: env.NODE_ENV === "development", 29 | } 30 | ); 31 | 32 | LLaMAClient.start(); 33 | 34 | return ((global as any).llamaClient = LLaMAClient); 35 | } 36 | 37 | return (global as any).llamaClient as LLaMATCPClient; 38 | } 39 | -------------------------------------------------------------------------------- /src/utils/utils.ts: -------------------------------------------------------------------------------- 1 | import { createServer, AddressInfo } from "net"; 2 | 3 | /** 4 | * Searches randomly for an open port 5 | * @param min The minimum port number to search for 6 | * @param max The maximum port number to search for 7 | * @param _i The current iteration 8 | */ 9 | export function findRandomOpenPort(min: number = 1000, max: number = 65535, _i: number = 0) { 10 | // if _i is a multiple of 10, log a warning 11 | if (_i > 0 && _i % 10 === 0) { 12 | console.warn(`[warn] findRandomOpenPort: ${_i} iterations, no open port found`); 13 | } 14 | 15 | // generate random port number 16 | const randomNumber = Math.floor(Math.random() * (max - min + 1)) + min; 17 | 18 | // check if port is open 19 | return new Promise((resolve, reject) => { 20 | const server = createServer(); 21 | server.unref(); 22 | server.on("error", (error) => { 23 | // check the error 24 | if (error.message.includes("EADDRINUSE")) { 25 | // port is in use, try again 26 | resolve(findRandomOpenPort(min, max, _i + 1)); 27 | return; 28 | } 29 | 30 | // some other error, throw it 31 | reject(error); 32 | }); 33 | 34 | // attempt to listen on the port 35 | server.listen(randomNumber, () => { 36 | // success! close the server and return the port 37 | const { port } = server.address() as AddressInfo; 38 | server.close(() => { 39 | resolve(port); 40 | }); 41 | }); 42 | }); 43 | } -------------------------------------------------------------------------------- /src/recoil/states.ts: -------------------------------------------------------------------------------- 1 | import { atom } from "recoil"; 2 | 3 | // websocket state 4 | export const wsUUIDState = atom({ 5 | key: "ws_uuid", 6 | default: "", 7 | }); 8 | 9 | export enum ClientWSState { 10 | CONNECTING, 11 | READY, 12 | } 13 | 14 | export const wsState = atom({ 15 | key: "ws_state", 16 | default: ClientWSState.CONNECTING, 17 | }); 18 | 19 | // model options 20 | export const maxPredictionTokensState = atom({ 21 | key: "num_predict", 22 | default: 2048, 23 | }); 24 | 25 | export const nBatchState = atom({ 26 | key: "n_batch", 27 | default: 8, 28 | }); 29 | 30 | export const topKState = atom({ 31 | key: "top_k", 32 | default: 40, 33 | }); 34 | 35 | export const topPState = atom({ 36 | key: "top_p", 37 | default: 0.95, 38 | }); 39 | 40 | export const repeatPenaltyState = atom({ 41 | key: "repeat_penalty", 42 | default: 1.3, 43 | }); 44 | 45 | export const repeatLastNState = atom({ 46 | key: "repeat_last_n", 47 | default: 64, 48 | }); 49 | 50 | export const tempState = atom({ 51 | key: "temp", 52 | default: 0.8, 53 | }); 54 | 55 | export const promptState = atom({ 56 | key: "prompt", 57 | default: "asdf", 58 | }); 59 | 60 | export const promptTemplateState = atom({ 61 | key: "prompt_template", 62 | default: "", 63 | }); 64 | 65 | export const generatingState = atom({ 66 | key: "generating", 67 | default: false, 68 | }); 69 | 70 | export const generatedText = atom({ 71 | key: "generated_text", 72 | default: "", 73 | }); 74 | -------------------------------------------------------------------------------- /src/recoil/templates.ts: -------------------------------------------------------------------------------- 1 | import { stripIndents } from "common-tags"; 2 | import { atom } from "recoil"; 3 | 4 | export type TemplateData = { 5 | name: string; 6 | prompt: string; 7 | options?: Partial<{ 8 | maximum_tokens: number; 9 | temperature: number; 10 | top_p: number; 11 | top_k: number; 12 | repeat_penalty: number; 13 | repeat_last_n: number; 14 | }>; 15 | }; 16 | 17 | export const defaultTemplates: TemplateData[] = [ 18 | { 19 | name: "Question Answering", 20 | prompt: stripIndents` 21 | Below is an instruction that describes a task. Write a response that appropriately completes the task. The response must be accurate, concise, coherent, and evidence-based whenever possible. 22 | 23 | **Task:** Describe Machine Learning in your own words. 24 | `, 25 | }, 26 | { 27 | name: "Chatbot", 28 | prompt: stripIndents` 29 | Below is a conversation between two people. Write a response that appropriately completes a response for the AI. The response must be accurate, concise, coherent, and evidence-based whenever possible. Do not complete the User's part. 30 | 31 | User: What is Machine Learning? 32 | AI: 33 | `, 34 | }, 35 | { 36 | name: "Story", 37 | prompt: stripIndents` 38 | Below is a prompt for a story. The story must relate to the prompt and be creative and interesting. 39 | 40 | Prompt: It was a dark and stormy night... 41 | `, 42 | options: { 43 | temperature: 0.5, 44 | }, 45 | }, 46 | ]; 47 | 48 | export const allTemplatesState = atom({ 49 | key: "all_templates", 50 | default: defaultTemplates, 51 | }); 52 | -------------------------------------------------------------------------------- /src/server/api/shared.ts: -------------------------------------------------------------------------------- 1 | import EventEmitter from "events"; 2 | 3 | // this is a set of client IDs that are currently generating text 4 | export const generating = new Set(); 5 | export const clients = new Set(); 6 | 7 | // EventEmitter but with types 8 | export class TypedEventEmitter< 9 | TEvents extends Record 10 | > extends EventEmitter { 11 | emit( 12 | eventName: TEventName, 13 | ...eventArg: TEvents[TEventName] 14 | ) { 15 | return super.emit(eventName, ...(eventArg as [])); 16 | } 17 | 18 | on( 19 | eventName: TEventName, 20 | handler: (...eventArg: TEvents[TEventName]) => void 21 | ) { 22 | return super.on(eventName, handler as any); 23 | } 24 | 25 | off( 26 | eventName: TEventName, 27 | handler: (...eventArg: TEvents[TEventName]) => void 28 | ) { 29 | return super.off(eventName, handler as any); 30 | } 31 | 32 | incrementMaxListeners(amount = 1) { 33 | return this.setMaxListeners(this.getMaxListeners() + amount); 34 | } 35 | 36 | decrementMaxListeners(amount = 1) { 37 | return this.setMaxListeners(this.getMaxListeners() - amount); 38 | } 39 | } 40 | 41 | // the events that can be emitted 42 | type Events = { 43 | // client id, message 44 | generate: [string, string]; 45 | 46 | // client id 47 | "generate:done": [string]; 48 | 49 | // client id, error 50 | "generate:error": [string, Error]; 51 | 52 | // client id 53 | "generate:cancel": [string]; 54 | }; 55 | 56 | // global event emitter 57 | // can be replaced with a pubsub system if you want to scale 58 | export const events = new TypedEventEmitter(); 59 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ## Base image 2 | FROM node:18-alpine AS base 3 | 4 | # Install package manager 5 | RUN npm i --global --no-update-notifier --no-fund pnpm 6 | 7 | # Set the working directory 8 | WORKDIR /app 9 | 10 | ## Builder image 11 | FROM base AS web-builder 12 | 13 | # Copy the package.json and package-lock.json 14 | COPY package*.json ./ 15 | COPY pnpm-lock.yaml ./ 16 | 17 | # Install the dependencies 18 | RUN pnpm install --frozen-lockfile 19 | 20 | # Copy the source code 21 | COPY . . 22 | 23 | # Build the app 24 | RUN SKIP_ENV_VALIDATION=true npm run build 25 | 26 | ## Builder for llama-cpp 27 | FROM alpine:3.14 AS llama-cpp-builder 28 | 29 | # Install build dependencies 30 | RUN apk add --no-cache \ 31 | build-base \ 32 | make \ 33 | git 34 | 35 | # Set the working directory 36 | WORKDIR /app 37 | 38 | # Clone the repository 39 | RUN git clone --depth 1 --branch tcp_server \ 40 | https://github.com/ggerganov/llama.cpp.git . 41 | 42 | # Build 43 | RUN make 44 | 45 | # Production image based on alpine node:18 46 | FROM base AS production 47 | 48 | # Set the working directory 49 | WORKDIR /app 50 | 51 | # Copy the package.json and package-lock.json 52 | COPY package*.json ./ 53 | COPY pnpm-lock.yaml ./ 54 | 55 | # Install production dependencies 56 | RUN pnpm install --frozen-lockfile --production 57 | 58 | # Copy all the built files 59 | COPY --from=llama-cpp-builder /app/main ./bin/main 60 | COPY --from=web-builder /app/ ./ 61 | # COPY --from=web-builder /app/dist ./dist 62 | # COPY --from=web-builder /app/.next ./.next 63 | # COPY --from=web-builder /app/public ./public 64 | # COPY --from=web-builder /app/scripts ./scripts 65 | # COPY --from=web-builder /app/src/env.mjs ./src/env.mjs 66 | 67 | # Default environment variables 68 | ENV NODE_ENV=production 69 | ENV HOST=0.0.0.0 70 | 71 | # Done 72 | CMD ["pnpm", "run", "start"] 73 | -------------------------------------------------------------------------------- /src/server/plugins/next.ts: -------------------------------------------------------------------------------- 1 | import { FastifyInstance, FastifyReply, FastifyRequest } from "fastify"; 2 | import fastifyPlugin from "fastify-plugin"; 3 | import { IncomingMessage, OutgoingMessage, ServerResponse } from "http"; 4 | import next from "next"; 5 | import { NextServerOptions } from "next/dist/server/next"; 6 | 7 | async function nextPlugin( 8 | fastify: FastifyInstance, 9 | options: NextServerOptions 10 | ) { 11 | const nextServer = next(options); 12 | const handle = nextServer.getRequestHandler(); 13 | 14 | fastify 15 | .decorate("nextServer", nextServer) 16 | .decorate("nextHandle", handle) 17 | .decorate("next", route.bind(fastify)); 18 | 19 | return nextServer.prepare(); 20 | 21 | function route( 22 | path: string, 23 | opts: { method: string | string[] } = { method: "GET" } 24 | ) { 25 | if (typeof opts.method === "string") 26 | // @ts-ignore 27 | this[opts.method.toLowerCase()](path, opts, handler); 28 | else if (Array.isArray(opts.method)) { 29 | for (const method of opts.method) { 30 | // @ts-ignore 31 | this[method.toLowerCase()](path, opts, handler); 32 | } 33 | } 34 | 35 | async function handler(req: FastifyRequest, reply: FastifyReply) { 36 | for (const [key, value] of Object.entries(reply.getHeaders())) { 37 | if (value) reply.raw.setHeader(key, value); 38 | } 39 | 40 | await handle(req.raw, reply.raw); 41 | 42 | reply.hijack(); 43 | } 44 | } 45 | } 46 | 47 | export default fastifyPlugin(nextPlugin, { 48 | name: "next", 49 | fastify: "4.x", 50 | }); 51 | 52 | declare module "fastify" { 53 | interface FastifyInstance { 54 | nextServer: ReturnType; 55 | next: (path: string, opts?: { method: string | string[] }) => void; 56 | nextHandle: ( 57 | req: IncomingMessage, 58 | res: OutgoingMessage | ServerResponse 59 | ) => Promise; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Release to Github Packages. # https://ghcr.io 2 | on: # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#on 3 | workflow_dispatch: 4 | push: 5 | tags: # This builds for all branches with semantically versioned tags (v0.12.3). 6 | - 'v*' # https://semver.org will fail, if there are any other tags 7 | 8 | #release: # Builds only releases. Includes draft and pre-releases. 9 | #types: [created] 10 | 11 | #pull_request: # Run 'tests' for any PRs. Default is to not run for first-time contributors: see /settings/actions 12 | 13 | env: 14 | TAG_LATEST: true # Encourage users to use a major version (foobar:1) instead of :latest. 15 | # By semantic versioning standards, major changes are changes 'backwards incompatible'. Major upgrades are often rare and prehaps, need attention from the user. 16 | jobs: 17 | # Push image to GitHub Packages. 18 | push: 19 | runs-on: ubuntu-latest 20 | permissions: 21 | packages: write 22 | contents: read 23 | 24 | steps: 25 | - uses: actions/checkout@v2 26 | 27 | - 28 | name: Set up QEMU 29 | uses: docker/setup-qemu-action@v2 30 | - 31 | name: Set up Docker Buildx 32 | uses: docker/setup-buildx-action@v2 33 | - 34 | name: Login to GHCR 35 | if: github.event_name != 'pull_request' 36 | uses: docker/login-action@v2 37 | with: 38 | registry: ghcr.io 39 | username: ${{ github.repository_owner }} 40 | password: ${{ secrets.PAT }} 41 | - 42 | name: Extract Builder meta 43 | id: builder-meta 44 | uses: docker/metadata-action@v4 45 | with: 46 | images: ghcr.io/itzderock/llama-playground 47 | tags: | 48 | type=semver,pattern={{version}} 49 | type=semver,pattern={{major}}.{{minor}} 50 | - 51 | name: Build and push 52 | uses: docker/build-push-action@v3 53 | with: 54 | context: . 55 | platforms: linux/amd64,linux/arm64 56 | push: ${{ github.event_name != 'pull_request' }} 57 | tags: ${{ steps.builder-meta.outputs.tags }} 58 | labels: ${{ steps.builder-meta.outputs.labels }} 59 | -------------------------------------------------------------------------------- /src/pages/_app.tsx: -------------------------------------------------------------------------------- 1 | import { type AppType } from "next/app"; 2 | import { 3 | ColorScheme, 4 | ColorSchemeProvider, 5 | MantineProvider, 6 | } from "@mantine/core"; 7 | import { api } from "~/utils/api"; 8 | import { RecoilRoot } from "recoil"; 9 | 10 | import "~/styles/globals.css"; 11 | import { useColorScheme, useLocalStorage } from "@mantine/hooks"; 12 | import Head from "next/head"; 13 | 14 | const App: AppType = ({ Component, pageProps }) => { 15 | // load color scheme from local storage 16 | const preferredColorScheme = useColorScheme(); 17 | const [colorScheme, setColorScheme] = useLocalStorage({ 18 | key: "colorScheme", 19 | defaultValue: preferredColorScheme, 20 | }); 21 | 22 | const toggleColorScheme = (value?: ColorScheme) => 23 | setColorScheme(value ?? (colorScheme === "light" ? "dark" : "light")); 24 | 25 | return ( 26 | <> 27 | 28 | 32 | 33 | 34 | 35 | LLaMA Playground 36 | 37 | 38 | 42 | 43 | 44 | 45 | 46 | {/* thumbnail (large) */} 47 | 48 | 49 | 50 | 51 | 55 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | ); 67 | }; 68 | 69 | export default api.withTRPC(App); 70 | -------------------------------------------------------------------------------- /src/partials/Header.tsx: -------------------------------------------------------------------------------- 1 | import { 2 | createStyles, 3 | Header, 4 | Group, 5 | Button, 6 | Box, 7 | Text, 8 | useMantineColorScheme, 9 | Flex, 10 | } from "@mantine/core"; 11 | import { IconBrandGithub, IconMoon, IconSun } from "@tabler/icons-react"; 12 | import TemplateSelect from "./TemplateSelect"; 13 | 14 | const useStyles = createStyles((theme) => ({ 15 | text: { 16 | color: theme.colorScheme === "dark" ? theme.colors.dark[0] : theme.black, 17 | fontWeight: 500, 18 | }, 19 | 20 | hiddenMobile: { 21 | [theme.fn.smallerThan("sm")]: { 22 | display: "none", 23 | }, 24 | }, 25 | 26 | hiddenDesktop: { 27 | [theme.fn.largerThan("sm")]: { 28 | display: "none", 29 | }, 30 | }, 31 | })); 32 | 33 | export default function FullHeader() { 34 | const { colorScheme, toggleColorScheme } = useMantineColorScheme(); 35 | const { classes } = useStyles(); 36 | 37 | return ( 38 | 39 |
40 | 41 | LLaMA Playground 42 | 43 | 44 | 45 | 46 | 47 | 48 | 60 | 61 | 69 | 70 | 71 |
72 | 73 | 80 | 81 | 82 |
83 | ); 84 | } 85 | -------------------------------------------------------------------------------- /src/server/index.ts: -------------------------------------------------------------------------------- 1 | import fastify, { FastifyLoggerOptions } from "fastify"; 2 | import { LoggerOptions } from "pino"; 3 | import fastifyNextJS from "./plugins/next"; 4 | import { fastifyTRPCPlugin } from "@trpc/server/adapters/fastify"; 5 | import "dotenv/config"; 6 | import { env } from "~/env.mjs"; 7 | import { appRouter } from "./api/root"; 8 | import { createTRPCContext } from "./api/trpc"; 9 | import ws from "@fastify/websocket"; 10 | import path from "path"; 11 | import { fastifyStatic } from "@fastify/static"; 12 | 13 | // logger 14 | const envToLogger = { 15 | development: { 16 | transport: { 17 | target: "pino-pretty", 18 | options: { 19 | translateTime: "HH:MM:ss Z", 20 | ignore: "pid,hostname", 21 | }, 22 | }, 23 | }, 24 | production: true, 25 | test: false, 26 | } satisfies { 27 | [key: string]: (FastifyLoggerOptions & LoggerOptions) | boolean; 28 | }; 29 | 30 | // create fastify app 31 | export const app = fastify({ 32 | // prevent errors during large batch requests 33 | maxParamLength: 5000, 34 | 35 | // logger 36 | logger: envToLogger[env.NODE_ENV] ?? true, 37 | }); 38 | 39 | // register the next.js plugin 40 | app.register(fastifyNextJS, { 41 | dev: env.NODE_ENV !== "production", 42 | dir: ".", 43 | }); 44 | 45 | // serve static files in /public 46 | app.register(fastifyStatic, { 47 | root: path.join(__dirname, "..", "public"), 48 | prefix: "/public/", 49 | }); 50 | 51 | // register ws for websocket support 52 | // needed for subscriptions 53 | app.register(ws); 54 | 55 | // register the trpc plugin 56 | app.register(fastifyTRPCPlugin, { 57 | prefix: "/api/trpc", 58 | useWSS: true, 59 | trpcOptions: { 60 | router: appRouter, 61 | createContext: createTRPCContext, 62 | }, 63 | }); 64 | 65 | // pass to next.js if no route is defined 66 | app.addHook("onRequest", async (req, rep) => { 67 | // if a route is defined, skip 68 | if (req.routerPath) { 69 | return; 70 | } 71 | 72 | // pass along headers 73 | for (const [key, value] of Object.entries(rep.getHeaders())) { 74 | if (value) rep.raw.setHeader(key, value); 75 | } 76 | 77 | // otherwise, pass to next.js 78 | await app.nextHandle(req.raw, rep.raw); 79 | rep.hijack(); 80 | }); 81 | 82 | // start the server 83 | app.listen({ port: parseInt(env.PORT), host: env.HOST }).then((address) => { 84 | console.log(`🚀 Server ready at ${address}`); 85 | }); 86 | -------------------------------------------------------------------------------- /scripts/startup.mjs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env zx 2 | import "dotenv/config"; 3 | import "zx/globals"; 4 | import os from "os"; 5 | import path from "path"; 6 | import { env } from "../src/env.mjs"; 7 | 8 | // don't run if user is providing their own server. 9 | if (!env.USE_BUILT_IN_LLAMA_SERVER) process.exit(0); 10 | 11 | // check if user is providing their own binary 12 | if (env.LLAMA_TCP_BIN != "auto") { 13 | if (!fs.existsSync(env.LLAMA_TCP_BIN)) { 14 | console.error( 15 | `❌ LLAMA_TCP_BIN is set to '${env.LLAMA_TCP_BIN}' but that file doesn't exist.` 16 | ); 17 | process.exit(1); 18 | } 19 | 20 | process.exit(0); 21 | } 22 | 23 | // check if llama.cpp's tcp_server is already built 24 | if (fs.existsSync("./bin/") && fs.existsSync("./bin/main")) { 25 | console.log(chalk.green("✅ Llama.cpp's tcp_server is already built.")); 26 | process.exit(0); 27 | } 28 | 29 | // clone and build the repo 30 | const { path: tempDir, delete: deleteTempDir } = createTempDir(); 31 | console.log( 32 | chalk.gray( 33 | `🔃 (1/4) Created temporary directory at "${chalk.green( 34 | tempDir 35 | )}". Cloning repo...` 36 | ) 37 | ); 38 | 39 | // clone the repo 40 | await $`git clone --depth 1 --branch tcp_server https://github.com/ggerganov/llama.cpp ${tempDir}`; 41 | console.log( 42 | chalk.gray(`🔃 (2/4) Cloned llama.cpp's tcp_server branch to ${tempDir}`) 43 | ); 44 | 45 | // build the repo 46 | console.log(chalk.gray(`🔃 (3/4) Building llama.cpp's tcp_server...`)); 47 | try { 48 | await $`cd ${tempDir} && make -j`; 49 | } catch (error) { 50 | console.error( 51 | chalk.red(`❌ Failed to build llama.cpp's tcp_server! Error: ${error}`) 52 | ); 53 | console.error(`🗑️ Cleaning up...`); 54 | deleteTempDir(); 55 | process.exit(1); 56 | } 57 | 58 | // copy the binary to the bin folder 59 | console.log(chalk.gray(`🔃 (4/4) Copying binary to ./bin/main...`)); 60 | fs.mkdirSync("./bin/"); 61 | fs.copyFileSync(path.join(tempDir, "main"), "./bin/main"); 62 | 63 | // done 64 | console.log(chalk.green("✅ Built llama.cpp's tcp_server! Cleaning up...")); 65 | deleteTempDir(); 66 | 67 | /** 68 | * Creates a temporary directory and returns that directory's path and a function to delete it. 69 | * @returns {{ path: string; delete: () => void; }} 70 | */ 71 | function createTempDir() { 72 | const random = Math.random().toString(36).substring(2, 15); 73 | const tempDir = path.join(os.tmpdir(), `llama-playground-tmp-${random}`); 74 | 75 | if (!fs.existsSync(tempDir)) { 76 | fs.mkdirSync(tempDir); 77 | return { 78 | path: tempDir, 79 | delete: () => fs.removeSync(tempDir), 80 | }; 81 | } 82 | 83 | return createTempDir(); 84 | } 85 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "llama-playground", 3 | "version": "1.0.1", 4 | "private": true, 5 | "scripts": { 6 | "build": "npm-run-all build:next build:server", 7 | "build:next": "next build", 8 | "build:server": "tsup", 9 | "dev": "npm-run-all build:server dev:run", 10 | "dev:run": "cross-env DEBUG=true REACT_EDITOR=code NODE_ENV=development RECOIL_DUPLICATE_ATOM_KEY_CHECKING_ENABLED=false node --enable-source-maps dist", 11 | "lint": "next lint", 12 | "start": "node --enable-source-maps ./dist/index.js", 13 | "prestart": "node ./scripts/startup.mjs" 14 | }, 15 | "dependencies": { 16 | "@emotion/react": "^11.10.6", 17 | "@emotion/server": "^11.10.0", 18 | "@fastify/static": "^6.9.0", 19 | "@fastify/websocket": "^7.1.3", 20 | "@mantine/core": "^6.0.2", 21 | "@mantine/form": "^6.0.4", 22 | "@mantine/hooks": "^6.0.2", 23 | "@mantine/next": "^6.0.2", 24 | "@mantine/tiptap": "^6.0.2", 25 | "@tabler/icons-react": "^2.11.0", 26 | "@tanstack/react-query": "^4.20.2", 27 | "@tiptap/core": ">=2.0.0-beta.209 <3.0.0", 28 | "@tiptap/extension-document": "2.0.0-beta.220", 29 | "@tiptap/extension-highlight": "2.0.0-beta.220", 30 | "@tiptap/extension-link": "2.0.0-beta.220", 31 | "@tiptap/extension-paragraph": "2.0.0-beta.220", 32 | "@tiptap/extension-placeholder": "2.0.0-beta.220", 33 | "@tiptap/extension-text": "2.0.0-beta.220", 34 | "@tiptap/pm": ">=2.0.0-beta.209 <3.0.0", 35 | "@tiptap/react": "2.0.0-beta.220", 36 | "@trpc/client": "^10.9.0", 37 | "@trpc/next": "^10.9.0", 38 | "@trpc/react-query": "^10.9.0", 39 | "@trpc/server": "^10.9.0", 40 | "@types/common-tags": "^1.8.1", 41 | "common-tags": "^1.8.2", 42 | "cross-env": "^7.0.3", 43 | "dotenv": "^16.0.3", 44 | "fastify": "^4.15.0", 45 | "fastify-plugin": "^4.5.0", 46 | "next": "^13.2.1", 47 | "npm-run-all": "^4.1.5", 48 | "pino": "^8.11.0", 49 | "react": "18.2.0", 50 | "react-dom": "18.2.0", 51 | "recoil": "^0.7.7", 52 | "state": "link:@tiptap/pm/state", 53 | "superjson": "1.9.1", 54 | "ts-node": "^10.9.1", 55 | "tsup": "^6.7.0", 56 | "zod": "^3.20.6", 57 | "zx": "^7.2.1" 58 | }, 59 | "devDependencies": { 60 | "@types/eslint": "^8.21.1", 61 | "@types/node": "^18.14.0", 62 | "@types/prettier": "^2.7.2", 63 | "@types/react": "^18.0.28", 64 | "@types/react-dom": "^18.0.11", 65 | "@typescript-eslint/eslint-plugin": "^5.53.0", 66 | "@typescript-eslint/parser": "^5.53.0", 67 | "autoprefixer": "^10.4.14", 68 | "eslint": "^8.34.0", 69 | "eslint-config-next": "^13.2.1", 70 | "pino-pretty": "^10.0.0", 71 | "postcss": "^8.4.14", 72 | "prettier": "^2.8.1", 73 | "typescript": "^4.9.5" 74 | }, 75 | "ct3aMetadata": { 76 | "initVersion": "7.8.0" 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/partials/PromptEditor.tsx: -------------------------------------------------------------------------------- 1 | import { RichTextEditor } from "@mantine/tiptap"; 2 | import { useEditor, type Content } from "@tiptap/react"; 3 | import { useEffect } from "react"; 4 | import { useRecoilState } from "recoil"; 5 | import { 6 | generatedText, 7 | generatingState, 8 | promptState, 9 | promptTemplateState, 10 | } from "~/recoil/states"; 11 | import Document from "@tiptap/extension-document"; 12 | import Text from "@tiptap/extension-text"; 13 | import Paragraph from "@tiptap/extension-paragraph"; 14 | import Placeholder from "@tiptap/extension-placeholder"; 15 | import Highlight from "@tiptap/extension-highlight"; 16 | 17 | export default function PromptEditor() { 18 | const [_prompt, setPrompt] = useRecoilState(promptState); 19 | const [generating] = useRecoilState(generatingState); 20 | const [generated] = useRecoilState(generatedText); 21 | const [templatePrompt, __] = useRecoilState(promptTemplateState); 22 | 23 | // build a tiptap editor 24 | const editor = useEditor({ 25 | extensions: [ 26 | Document, 27 | Paragraph, 28 | Text, 29 | Placeholder.configure({ 30 | placeholder: "Q: What is a llama?", 31 | }), 32 | Highlight, 33 | ], 34 | onUpdate: ({ editor }) => { 35 | setPrompt( 36 | editor.getText({ 37 | blockSeparator: "\n", 38 | }) 39 | ); 40 | }, 41 | }); 42 | 43 | // handle enabling/disabling the editor 44 | useEffect(() => { 45 | if (generating) { 46 | editor?.setEditable(false); 47 | editor?.chain().focus("end").setHighlight({ color: "#FFF68F" }).run(); 48 | } else { 49 | editor?.commands.unsetHighlight(); 50 | editor?.setEditable(true); 51 | } 52 | }, [generating]); 53 | 54 | // handle updating the generated text 55 | useEffect(() => { 56 | if (!editor) return; 57 | 58 | editor 59 | .chain() 60 | .focus("end") 61 | .setHighlight({ color: "#FFF68F" }) 62 | .insertContent( 63 | generated === "\n" 64 | ? { 65 | type: "paragraph", 66 | content: [], 67 | } 68 | : generated 69 | ) 70 | .run(); 71 | }, [generated]); 72 | 73 | // handle prompt changes 74 | useEffect(() => { 75 | if (!editor) return; 76 | 77 | // delete all content and insert the new prompt 78 | editor 79 | .chain() 80 | .selectAll() 81 | .deleteSelection() 82 | .insertContent(templatePrompt) 83 | .run(); 84 | }, [templatePrompt]); 85 | 86 | return ( 87 | 97 | 98 | 99 | ); 100 | } 101 | -------------------------------------------------------------------------------- /src/server/api/trpc.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * YOU PROBABLY DON'T NEED TO EDIT THIS FILE, UNLESS: 3 | * 1. You want to modify request context (see Part 1). 4 | * 2. You want to create a new middleware or type of procedure (see Part 3). 5 | * 6 | * TL;DR - This is where all the tRPC server stuff is created and plugged in. The pieces you will 7 | * need to use are documented accordingly near the end. 8 | */ 9 | 10 | /** 11 | * 1. CONTEXT 12 | * 13 | * This section defines the "contexts" that are available in the backend API. 14 | * 15 | * These allow you to access things when processing a request, like the database, the session, etc. 16 | */ 17 | import { type CreateFastifyContextOptions } from "@trpc/server/adapters/fastify"; 18 | 19 | /** Replace this with an object if you want to pass things to `createContextInner`. */ 20 | type CreateContextOptions = Record; 21 | 22 | /** 23 | * This helper generates the "internals" for a tRPC context. If you need to use it, you can export 24 | * it from here. 25 | * 26 | * Examples of things you may need it for: 27 | * - testing, so we don't have to mock Next.js' req/res 28 | * - tRPC's `createSSGHelpers`, where we don't have req/res 29 | * 30 | * @see https://create.t3.gg/en/usage/trpc#-servertrpccontextts 31 | */ 32 | const createInnerTRPCContext = (_opts: CreateContextOptions) => { 33 | return {}; 34 | }; 35 | 36 | /** 37 | * This is the actual context you will use in your router. It will be used to process every request 38 | * that goes through your tRPC endpoint. 39 | * 40 | * @see https://trpc.io/docs/context 41 | */ 42 | export const createTRPCContext = (_opts: CreateFastifyContextOptions) => { 43 | return createInnerTRPCContext({}); 44 | }; 45 | 46 | /** 47 | * 2. INITIALIZATION 48 | * 49 | * This is where the tRPC API is initialized, connecting the context and transformer. We also parse 50 | * ZodErrors so that you get typesafety on the frontend if your procedure fails due to validation 51 | * errors on the backend. 52 | */ 53 | import { initTRPC } from "@trpc/server"; 54 | import superjson from "superjson"; 55 | import { ZodError } from "zod"; 56 | 57 | const t = initTRPC.context().create({ 58 | transformer: superjson, 59 | errorFormatter({ shape, error }) { 60 | return { 61 | ...shape, 62 | data: { 63 | ...shape.data, 64 | zodError: 65 | error.cause instanceof ZodError ? error.cause.flatten() : null, 66 | }, 67 | }; 68 | }, 69 | }); 70 | 71 | /** 72 | * 3. ROUTER & PROCEDURE (THE IMPORTANT BIT) 73 | * 74 | * These are the pieces you use to build your tRPC API. You should import these a lot in the 75 | * "/src/server/api/routers" directory. 76 | */ 77 | 78 | /** 79 | * This is how you create new routers and sub-routers in your tRPC API. 80 | * 81 | * @see https://trpc.io/docs/router 82 | */ 83 | export const createTRPCRouter = t.router; 84 | 85 | /** 86 | * Public (unauthenticated) procedure 87 | * 88 | * This is the base piece you use to build new queries and mutations on your tRPC API. It does not 89 | * guarantee that a user querying is authorized, but you can still access user session data if they 90 | * are logged in. 91 | */ 92 | export const publicProcedure = t.procedure; 93 | -------------------------------------------------------------------------------- /src/utils/api.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * This is the client-side entrypoint for your tRPC API. It is used to create the `api` object which 3 | * contains the Next.js App-wrapper, as well as your type-safe React Query hooks. 4 | * 5 | * We also create a few inference helpers for input and output types. 6 | */ 7 | import { 8 | createWSClient, 9 | httpBatchLink, 10 | loggerLink, 11 | wsLink, 12 | } from "@trpc/client"; 13 | import { createTRPCNext } from "@trpc/next"; 14 | import { type inferRouterInputs, type inferRouterOutputs } from "@trpc/server"; 15 | import superjson from "superjson"; 16 | 17 | import { type AppRouter } from "~/server/api/root"; 18 | 19 | const getBaseUrl = () => { 20 | if (typeof window !== "undefined") return ""; // browser should use relative url 21 | if (process.env.VERCEL_URL) return `https://${process.env.VERCEL_URL}`; // SSR should use vercel url 22 | return `http://localhost:${process.env.PORT ?? 3000}`; // dev SSR should use localhost 23 | }; 24 | 25 | /** A set of type-safe react-query hooks for your tRPC API. */ 26 | export const api = createTRPCNext({ 27 | config() { 28 | const links = [ 29 | loggerLink({ 30 | enabled: (opts) => 31 | process.env.NODE_ENV === "development" || 32 | (opts.direction === "down" && opts.result instanceof Error), 33 | }), 34 | ]; 35 | 36 | // If we're in the browser, we'll use a WebSocket connection to the server. 37 | if (typeof window !== "undefined") { 38 | // url should be the page url but with ws(s) instead of http(s) 39 | let url = window.location.host; 40 | 41 | // add protocol 42 | if (window.location.protocol === "https:") { 43 | url = "wss://" + url; 44 | } else { 45 | url = "ws://" + url; 46 | } 47 | 48 | links.push( 49 | wsLink({ 50 | client: createWSClient({ 51 | url: `${url}/api/trpc`, 52 | }), 53 | }) 54 | ); 55 | } else { 56 | // If we're on the server, we'll use HTTP batching. 57 | links.push( 58 | httpBatchLink({ 59 | url: `${getBaseUrl()}/api/trpc`, 60 | }) 61 | ); 62 | } 63 | 64 | return { 65 | /** 66 | * Transformer used for data de-serialization from the server. 67 | * 68 | * @see https://trpc.io/docs/data-transformers 69 | */ 70 | transformer: superjson, 71 | 72 | /** 73 | * Links used to determine request flow from client to server. 74 | * 75 | * @see https://trpc.io/docs/links 76 | */ 77 | links, 78 | }; 79 | }, 80 | 81 | /** 82 | * Whether tRPC should await queries when server rendering pages. 83 | * 84 | * @see https://trpc.io/docs/nextjs#ssr-boolean-default-false 85 | */ 86 | ssr: false, 87 | }); 88 | 89 | /** 90 | * Inference helper for inputs. 91 | * 92 | * @example type HelloInput = RouterInputs['example']['hello'] 93 | */ 94 | export type RouterInputs = inferRouterInputs; 95 | 96 | /** 97 | * Inference helper for outputs. 98 | * 99 | * @example type HelloOutput = RouterOutputs['example']['hello'] 100 | */ 101 | export type RouterOutputs = inferRouterOutputs; 102 | -------------------------------------------------------------------------------- /src/partials/Generate.tsx: -------------------------------------------------------------------------------- 1 | import { Flex, Button } from "@mantine/core"; 2 | import { useRecoilState, useRecoilValue } from "recoil"; 3 | import { 4 | promptState, 5 | generatingState, 6 | wsState, 7 | wsUUIDState, 8 | ClientWSState, 9 | generatedText, 10 | maxPredictionTokensState, 11 | topKState, 12 | topPState, 13 | repeatPenaltyState, 14 | tempState, 15 | repeatLastNState, 16 | } from "~/recoil/states"; 17 | import { WSMessageType } from "~/server/api/types"; 18 | import { api } from "~/utils/api"; 19 | 20 | export default function Generate() { 21 | const [prompt] = useRecoilState(promptState); 22 | const [loading, setLoading] = useRecoilState(generatingState); 23 | const [state, setWSState] = useRecoilState(wsState); 24 | const [uuid, setUUID] = useRecoilState(wsUUIDState); 25 | const [_generated, setGenerated] = useRecoilState(generatedText); 26 | const nPredict = useRecoilValue(maxPredictionTokensState); 27 | const topK = useRecoilValue(topKState); 28 | const topP = useRecoilValue(topPState); 29 | const repeatPenalty = useRecoilValue(repeatPenaltyState); 30 | const temp = useRecoilValue(tempState); 31 | const repeatLastN = useRecoilValue(repeatLastNState); 32 | 33 | // subscribe to generation updates 34 | api.llama.subscription.useSubscription(undefined, { 35 | onData: (data) => { 36 | // wait for identity 37 | if (data.type === WSMessageType.IDENTITY) { 38 | setUUID(data.data); 39 | setWSState(ClientWSState.READY); 40 | return; 41 | } 42 | 43 | // loading false when generation done 44 | if (data.type === WSMessageType.REQUEST_COMPLETE) { 45 | setLoading(false); 46 | return; 47 | } 48 | 49 | // update generated text 50 | if (data.type === WSMessageType.COMPLETION) { 51 | setGenerated(data.data); 52 | return; 53 | } 54 | }, 55 | 56 | onError: (error) => { 57 | setWSState(ClientWSState.CONNECTING); 58 | }, 59 | }); 60 | 61 | // mutations 62 | const generate = api.llama.startGeneration.useMutation(); 63 | const cancel = api.llama.cancelGeneration.useMutation(); 64 | 65 | return ( 66 | 67 | 89 | 90 | {/* cancel */} 91 | {loading && ( 92 | 101 | )} 102 | 103 | ); 104 | } 105 | -------------------------------------------------------------------------------- /src/env.mjs: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | 3 | /** 4 | * Specify your server-side environment variables schema here. This way you can ensure the app isn't 5 | * built with invalid env vars. 6 | */ 7 | const server = z.object({ 8 | NODE_ENV: z 9 | .enum(["development", "test", "production"]) 10 | .default("development"), 11 | USE_BUILT_IN_LLAMA_SERVER: z 12 | .literal("true") 13 | .or(z.literal("false")) 14 | .or(z.boolean()), 15 | LLAMA_SERVER_HOST: z.string().min(1).or(z.literal("auto")), 16 | LLAMA_SERVER_PORT: z.number().int().positive().or(z.literal("auto")), 17 | LLAMA_MODEL_PATH: z.string().min(1), 18 | LLAMA_TCP_BIN: z.string().min(1).or(z.literal("auto")), 19 | PORT: z.string(), 20 | HOST: z.string(), 21 | }); 22 | 23 | /** 24 | * Specify your client-side environment variables schema here. This way you can ensure the app isn't 25 | * built with invalid env vars. To expose them to the client, prefix them with `NEXT_PUBLIC_`. 26 | */ 27 | const client = z.object({ 28 | // NEXT_PUBLIC_CLIENTVAR: z.string().min(1), 29 | }); 30 | 31 | /** 32 | * You can't destruct `process.env` as a regular object in the Next.js edge runtimes (e.g. 33 | * middlewares) or client-side so we need to destruct manually. 34 | * 35 | * @type {Record | keyof z.infer, string | number | undefined>} 36 | */ 37 | const processEnv = { 38 | NODE_ENV: process.env.NODE_ENV, 39 | USE_BUILT_IN_LLAMA_SERVER: process.env.USE_BUILT_IN_LLAMA_SERVER, 40 | LLAMA_SERVER_HOST: process.env.LLAMA_SERVER_HOST, 41 | LLAMA_SERVER_PORT: isNaN(parseInt(process.env.LLAMA_SERVER_PORT ?? "")) 42 | ? process.env.LLAMA_SERVER_PORT 43 | : parseInt(process.env.LLAMA_SERVER_PORT ?? ""), 44 | LLAMA_MODEL_PATH: process.env.LLAMA_MODEL_PATH, 45 | LLAMA_TCP_BIN: process.env.LLAMA_TCP_BIN, 46 | PORT: process.env.PORT ?? "3000", 47 | HOST: process.env.HOST ?? "localhost", 48 | }; 49 | 50 | // Don't touch the part below 51 | // -------------------------- 52 | 53 | const merged = server.merge(client); 54 | 55 | /** @typedef {z.input} MergedInput */ 56 | /** @typedef {z.infer} MergedOutput */ 57 | /** @typedef {z.SafeParseReturnType} MergedSafeParseReturn */ 58 | 59 | let env = /** @type {MergedOutput} */ (process.env); 60 | 61 | if (!!process.env.SKIP_ENV_VALIDATION == false) { 62 | const isServer = typeof window === "undefined"; 63 | 64 | const parsed = /** @type {MergedSafeParseReturn} */ ( 65 | isServer 66 | ? merged.safeParse(processEnv) // on server we can validate all env vars 67 | : client.safeParse(processEnv) // on client we can only validate the ones that are exposed 68 | ); 69 | 70 | if (parsed.success === false) { 71 | console.error( 72 | "❌ Invalid environment variables:", 73 | parsed.error.flatten().fieldErrors 74 | ); 75 | throw new Error("Invalid environment variables"); 76 | } 77 | 78 | env = new Proxy(parsed.data, { 79 | get(target, prop) { 80 | if (typeof prop !== "string") return undefined; 81 | // Throw a descriptive error if a server-side env var is accessed on the client 82 | // Otherwise it would just be returning `undefined` and be annoying to debug 83 | if (!isServer && !prop.startsWith("NEXT_PUBLIC_")) 84 | throw new Error( 85 | process.env.NODE_ENV === "production" 86 | ? "❌ Attempted to access a server-side environment variable on the client" 87 | : `❌ Attempted to access server-side environment variable '${prop}' on the client` 88 | ); 89 | return target[/** @type {keyof typeof target} */ (prop)]; 90 | }, 91 | }); 92 | } 93 | 94 | export { env }; 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦙 LLaMA Playground 🛝 2 | 3 | A simple Open-AI inspired interface that uses [llama.cpp#tcp_server](https://github.com/ggerganov/llama.cpp/tree/tcp_server) in the background. 4 | 5 | ![demo](./public/demo.gif) 6 | 7 | ## Difference vs. other interfaces 8 | 9 | Other interfaces use the llama.cpp cli command to run the model. This is not ideal since it requires to spawn a new process for each request. This is not only slow but also requires to load the model each time. This interface uses the llama.cpp tcp_server to run the model in the background. This allows to run multiple requests in parallel and also to cache the model in memory. 10 | 11 | ## Features 12 | 13 | - Simple to use UI 14 | - Able to handle multiple requests in parallel quickly 15 | - Controls to change the model parameters on the fly 16 | - Does not require rebooting, changes are applied instantly 17 | - Save and load templates to save your work 18 | - Templates are saved in the browser's local storage and are not sent to the server 19 | 20 | ## About 21 | 22 | Built on top of a modified [T3-stack](https://github.com/t3-oss/create-t3-app) application. 23 | Fastify is used instead the regular next.js server since websocket support is needed. 24 | [Mantine](https://mantine.dev/) is used for the UI. 25 | [tRPC](https://trpc.io/) is used for an end-to-end type-safe API. 26 | 27 | The fastify server starts a tcp_server from llama.cpp in the background. 28 | Upon each request, the server establishes a new TCP connection to the tcp_server and sends the request. 29 | Output is then forwarded to the client via websockets. 30 | 31 | ## Notice 32 | 33 | This is not meant to be used in production. There is no rate-limiting, no authentication, etc. It is just a simple interface to play with the models. 34 | 35 | ## Usage 36 | 37 | ### Getting the model 38 | 39 | This repository will not include the model weights as these are the property of Meta. Do not share the weights in this repository. 40 | 41 | Currently, the application will not convert and quantize the model for you. You will need to do this yourself. This means you will need the llama.cpp build dependencies. 42 | 43 | - For ubuntu: `build-essentail make python3` 44 | - For arch: `base-devel make python3` 45 | 46 | ```bash 47 | # build this repo 48 | git clone https://github.com/ggerganov/llama.cpp 49 | cd llama.cpp 50 | make 51 | 52 | # obtain the original LLaMA model weights and place them in ./models 53 | ls ./models 54 | 65B 30B 13B 7B tokenizer_checklist.chk tokenizer.model 55 | 56 | # install Python dependencies 57 | python3 -m pip install torch numpy sentencepiece 58 | 59 | # convert the 7B model to ggml FP16 format 60 | python3 convert-pth-to-ggml.py models/7B/ 1 61 | 62 | # quantize the model to 4-bits 63 | python3 quantize.py 7B 64 | ``` 65 | 66 | ^ (source [llama.cpp/README.md](https://github.com/ggerganov/llama.cpp/)) 67 | 68 | Then you can start the server using one of the below methods: 69 | 70 | ### With Docker 71 | 72 | ```bash 73 | # Clone the repository 74 | git clone --depth 1 https://github.com/ItzDerock/llama-playground . 75 | 76 | # Edit the docker-compose.yml file to point to the correct model 77 | vim docker-compose.yml 78 | 79 | # Start the server 80 | docker-compose up -d 81 | ``` 82 | 83 | ### Without Docker 84 | 85 | ```bash 86 | # Clone the repository 87 | git clone --depth 1 https://github.com/ItzDerock/llama-playground . 88 | 89 | # Install dependencies 90 | pnpm install # you will need pnpm 91 | 92 | # Edit the .env file to point to the correct model 93 | vim .env 94 | 95 | # Build the server 96 | pnpm build 97 | 98 | # Start the server 99 | pnpm start 100 | ``` 101 | 102 | ## Development 103 | 104 | Run `pnpm run dev` to start the development server. 105 | 106 | ## License 107 | 108 | [MIT](./LICENSE) 109 | -------------------------------------------------------------------------------- /src/server/api/routers/llama.ts: -------------------------------------------------------------------------------- 1 | import { observable } from "@trpc/server/observable"; 2 | import { createTRPCRouter, publicProcedure } from "~/server/api/trpc"; 3 | import { getLLaMAClient } from "~/server/llama"; 4 | import { randomUUID } from "crypto"; 5 | import { generating, events, clients } from "~/server/api/shared"; 6 | import { z } from "zod"; 7 | import { WSMessageType, WSMessage } from "../types"; 8 | 9 | export const llamaRouter = createTRPCRouter({ 10 | status: publicProcedure.query(() => { 11 | const client = getLLaMAClient(); 12 | 13 | return { 14 | status: client.state, 15 | }; 16 | }), 17 | 18 | subscription: publicProcedure.subscription(() => { 19 | return observable((emit) => { 20 | // each client gets a unique UUID so we can keep track of them 21 | const uuid = randomUUID(); 22 | 23 | // send the client their UUID 24 | emit.next({ 25 | type: WSMessageType.IDENTITY, 26 | data: uuid, 27 | }); 28 | 29 | // register this client 30 | clients.add(uuid); 31 | 32 | // create a callback function 33 | function completion(id: string, completion: string) { 34 | if (id === uuid) { 35 | emit.next({ 36 | type: WSMessageType.COMPLETION, 37 | data: completion, 38 | }); 39 | } 40 | } 41 | 42 | // done callback 43 | function done(id: string) { 44 | if (id === uuid) { 45 | emit.next({ 46 | type: WSMessageType.REQUEST_COMPLETE, 47 | data: "", 48 | }); 49 | } 50 | } 51 | 52 | // register callbacks 53 | events.incrementMaxListeners(2); 54 | events.on("generate", completion); 55 | events.on("generate:done", done); 56 | 57 | // handle disconnects 58 | return () => { 59 | // remove the client from the clients set 60 | clients.delete(uuid); 61 | 62 | // remove the callbacks 63 | events.off("generate", completion); 64 | events.off("generate:done", done); 65 | 66 | // decrement the max listeners 67 | events.decrementMaxListeners(2); 68 | 69 | // remove the client from the generating if they are generating 70 | if (generating.has(uuid)) { 71 | generating.delete(uuid); 72 | events.emit("generate:cancel", uuid); 73 | } 74 | }; 75 | }); 76 | }), 77 | 78 | startGeneration: publicProcedure 79 | .input( 80 | z.object({ 81 | prompt: z.string(), 82 | options: z.object({ 83 | "--seed": z.string().optional(), 84 | "--threads": z.number().optional(), 85 | "--n_predict": z.number().optional(), 86 | "--top_k": z.number().optional(), 87 | "--top_p": z.number().optional(), 88 | "--repeat_last_n": z.number().optional(), 89 | "--repeat_penalty": z.number().optional(), 90 | "--ctx_size": z.number().optional(), 91 | "--ignore-eos": z.boolean().optional(), 92 | "--memory_f16": z.boolean().optional(), 93 | "--temp": z.number().optional(), 94 | "--n_parts": z.number().optional(), 95 | "--batch_size": z.number().optional(), 96 | "--perplexity": z.number().optional(), 97 | }), 98 | uuid: z.string(), 99 | }) 100 | ) 101 | .mutation(async ({ input }) => { 102 | // validate uuid 103 | if (!clients.has(input.uuid)) { 104 | throw new Error("Invalid UUID"); 105 | } 106 | 107 | // check if the client is already generating 108 | if (generating.has(input.uuid)) { 109 | throw new Error("Client is already generating"); 110 | } 111 | 112 | // set the client's generating state to true 113 | generating.add(input.uuid); 114 | 115 | // start generation 116 | const client = getLLaMAClient(); 117 | const { stream, cancel } = await client.complete( 118 | input.prompt, 119 | input.options 120 | ); 121 | 122 | // forward the generated tokens to the client 123 | stream.on("data", (data) => { 124 | events.emit("generate", input.uuid, data.toString()); 125 | }); 126 | 127 | // handle cancel 128 | function cancelRequest(id: string) { 129 | if (id === input.uuid) { 130 | cancel(); 131 | } 132 | } 133 | 134 | // increment the max listeners 135 | events.incrementMaxListeners(1); 136 | events.on("generate:cancel", cancelRequest); 137 | 138 | // when the client is done generating, set the client's generating state to false 139 | stream.on("end", () => { 140 | generating.delete(input.uuid); 141 | 142 | // send the client a message saying that generation is complete 143 | events.emit("generate:done", input.uuid); 144 | 145 | // clean up 146 | events.off("generate:cancel", cancelRequest); 147 | events.decrementMaxListeners(1); 148 | }); 149 | }), 150 | 151 | cancelGeneration: publicProcedure 152 | .input( 153 | z.object({ 154 | uuid: z.string(), 155 | }) 156 | ) 157 | .mutation(async ({ input }) => { 158 | // validate uuid 159 | if (!clients.has(input.uuid)) { 160 | throw new Error("Invalid UUID"); 161 | } 162 | 163 | // check if the client is already generating 164 | if (!generating.has(input.uuid)) { 165 | throw new Error("Client is not generating"); 166 | } 167 | 168 | // cancel generation 169 | events.emit("generate:cancel", input.uuid); 170 | }), 171 | }); 172 | -------------------------------------------------------------------------------- /src/partials/TemplateSelect.tsx: -------------------------------------------------------------------------------- 1 | import { Button, Flex, Modal, Select, TextInput } from "@mantine/core"; 2 | import { useDisclosure } from "@mantine/hooks"; 3 | import { useForm } from "@mantine/form"; 4 | import { useEffect } from "react"; 5 | import { useRecoilState, useRecoilValue } from "recoil"; 6 | import { 7 | generatingState, 8 | maxPredictionTokensState, 9 | promptState, 10 | promptTemplateState, 11 | repeatLastNState, 12 | repeatPenaltyState, 13 | tempState, 14 | topKState, 15 | topPState, 16 | } from "~/recoil/states"; 17 | import { 18 | allTemplatesState, 19 | defaultTemplates, 20 | TemplateData, 21 | } from "~/recoil/templates"; 22 | 23 | function setIfExists(setFunction: (arg0: T) => void, value?: T) { 24 | if (value !== undefined) setFunction(value); 25 | } 26 | 27 | export default function TemplateSelect({ fullWidth }: { fullWidth?: boolean }) { 28 | const [templates, setTemplates] = useRecoilState(allTemplatesState); 29 | 30 | // the model options 31 | const [maxTokens, setMaxTokens] = useRecoilState(maxPredictionTokensState); 32 | const [topK, setTopK] = useRecoilState(topKState); 33 | const [topP, setTopP] = useRecoilState(topPState); 34 | const [repeatPenalty, setRepeatPenalty] = useRecoilState(repeatPenaltyState); 35 | const [repeatLastN, setRepeatLastN] = useRecoilState(repeatLastNState); 36 | const [temp, setTemp] = useRecoilState(tempState); 37 | const [prompt, _setPrompt] = useRecoilState(promptState); 38 | const [_, setTemplatePrompt] = useRecoilState(promptTemplateState); 39 | 40 | // generation state 41 | const generating = useRecoilValue(generatingState); 42 | 43 | // modal state 44 | const [opened, { open, close }] = useDisclosure(false); 45 | const form = useForm({ 46 | initialValues: { 47 | name: "", 48 | }, 49 | 50 | validate: { 51 | name: (value) => { 52 | if (!value) return "Name is required"; 53 | if (defaultTemplates.some((template) => template.name === value)) 54 | return "Cannot override default templates."; 55 | }, 56 | }, 57 | }); 58 | 59 | // read templates from local storage 60 | useEffect(() => { 61 | const storedTemplates = localStorage.getItem("templates"); 62 | if (storedTemplates) { 63 | try { 64 | setTemplates((defaultTemplates) => [ 65 | ...defaultTemplates, 66 | ...JSON.parse(storedTemplates ?? "[]"), 67 | ]); 68 | } catch (error) { 69 | console.error(`Error parsing templates:`, error); 70 | } 71 | } 72 | }, []); 73 | 74 | // save templates to local storage 75 | useEffect(() => { 76 | // only save if they are not the default templates 77 | const userCreatedTemplates = templates.filter( 78 | (template) => !defaultTemplates.some((tmp) => tmp.name === template.name) 79 | ); 80 | 81 | localStorage.setItem("templates", JSON.stringify(userCreatedTemplates)); 82 | }, [templates]); 83 | 84 | // save 85 | function saveTemplate(name: string) { 86 | const templateData = { 87 | name, 88 | prompt, 89 | options: { 90 | maximum_tokens: maxTokens, 91 | repeat_last_n: repeatLastN, 92 | temperature: temp, 93 | repeat_penalty: repeatPenalty, 94 | top_p: topP, 95 | top_k: topK, 96 | }, 97 | } satisfies TemplateData; 98 | 99 | setTemplates((templates) => [...templates, templateData]); 100 | close(); 101 | } 102 | 103 | // load 104 | function loadTemplate(name: string) { 105 | const template = templates.find((template) => template.name === name); 106 | if (!template) return; 107 | 108 | setTemplatePrompt(template.prompt); 109 | setIfExists(setMaxTokens, template.options?.maximum_tokens); 110 | setIfExists(setTopK, template.options?.top_k); 111 | setIfExists(setTopP, template.options?.top_p); 112 | setIfExists(setRepeatPenalty, template.options?.repeat_penalty); 113 | setIfExists(setRepeatLastN, template.options?.repeat_last_n); 114 | setIfExists(setTemp, template.options?.temperature); 115 | } 116 | 117 | return ( 118 | <> 119 | {/* modal */} 120 | 121 |
saveTemplate(data.name))}> 122 | 128 | 131 | 132 |
133 | 134 | 143 | 46 | 47 | {/* Maximum Length */} 48 | 54 | setMaxTokens(value === "" ? 0 : value)} 57 | min={0} 58 | precision={0} 59 | hideControls 60 | label="Maximum Tokens" 61 | required 62 | /> 63 | 64 | 65 | {/* Temperature */} 66 | 72 | 73 | 74 | 75 | 76 | Temperature 77 | 78 | 79 | 80 | setTemp(value === "" ? 0 : value)} 83 | min={0} 84 | max={1} 85 | precision={2} 86 | hideControls 87 | size="xs" 88 | required 89 | /> 90 | 91 | 92 | setTemp(value)} 95 | min={0} 96 | max={1} 97 | step={0.01} 98 | precision={2} 99 | mt="sm" 100 | label={null} 101 | /> 102 | 103 | 104 | 105 | {/* Top P */} 106 | 112 | 113 | 114 | 115 | 116 | Top P 117 | 118 | 119 | 120 | setTopP(value === "" ? 0 : value)} 123 | min={0} 124 | max={1} 125 | precision={2} 126 | hideControls 127 | required 128 | size="xs" 129 | /> 130 | 131 | 132 | setTopP(value)} 135 | min={0} 136 | max={1} 137 | step={0.01} 138 | precision={2} 139 | mt="sm" 140 | label={null} 141 | /> 142 | 143 | 144 | 145 | {/* Top K */} 146 | 152 | setTopK(value === "" ? 0 : value)} 155 | min={0} 156 | precision={0} 157 | hideControls 158 | required 159 | label="Top K" 160 | /> 161 | 162 | 163 | {/* Repeat Penalty */} 164 | 170 | setRepeatPenalty(value === "" ? 0 : value)} 173 | min={0} 174 | precision={2} 175 | hideControls 176 | required 177 | label="Repeat Penalty" 178 | /> 179 | 180 | 181 | {/* Repeat last n */} 182 | 188 | setRepeatLastN(value === "" ? 0 : value)} 191 | min={0} 192 | precision={1} 193 | hideControls 194 | required 195 | label="Repeat Last N" 196 | /> 197 | 198 | 199 | {/* Model status */} 200 | 201 | {status.data?.status} 202 | 203 | 204 | ); 205 | } 206 | -------------------------------------------------------------------------------- /src/server/llama/adapters/llamacpp.ts: -------------------------------------------------------------------------------- 1 | import { ChildProcess, spawn } from "child_process"; 2 | import { Socket } from "net"; 3 | import { Readable } from "stream"; 4 | import { findRandomOpenPort } from "~/utils/utils"; 5 | 6 | type LLaMATCPClientOptions = ( 7 | | { 8 | modelPath: string; 9 | binPath: string; 10 | port: number | "auto"; 11 | } 12 | | { 13 | port: number; 14 | host: string; 15 | } 16 | ) & { 17 | debug?: boolean | ((...args: any[]) => void); 18 | }; 19 | 20 | export default class LLaMATCPClient { 21 | // configuration options 22 | private options: LLaMATCPClientOptions; 23 | 24 | // the llama-tcp server process 25 | private _process?: ChildProcess; 26 | 27 | // state 28 | public state: "loading" | "ready" | "error" = "loading"; 29 | private _loadingPromise?: Promise; 30 | 31 | // logging 32 | private _log: (...args: any[]) => void; 33 | 34 | /** 35 | * Create a new LLaMATCPClient 36 | * @param options - options for the client 37 | */ 38 | constructor(options: LLaMATCPClientOptions) { 39 | // set logging 40 | if (options.debug === true) { 41 | this._log = (...args: any[]) => console.log("[llama-tcp] ", ...args); 42 | } else if (typeof options.debug === "function") { 43 | this._log = options.debug; 44 | } else { 45 | this._log = () => {}; 46 | } 47 | 48 | // set options 49 | this.options = options; 50 | 51 | // if binPath is set and equals "auto", set it to the default path 52 | if ("binPath" in this.options && this.options.binPath === "auto") { 53 | this.options.binPath = "./bin/main"; 54 | } 55 | } 56 | 57 | /** 58 | * Start the LLaMATCPClient 59 | * Creates a server process if necessary 60 | */ 61 | async start() { 62 | // if we're already loading, return the loading promise 63 | if (this._loadingPromise) return this._loadingPromise; 64 | 65 | // create a new promise for loading 66 | let resolve: () => void = () => {}; 67 | let reject: (error: Error) => void = () => {}; 68 | this._loadingPromise = new Promise(async (res, rej) => { 69 | resolve = res; 70 | reject = rej; 71 | }); 72 | 73 | // only start a server instance if binPath is set 74 | if ("binPath" in this.options) { 75 | this._log("starting server"); 76 | 77 | // find a random port if port is set to "auto" 78 | if (this.options.port === "auto") { 79 | this._log("start(): port is set to auto, finding random open port"); 80 | this.options.port = await findRandomOpenPort(); 81 | this._log("start(): found random open port: ", this.options.port); 82 | } 83 | 84 | // start the server process 85 | this._log("start(): starting server process"); 86 | this._process = spawn( 87 | this.options.binPath!, 88 | ["-l", this.options.port.toString(), "-m", this.options.modelPath], 89 | { 90 | stdio: "inherit", 91 | } 92 | ); 93 | 94 | // handle errors 95 | this._process.on("error", (error) => { 96 | console.error(error); 97 | this.state = "error"; 98 | reject(error); 99 | }); 100 | 101 | // wait for the server to start 102 | this._log("start(): waiting for server to start"); 103 | let iterations = 0; 104 | 105 | while (iterations < 50) { 106 | try { 107 | await this._createConnection(); 108 | this.state = "ready"; 109 | this._log("start(): server started successfully"); 110 | resolve(); 111 | return; 112 | } catch (error) { 113 | this._log( 114 | "start(): error: ", 115 | (error as any).message, 116 | ", retrying in 1s (", 117 | iterations, 118 | "/50)" 119 | ); 120 | iterations++; 121 | await new Promise((res) => setTimeout(res, 1000)); 122 | } 123 | } 124 | 125 | // if we get here, the server failed to start 126 | this.state = "error"; 127 | reject(new Error("Failed to start server!")); 128 | } else { 129 | // if we don't have a binPath, we're just connecting to a remote server 130 | this.state = "ready"; 131 | resolve(); 132 | } 133 | } 134 | 135 | /** 136 | * Creates a new connection and asks the server to complete the given text 137 | * @param text - the text to complete 138 | * @param options - options for the completion 139 | * @returns a readable stream of the completion 140 | */ 141 | async complete( 142 | text: string, 143 | options: { 144 | [key: string]: string | number | boolean; 145 | } 146 | ) { 147 | // make sure we're ready 148 | if (this.state !== "ready") { 149 | await this._loadingPromise; 150 | } 151 | 152 | // create a new connection 153 | const client = await this._createConnection(); 154 | 155 | // build the tcp request 156 | let request = ""; 157 | let args = 0; 158 | 159 | // the text goes in as a -p argument 160 | options["-p"] = text; 161 | 162 | // add the options 163 | for (const key in options) { 164 | // if the value is boolean and false, skip it 165 | if (typeof options[key] === "boolean" && !options[key]) continue; 166 | 167 | request += key + "\x00"; 168 | args++; 169 | 170 | // if value is number, convert it to string 171 | if (typeof options[key] === "number") { 172 | options[key] = options[key]!.toString(); 173 | } 174 | 175 | // if value is a non-empty string, add it 176 | if (typeof options[key] === "string" && options[key] !== "") { 177 | request += options[key] + "\x00"; 178 | args++; 179 | } 180 | } 181 | 182 | // append # of args to the start of the request 183 | request = args.toString() + "\n" + request; 184 | 185 | // log the built tcp packet 186 | this._log( 187 | `complete(): sending tcp packet: `, 188 | request.replace("\x00", "\\x00") 189 | ); 190 | 191 | // create a readable stream from the connection 192 | const stream = new Readable({ 193 | read() {}, 194 | }); 195 | 196 | // send the request 197 | client.write(request); 198 | 199 | // variables to keep track of where we are in the response 200 | let gotSamplingParameters = false; 201 | let promptIndex = 0; 202 | 203 | // handle data 204 | client.on("data", (data) => { 205 | // convert the data to a string 206 | let parsedData = data.toString(); 207 | this._log(`complete(): received tcp packet: `, parsedData); 208 | 209 | // wait for the packet with samping parameters: 210 | // this marks the first line of response 211 | if (!gotSamplingParameters) { 212 | if (parsedData.includes("sampling parameters:")) { 213 | gotSamplingParameters = true; 214 | 215 | // get the last line of data -- this will be part of the prompt. 216 | let after = parsedData.split("\n").pop()!; 217 | 218 | // remove the trailing space 219 | after = after.substring(1); 220 | 221 | // if the length is greater than the prompt length, then we need to substring 222 | if (after.length > text.length) { 223 | parsedData = after.substring(text.length - 1); 224 | console.log(parsedData, after); 225 | } else { 226 | // otherwise, ignore this chunk and update promptIndex 227 | promptIndex = after.length; 228 | return; 229 | } 230 | } 231 | 232 | if (parsedData === "") return; 233 | } 234 | 235 | // wait until the prompt is finished being echoed back 236 | if (promptIndex < text.length) { 237 | let requiredPromptChars = text.length - promptIndex; 238 | 239 | // if the data includes same amt or less chars, discard 240 | if (parsedData.length <= requiredPromptChars) { 241 | promptIndex += parsedData.length; 242 | return; 243 | } 244 | 245 | // otherwise, we substring what is required 246 | parsedData = parsedData.substring(requiredPromptChars); 247 | promptIndex += requiredPromptChars; 248 | } 249 | 250 | // push the data to the stream 251 | stream.push(parsedData); 252 | }); 253 | 254 | // when the connection closes, end the stream 255 | client.on("close", () => { 256 | stream.push(null); 257 | }); 258 | 259 | // handle errors 260 | client.on("error", (error) => { 261 | stream.emit("error", error); 262 | }); 263 | 264 | // return the stream 265 | return { 266 | stream, 267 | cancel: () => { 268 | client.destroy(); 269 | }, 270 | }; 271 | } 272 | 273 | private _createConnection() { 274 | return new Promise((resolve, reject) => { 275 | // create a tcp client 276 | const client = new Socket(); 277 | 278 | // make sure a port has been resolved 279 | if (this.options.port === "auto") { 280 | reject( 281 | new Error("_createConnection() called before port was resolved!") 282 | ); 283 | return; 284 | } 285 | 286 | // default host is localhost 287 | const host = "host" in this.options ? this.options.host : "localhost"; 288 | 289 | // connect it to the server 290 | client.connect(this.options.port, host, () => { 291 | this._log("_createConnection(): connected to server"); 292 | resolve(client); 293 | }); 294 | 295 | // handle errors 296 | client.on("error", (error) => { 297 | this._log("_createConnection(): error: ", error.message); 298 | reject(error); 299 | }); 300 | }); 301 | } 302 | } 303 | --------------------------------------------------------------------------------