├── .eslintrc.json ├── .github └── workflows │ ├── js.yaml │ └── lint.yaml ├── .gitignore ├── .npmrc ├── .pre-commit-config.yaml ├── .prettierrc ├── LICENSE ├── README.md ├── apis ├── cloudflare │ ├── .gitignore │ ├── README.md │ ├── package.json │ ├── src │ │ ├── env.ts │ │ ├── index.ts │ │ ├── lib.ts │ │ ├── metric-aggregator.ts │ │ ├── proxy.ts │ │ ├── realtime-logger.ts │ │ ├── realtime.ts │ │ └── tracing.ts │ ├── tsconfig.json │ ├── tsup.config.ts │ ├── worker-configuration.d.ts │ └── wrangler-template.toml ├── node │ ├── README.md │ ├── package.json │ ├── src │ │ ├── anthropic.ts │ │ ├── cache.ts │ │ ├── env.ts │ │ ├── index.js │ │ ├── local.ts │ │ ├── login.ts │ │ └── node-proxy.ts │ └── tsconfig.json └── vercel │ ├── .eslintrc.json │ ├── .gitignore │ ├── .npmrc │ ├── README.md │ ├── app │ ├── 404.html │ └── layout.tsx │ ├── components │ └── headers.tsx │ ├── next-env.d.ts │ ├── next.config.js │ ├── package.json │ ├── pages │ ├── _app.tsx │ ├── api │ │ ├── ping.ts │ │ └── v1 │ │ │ └── [...slug].ts │ └── index.tsx │ ├── postcss.config.js │ ├── public │ └── favicon.ico │ ├── tailwind.config.js │ └── tsconfig.json ├── package.json ├── packages └── proxy │ ├── .eslintrc.cjs │ ├── .gitignore │ ├── edge │ ├── deps.test.ts │ ├── exporter.ts │ └── index.ts │ ├── package.json │ ├── schema │ ├── audio.ts │ ├── deps.test.ts │ ├── index.test.ts │ ├── index.ts │ ├── model_list.json │ ├── models.test.ts │ ├── models.ts │ ├── openai-realtime.ts │ └── secrets.ts │ ├── scripts │ └── sync_models.ts │ ├── src │ ├── PrometheusSerializer.ts │ ├── constants.ts │ ├── deps.test.ts │ ├── index.ts │ ├── metrics.ts │ ├── providers │ │ ├── anthropic.test.ts │ │ ├── anthropic.ts │ │ ├── azure.ts │ │ ├── bedrock.ts │ │ ├── databricks.ts │ │ ├── google.test.ts │ │ ├── google.ts │ │ ├── openai.test.ts │ │ ├── openai.ts │ │ └── util.ts │ ├── proxy.ts │ ├── util.test.ts │ └── util.ts │ ├── tsconfig.json │ ├── tsup.config.ts │ ├── turbo.json │ ├── types │ ├── anthropic.ts │ ├── index.ts │ └── openai.ts │ ├── utils │ ├── audioEncoder.ts │ ├── deps.test.ts │ ├── encrypt.ts │ ├── index.ts │ ├── openai.ts │ ├── tempCredentials.test.ts │ ├── tempCredentials.ts │ └── tests.ts │ └── vitest.config.js ├── pnpm-lock.yaml ├── pnpm-workspace.yaml ├── turbo.json └── vitest.config.js /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "root": true, 3 | "extends": ["next/core-web-vitals", "turbo"] 4 | } 5 | -------------------------------------------------------------------------------- /.github/workflows/js.yaml: -------------------------------------------------------------------------------- 1 | name: js 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | strategy: 13 | matrix: 14 | node-version: 15 | - 20 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: actions/setup-node@v4 20 | with: 21 | node-version: ${{ matrix.node-version }} 22 | registry-url: "https://registry.npmjs.org" 23 | - uses: pnpm/action-setup@v4 24 | - run: | 25 | pnpm install 26 | pnpm run test 27 | pnpm run build 28 | env: 29 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 30 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 31 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} 32 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/marketplace/actions/pre-commit 2 | name: lint 3 | 4 | on: 5 | pull_request: 6 | push: 7 | branches: [main] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/setup-python@v3 15 | - uses: pre-commit/action@v3.0.0 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # Dependencies 4 | node_modules 5 | .pnp 6 | .pnp.js 7 | 8 | # Testing 9 | /coverage 10 | 11 | # Next.js 12 | .next 13 | out 14 | 15 | # Production 16 | build 17 | dist 18 | 19 | # Misc 20 | .DS_Store 21 | *.pem 22 | tsconfig.tsbuildinfo 23 | 24 | # Debug 25 | npm-debug.log* 26 | yarn-debug.log* 27 | yarn-error.log* 28 | 29 | # Local ENV files 30 | .env.local 31 | .env.development.local 32 | .env.test.local 33 | .env.production.local 34 | 35 | # Vercel 36 | .vercel 37 | 38 | # Turborepo 39 | .turbo 40 | 41 | *.swp 42 | -------------------------------------------------------------------------------- /.npmrc: -------------------------------------------------------------------------------- 1 | enable-pre-post-scripts=true 2 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: "https://github.com/pre-commit/pre-commit-hooks" 3 | rev: v4.4.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - repo: https://github.com/codespell-project/codespell 8 | rev: v2.2.5 9 | hooks: 10 | - id: codespell 11 | exclude: > 12 | (?x)^( 13 | .*\.(json|prisma|svg)| 14 | .*pnpm-lock.yaml 15 | )$ 16 | args: ["-L rouge,coo,couldn,unsecure,afterall"] 17 | - repo: https://github.com/rbubley/mirrors-prettier 18 | rev: v3.3.2 19 | hooks: 20 | - id: prettier 21 | exclude: ^(extension/|.*\.json|.*pnpm-lock.yaml$) 22 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "singleQuote": false 3 | } 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023-2024 Braintrust Data, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Braintrust AI Proxy 2 | 3 | The Braintrust AI proxy offers a unified way to access the world's leading AI models through a single API, including 4 | models from [OpenAI](https://platform.openai.com/docs/models), [Anthropic](https://docs.anthropic.com/claude/reference/getting-started-with-the-api), [LLaMa 2](https://ai.meta.com/llama/), 5 | [Mistral](https://mistral.ai/), and more. The benefits of using the proxy include: 6 | 7 | - **Code Simplification**: Use a consistent API across different AI providers. 8 | - **Cost Reduction**: The proxy automatically caches results, reusing them when possible. 9 | - **Enhanced Observability**: Log requests automatically for better tracking and debugging. \[Coming soon\] 10 | 11 | See the full list of supported models [here](https://www.braintrust.dev/docs/guides/proxy#supported-models). 12 | To read more about why we launched the AI proxy, check out our [announcement blog post](https://braintrust.dev/blog/ai-proxy). 13 | 14 | This repository contains the code for the proxy — both the underlying implementation and wrappers that allow you to 15 | deploy it on [Vercel](https://vercel.com), [Cloudflare](https://developers.cloudflare.com/workers/), 16 | [AWS Lambda](https://aws.amazon.com/lambda/), or an [Express](https://expressjs.com/) server. 17 | 18 | ## Just let me try it! 19 | 20 | You can communicate with the proxy via the standard OpenAI drivers/API, and simply set the base url to 21 | `https://api.braintrust.dev/v1/proxy`. Try running the following script in your favorite language, twice. 22 | 23 | ### TypeScript 24 | 25 | ```javascript copy 26 | import { OpenAI } from "openai"; 27 | const client = new OpenAI({ 28 | baseURL: "https://api.braintrust.dev/v1/proxy", 29 | apiKey: process.env.OPENAI_API_KEY, // Can use Braintrust, Anthropic, etc. keys here too 30 | }); 31 | 32 | async function main() { 33 | const start = performance.now(); 34 | const response = await client.chat.completions.create({ 35 | model: "gpt-3.5-turbo", // // Can use claude-2, llama-2-13b-chat here too 36 | messages: [{ role: "user", content: "What is a proxy?" }], 37 | seed: 1, // A seed activates the proxy's cache 38 | }); 39 | console.log(response.choices[0].message.content); 40 | console.log(`Took ${(performance.now() - start) / 1000}s`); 41 | } 42 | 43 | main(); 44 | ``` 45 | 46 | ### Python 47 | 48 | ```python copy 49 | from openai import OpenAI 50 | import os 51 | import time 52 | 53 | client = OpenAI( 54 | base_url="https://api.braintrust.dev/v1/proxy", 55 | api_key=os.environ["OPENAI_API_KEY"], # Can use Braintrust, Anthropic, etc. keys here too 56 | ) 57 | 58 | start = time.time() 59 | response = client.chat.completions.create( 60 | model="gpt-3.5-turbo", # Can use claude-2, llama-2-13b-chat here too 61 | messages=[{"role": "user", "content": "What is a proxy?"}], 62 | seed=1, # A seed activates the proxy's cache 63 | ) 64 | print(response.choices[0].message.content) 65 | print(f"Took {time.time()-start}s") 66 | ``` 67 | 68 | ### cURL 69 | 70 | ```bash copy 71 | time curl -i https://api.braintrust.dev/v1/proxy/chat/completions \ 72 | -H "Content-Type: application/json" \ 73 | -d '{ 74 | "model": "gpt-3.5-turbo", 75 | "messages": [ 76 | { 77 | "role": "user", 78 | "content": "What is a proxy?" 79 | } 80 | ], 81 | "seed": 1 82 | }' \ 83 | -H "Authorization: Bearer $OPENAI_API_KEY" # Can use Braintrust, Anthropic, etc. keys here too 84 | ``` 85 | 86 | ## Deploying 87 | 88 | You can find the full documentation for using the proxy [here](https://www.braintrust.dev/docs/guides/proxy). 89 | The proxy is hosted for you, with end-to-end encryption, at `https://api.braintrust.dev/v1/proxy`. However, you 90 | can also deploy it yourself and customize its behavior. 91 | 92 | To see docs for how to deploy on various platforms, see the READMEs in the corresponding folders: 93 | 94 | - [Vercel](./apis/vercel) 95 | - [Cloudflare](./apis/cloudflare) 96 | - [AWS Lambda](./apis/node) 97 | - [Express](./apis/node) 98 | 99 | ## Developing 100 | 101 | To build the proxy, install [pnpm](https://pnpm.io/installation) and run: 102 | 103 | ```bash 104 | pnpm install 105 | pnpm build 106 | ``` 107 | -------------------------------------------------------------------------------- /apis/cloudflare/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | 3 | logs 4 | _.log 5 | npm-debug.log_ 6 | yarn-debug.log* 7 | yarn-error.log* 8 | lerna-debug.log* 9 | .pnpm-debug.log* 10 | 11 | # Diagnostic reports (https://nodejs.org/api/report.html) 12 | 13 | report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json 14 | 15 | # Runtime data 16 | 17 | pids 18 | _.pid 19 | _.seed 20 | \*.pid.lock 21 | 22 | # Directory for instrumented libs generated by jscoverage/JSCover 23 | 24 | lib-cov 25 | 26 | # Coverage directory used by tools like istanbul 27 | 28 | coverage 29 | \*.lcov 30 | 31 | # nyc test coverage 32 | 33 | .nyc_output 34 | 35 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 36 | 37 | .grunt 38 | 39 | # Bower dependency directory (https://bower.io/) 40 | 41 | bower_components 42 | 43 | # node-waf configuration 44 | 45 | .lock-wscript 46 | 47 | # Compiled binary addons (https://nodejs.org/api/addons.html) 48 | 49 | build/Release 50 | 51 | # Dependency directories 52 | 53 | node_modules/ 54 | jspm_packages/ 55 | 56 | # Snowpack dependency directory (https://snowpack.dev/) 57 | 58 | web_modules/ 59 | 60 | # TypeScript cache 61 | 62 | \*.tsbuildinfo 63 | 64 | # Optional npm cache directory 65 | 66 | .npm 67 | 68 | # Optional eslint cache 69 | 70 | .eslintcache 71 | 72 | # Optional stylelint cache 73 | 74 | .stylelintcache 75 | 76 | # Microbundle cache 77 | 78 | .rpt2_cache/ 79 | .rts2_cache_cjs/ 80 | .rts2_cache_es/ 81 | .rts2_cache_umd/ 82 | 83 | # Optional REPL history 84 | 85 | .node_repl_history 86 | 87 | # Output of 'npm pack' 88 | 89 | \*.tgz 90 | 91 | # Yarn Integrity file 92 | 93 | .yarn-integrity 94 | 95 | # dotenv environment variable files 96 | 97 | .env 98 | .env.development.local 99 | .env.test.local 100 | .env.production.local 101 | .env.local 102 | 103 | # parcel-bundler cache (https://parceljs.org/) 104 | 105 | .cache 106 | .parcel-cache 107 | 108 | # Next.js build output 109 | 110 | .next 111 | out 112 | 113 | # Nuxt.js build / generate output 114 | 115 | .nuxt 116 | dist 117 | 118 | # Gatsby files 119 | 120 | .cache/ 121 | 122 | # Comment in the public line in if your project uses Gatsby and not Next.js 123 | 124 | # https://nextjs.org/blog/next-9-1#public-directory-support 125 | 126 | # public 127 | 128 | # vuepress build output 129 | 130 | .vuepress/dist 131 | 132 | # vuepress v2.x temp and cache directory 133 | 134 | .temp 135 | .cache 136 | 137 | # Docusaurus cache and generated files 138 | 139 | .docusaurus 140 | 141 | # Serverless directories 142 | 143 | .serverless/ 144 | 145 | # FuseBox cache 146 | 147 | .fusebox/ 148 | 149 | # DynamoDB Local files 150 | 151 | .dynamodb/ 152 | 153 | # TernJS port file 154 | 155 | .tern-port 156 | 157 | # Stores VSCode versions used for testing VSCode extensions 158 | 159 | .vscode-test 160 | 161 | # yarn v2 162 | 163 | .yarn/cache 164 | .yarn/unplugged 165 | .yarn/build-state.yml 166 | .yarn/install-state.gz 167 | .pnp.\* 168 | 169 | # wrangler project 170 | 171 | .wrangler/ 172 | 173 | .dev.vars 174 | wrangler.toml 175 | -------------------------------------------------------------------------------- /apis/cloudflare/README.md: -------------------------------------------------------------------------------- 1 | # Braintrust AI Proxy (Cloudflare) 2 | 3 | This directory contains an implementation of the Braintrust AI Proxy that runs on 4 | [Cloudflare Workers](https://workers.cloudflare.com/). Because of their global network, 5 | you get the benefit of low latency and can scale up to millions of users. 6 | 7 | ## Deploying 8 | 9 | You'll need the following prerequisites: 10 | 11 | - A [Cloudflare account](https://www.cloudflare.com/) 12 | - [pnpm](https://pnpm.io/installation) 13 | 14 | By default, the worker uses the local `@braintrust/proxy` package, which you need to build. From the 15 | [repository's root](../..), run: 16 | 17 | ```bash copy 18 | pnpm install 19 | pnpm build 20 | ``` 21 | 22 | Then, you return to this directory and setup a KV namespace for the proxy: 23 | 24 | ```bash copy 25 | wrangler kv:namespace create ai_proxy 26 | ``` 27 | 28 | Record the ID of the namespace that you just created. Then, copy `wrangler-template.toml` to 29 | `wrangler.toml` and replace `` with the ID of the namespace. 30 | 31 | Finally, you can run the worker locally with 32 | 33 | ```bash copy 34 | npx wrangler dev 35 | ``` 36 | 37 | or deploy it to Cloudflare with 38 | 39 | ```bash copy 40 | npx wrangler deploy 41 | ``` 42 | 43 | ## Integrating into your own project 44 | 45 | If you'd like to use the proxy in your own project, that's fine too! Simply install the 46 | `@braintrust/proxy` package with your favorite package manager, and follow/customize the 47 | implementation in [`proxy.ts`](./src/proxy.ts). 48 | -------------------------------------------------------------------------------- /apis/cloudflare/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@braintrust/ai-proxy-wrangler", 3 | "version": "0.0.0", 4 | "private": true, 5 | "main": "./dist/lib.mjs", 6 | "scripts": { 7 | "deploy": "wrangler deploy", 8 | "dev": "wrangler dev --port 8787 --inspector-port 9299", 9 | "start": "wrangler dev", 10 | "watch": "tsup --watch --dts", 11 | "build": "tsup --clean --dts" 12 | }, 13 | "devDependencies": { 14 | "@cloudflare/workers-types": "^4.20241022.0", 15 | "itty-router": "^3.0.12", 16 | "tsup": "^8.4.0", 17 | "typescript": "^5.0.4", 18 | "wrangler": "^3.107.3" 19 | }, 20 | "dependencies": { 21 | "@braintrust/core": "^0.0.85", 22 | "braintrust": "^0.0.197", 23 | "@braintrust/proxy": "workspace:*", 24 | "@openai/realtime-api-beta": "github:openai/openai-realtime-api-beta#cd8a9251dcfb0cba0d7b0501e9ff36c915f5090f", 25 | "@opentelemetry/resources": "^1.18.1", 26 | "@opentelemetry/sdk-metrics": "^1.18.1", 27 | "dotenv": "^16.3.1", 28 | "zod": "3.22.4" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /apis/cloudflare/src/env.ts: -------------------------------------------------------------------------------- 1 | declare global { 2 | interface Env { 3 | ai_proxy: KVNamespace; 4 | BRAINTRUST_APP_URL: string; 5 | DISABLE_METRICS?: boolean; 6 | PROMETHEUS_SCRAPE_USER?: string; 7 | PROMETHEUS_SCRAPE_PASSWORD?: string; 8 | WHITELISTED_ORIGINS?: string; 9 | } 10 | } 11 | 12 | export function braintrustAppUrl(env: Env) { 13 | return new URL(env.BRAINTRUST_APP_URL || "https://www.braintrust.dev"); 14 | } 15 | -------------------------------------------------------------------------------- /apis/cloudflare/src/index.ts: -------------------------------------------------------------------------------- 1 | import { 2 | proxyV1Prefixes, 3 | handleProxyV1, 4 | handlePrometheusScrape, 5 | originWhitelist, 6 | } from "./proxy"; 7 | import { getCorsHeaders } from "@braintrust/proxy/edge"; 8 | export { PrometheusMetricAggregator } from "./metric-aggregator"; 9 | 10 | // The fetch handler is invoked when this worker receives a HTTP(S) request 11 | // and should return a Response (optionally wrapped in a Promise) 12 | // eslint-disable-next-line import/no-anonymous-default-export 13 | export default { 14 | async fetch( 15 | request: Request, 16 | env: Env, 17 | ctx: ExecutionContext, 18 | ): Promise { 19 | const url = new URL(request.url); 20 | console.log("URL =", url.pathname); 21 | if (["/", "/v1/proxy", "/v1/proxy/"].includes(url.pathname)) { 22 | return new Response("Hello World!", { 23 | status: 200, 24 | headers: getCorsHeaders(request, originWhitelist(env)), 25 | }); 26 | } 27 | if (url.pathname === "/metrics") { 28 | return handlePrometheusScrape(request, env, ctx); 29 | } 30 | for (const prefix of proxyV1Prefixes) { 31 | if (url.pathname.startsWith(prefix)) { 32 | return handleProxyV1(request, prefix, env, ctx); 33 | } 34 | } 35 | return new Response("Not found", { 36 | status: 404, 37 | headers: { "Content-Type": "text/plain" }, 38 | }); 39 | }, 40 | }; 41 | -------------------------------------------------------------------------------- /apis/cloudflare/src/lib.ts: -------------------------------------------------------------------------------- 1 | export * from "./proxy"; 2 | export * from "./metric-aggregator"; 3 | -------------------------------------------------------------------------------- /apis/cloudflare/src/metric-aggregator.ts: -------------------------------------------------------------------------------- 1 | import { MetricData, ResourceMetrics } from "@opentelemetry/sdk-metrics"; 2 | import { Resource } from "@opentelemetry/resources"; 3 | import { PrometheusSerializer } from "@braintrust/proxy/src/PrometheusSerializer"; 4 | import { aggregateMetrics, prometheusSerialize } from "@braintrust/proxy"; 5 | 6 | declare global { 7 | interface Env { 8 | METRICS_AGGREGATOR: DurableObjectNamespace; 9 | // The number of durable objects to use for metrics aggregation. Each shard (times the number 10 | // of other distinct sets of labels) works out to one Prometheus timeseries. Shards allow us to 11 | // essentially aggregate _across_ workers. 12 | METRICS_SHARDS?: number; 13 | // TODO: If a metric doesn't show up for this many seconds, it'll be deleted from the store. We detect 14 | // this at read time. 15 | METRICS_TTL?: number; 16 | } 17 | } 18 | 19 | export class PrometheusMetricAggregator { 20 | state: DurableObjectState; 21 | constructor(state: DurableObjectState, env: Env) { 22 | this.state = state; 23 | } 24 | 25 | async fetch(request: Request): Promise { 26 | if (request.method !== "POST") { 27 | return new Response("Only POST is supported", { status: 405 }); 28 | } 29 | const url = new URL(request.url); 30 | if (url.pathname === "/push") { 31 | return await this.handlePush(request); 32 | } else if (url.pathname === "/metrics") { 33 | return await this.handlePromScrape(request); 34 | } else { 35 | return new Response("Not found", { 36 | status: 404, 37 | }); 38 | } 39 | } 40 | 41 | async handlePush(request: Request): Promise { 42 | const metrics = (await request.json()) as ResourceMetrics; 43 | try { 44 | await aggregateMetrics( 45 | metrics, 46 | async (key: string) => 47 | (await this.state.storage.get(key)) || null, 48 | (key: string, value: MetricData) => this.state.storage.put(key, value), 49 | ); 50 | } catch (e) { 51 | console.error("Error aggregating metrics", e); 52 | return new Response("Error aggregating metrics", { status: 500 }); 53 | } 54 | 55 | return new Response(null, { status: 204 }); 56 | } 57 | 58 | async handlePromScrape(request: Request): Promise { 59 | const resource = Resource.default(); 60 | resource.attributes["service"] = "braintrust-proxy-cloudflare"; 61 | 62 | const metrics = await this.state.storage.list({ 63 | prefix: "otel_metric_", 64 | }); 65 | 66 | const resourceMetrics: ResourceMetrics = { 67 | resource, 68 | scopeMetrics: [ 69 | { 70 | scope: { 71 | name: "cloudflare-metric-aggregator", 72 | }, 73 | // metrics is a map. can you create a list of its values 74 | metrics: Array.from(metrics.values()).map((m) => ({ 75 | ...m, 76 | dataPoints: m.dataPoints.map((dp) => ({ 77 | ...dp, 78 | attributes: { 79 | ...dp.attributes, 80 | metric_shard: this.state.id.toString(), 81 | }, 82 | })), 83 | })) as MetricData[], 84 | }, 85 | ], 86 | }; 87 | 88 | return new Response(prometheusSerialize(resourceMetrics), { 89 | headers: { 90 | "Content-Type": "text/plain", 91 | }, 92 | status: 200, 93 | }); 94 | } 95 | 96 | static numShards(env: Env): number { 97 | return env.METRICS_SHARDS ?? 2; 98 | } 99 | 100 | static metricsTTL(env: Env): number { 101 | return env.METRICS_TTL ?? 24 * 7 * 3600; 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /apis/cloudflare/src/proxy.ts: -------------------------------------------------------------------------------- 1 | import { 2 | EdgeProxyV1, 3 | FlushingExporter, 4 | ProxyOpts, 5 | makeFetchApiSecrets, 6 | encryptedGet, 7 | } from "@braintrust/proxy/edge"; 8 | import { 9 | NOOP_METER_PROVIDER, 10 | SpanLogger, 11 | initMetrics, 12 | } from "@braintrust/proxy"; 13 | import { PrometheusMetricAggregator } from "./metric-aggregator"; 14 | import { handleRealtimeProxy } from "./realtime"; 15 | import { braintrustAppUrl } from "./env"; 16 | import { Span, startSpan } from "braintrust"; 17 | import { BT_PARENT, resolveParentHeader } from "@braintrust/core"; 18 | import { cachedLogin, makeProxySpanLogger } from "./tracing"; 19 | 20 | export const proxyV1Prefixes = ["/v1/proxy", "/v1"]; 21 | 22 | function apiCacheKey(key: string) { 23 | return `http://apikey.cache/${encodeURIComponent(key)}.jpg`; 24 | } 25 | 26 | export function originWhitelist(env: Env) { 27 | return env.WHITELISTED_ORIGINS && env.WHITELISTED_ORIGINS.length > 0 28 | ? env.WHITELISTED_ORIGINS.split(",") 29 | .map((x) => x.trim()) 30 | .filter((x) => x) 31 | : undefined; 32 | } 33 | 34 | export async function handleProxyV1( 35 | request: Request, 36 | proxyV1Prefix: string, 37 | env: Env, 38 | ctx: ExecutionContext, 39 | ): Promise { 40 | let meterProvider = undefined; 41 | if (!env.DISABLE_METRICS) { 42 | const metricShard = Math.floor( 43 | Math.random() * PrometheusMetricAggregator.numShards(env), 44 | ); 45 | const aggregator = env.METRICS_AGGREGATOR.get( 46 | env.METRICS_AGGREGATOR.idFromName(metricShard.toString()), 47 | ); 48 | const metricAggURL = new URL(request.url); 49 | metricAggURL.pathname = "/push"; 50 | 51 | meterProvider = initMetrics( 52 | new FlushingExporter((resourceMetrics) => 53 | aggregator.fetch(metricAggURL, { 54 | method: "POST", 55 | headers: { 56 | "Content-Type": "application/json", 57 | }, 58 | body: JSON.stringify(resourceMetrics), 59 | }), 60 | ), 61 | ); 62 | } 63 | 64 | const meter = (meterProvider || NOOP_METER_PROVIDER).getMeter( 65 | "cloudflare-metrics", 66 | ); 67 | 68 | const whitelist = originWhitelist(env); 69 | 70 | const cacheGetLatency = meter.createHistogram("results_cache_get_latency"); 71 | const cacheSetLatency = meter.createHistogram("results_cache_set_latency"); 72 | 73 | const cache = await caches.open("apikey:cache"); 74 | 75 | const credentialsCache = { 76 | async get(key: string): Promise { 77 | const response = await cache.match(apiCacheKey(key)); 78 | if (response) { 79 | return (await response.json()) as T; 80 | } else { 81 | return null; 82 | } 83 | }, 84 | async set(key: string, value: T, { ttl }: { ttl?: number }) { 85 | await cache.put( 86 | apiCacheKey(key), 87 | new Response(JSON.stringify(value), { 88 | headers: { 89 | "Cache-Control": `public${ttl ? `, max-age=${ttl}}` : ""}`, 90 | }, 91 | }), 92 | ); 93 | }, 94 | }; 95 | 96 | let spanLogger: SpanLogger | undefined; 97 | let span: Span | undefined; 98 | const parentHeader = request.headers.get(BT_PARENT); 99 | if (parentHeader) { 100 | let parent; 101 | try { 102 | parent = resolveParentHeader(parentHeader); 103 | } catch (e) { 104 | return new Response( 105 | `Invalid parent header '${parentHeader}': ${e instanceof Error ? e.message : String(e)}`, 106 | { status: 400 }, 107 | ); 108 | } 109 | span = startSpan({ 110 | state: await cachedLogin({ 111 | appUrl: braintrustAppUrl(env).toString(), 112 | headers: request.headers, 113 | cache: credentialsCache, 114 | }), 115 | type: "llm", 116 | name: "LLM", 117 | parent: parent.toStr(), 118 | }); 119 | spanLogger = makeProxySpanLogger(span, ctx.waitUntil.bind(ctx)); 120 | } 121 | 122 | const opts: ProxyOpts = { 123 | getRelativeURL(request: Request): string { 124 | return new URL(request.url).pathname.slice(proxyV1Prefix.length); 125 | }, 126 | cors: true, 127 | credentialsCache, 128 | completionsCache: { 129 | get: async (key) => { 130 | const start = performance.now(); 131 | const ret = await env.ai_proxy.get(key); 132 | const end = performance.now(); 133 | cacheGetLatency.record(end - start); 134 | if (ret) { 135 | return JSON.parse(ret); 136 | } else { 137 | return null; 138 | } 139 | }, 140 | set: async (key, value, { ttl }: { ttl?: number }) => { 141 | const start = performance.now(); 142 | await env.ai_proxy.put(key, JSON.stringify(value), { 143 | expirationTtl: ttl, 144 | }); 145 | const end = performance.now(); 146 | cacheSetLatency.record(end - start); 147 | }, 148 | }, 149 | braintrustApiUrl: braintrustAppUrl(env).toString(), 150 | meterProvider, 151 | whitelist, 152 | spanLogger, 153 | }; 154 | 155 | const url = new URL(request.url); 156 | if (url.pathname === `${proxyV1Prefix}/realtime`) { 157 | return await handleRealtimeProxy({ 158 | request, 159 | env, 160 | ctx, 161 | cacheGet: async (encryptionKey: string, key: string) => { 162 | if (!opts.completionsCache) { 163 | return null; 164 | } 165 | return ( 166 | (await encryptedGet(opts.completionsCache, encryptionKey, key)) ?? 167 | null 168 | ); 169 | }, 170 | getApiSecrets: makeFetchApiSecrets({ ctx, opts }), 171 | }); 172 | } 173 | 174 | return EdgeProxyV1(opts)(request, ctx); 175 | } 176 | 177 | export async function handlePrometheusScrape( 178 | request: Request, 179 | env: Env, 180 | ctx: ExecutionContext, 181 | ): Promise { 182 | if (env.DISABLE_METRICS) { 183 | return new Response("Metrics disabled", { status: 403 }); 184 | } 185 | if ( 186 | env.PROMETHEUS_SCRAPE_USER !== undefined || 187 | env.PROMETHEUS_SCRAPE_PASSWORD !== undefined 188 | ) { 189 | const unauthorized = new Response("Unauthorized", { 190 | status: 401, 191 | headers: { 192 | "WWW-Authenticate": 'Basic realm="Braintrust Proxy Metrics"', 193 | }, 194 | }); 195 | 196 | const auth = request.headers.get("Authorization"); 197 | if (!auth || auth.indexOf("Basic ") !== 0) { 198 | return unauthorized; 199 | } 200 | 201 | const userPass = atob(auth.slice("Basic ".length)).split(":"); 202 | if ( 203 | userPass[0] !== env.PROMETHEUS_SCRAPE_USER || 204 | userPass[1] !== env.PROMETHEUS_SCRAPE_PASSWORD 205 | ) { 206 | return unauthorized; 207 | } 208 | } 209 | // Array from 0 ... numShards 210 | const shards = await Promise.all( 211 | Array.from( 212 | { length: PrometheusMetricAggregator.numShards(env) }, 213 | async (_, i) => { 214 | const aggregator = env.METRICS_AGGREGATOR.get( 215 | env.METRICS_AGGREGATOR.idFromName(i.toString()), 216 | ); 217 | const url = new URL(request.url); 218 | url.pathname = "/metrics"; 219 | const resp = await aggregator.fetch(url, { 220 | method: "POST", 221 | }); 222 | if (resp.status !== 200) { 223 | throw new Error( 224 | `Unexpected status code ${resp.status} ${ 225 | resp.statusText 226 | }: ${await resp.text()}`, 227 | ); 228 | } else { 229 | return await resp.text(); 230 | } 231 | }, 232 | ), 233 | ); 234 | return new Response(shards.join("\n"), { 235 | headers: { 236 | "Content-Type": "text/plain", 237 | }, 238 | }); 239 | } 240 | -------------------------------------------------------------------------------- /apis/cloudflare/src/realtime.ts: -------------------------------------------------------------------------------- 1 | import { RealtimeAPI } from "@openai/realtime-api-beta"; 2 | import { APISecret, ProxyLoggingParam } from "@braintrust/proxy/schema"; 3 | import { ORG_NAME_HEADER } from "@braintrust/proxy"; 4 | import { 5 | isTempCredential, 6 | verifyTempCredentials, 7 | } from "@braintrust/proxy/utils"; 8 | import { OpenAiRealtimeLogger } from "./realtime-logger"; 9 | import { braintrustAppUrl } from "./env"; 10 | 11 | const MODEL = "gpt-4o-realtime-preview-2024-10-01"; 12 | 13 | export async function handleRealtimeProxy({ 14 | request, 15 | env, 16 | ctx, 17 | cacheGet, 18 | getApiSecrets, 19 | }: { 20 | request: Request; 21 | env: Env; 22 | ctx: ExecutionContext; 23 | cacheGet: (encryptionKey: string, key: string) => Promise; 24 | getApiSecrets: ( 25 | useCache: boolean, 26 | authToken: string, 27 | model: string | null, 28 | org_name?: string, 29 | ) => Promise; 30 | }): Promise { 31 | const upgradeHeader = request.headers.get("Upgrade"); 32 | if (!upgradeHeader || upgradeHeader !== "websocket") { 33 | return new Response("Expected Upgrade: websocket", { status: 426 }); 34 | } 35 | 36 | const webSocketPair = new WebSocketPair(); 37 | const [client, server] = Object.values(webSocketPair); 38 | 39 | let realtimeApi: RealtimeAPI | null = null; 40 | 41 | server.accept(); 42 | 43 | const responseHeaders = new Headers(); 44 | const protocolHeader = request.headers.get("Sec-WebSocket-Protocol"); 45 | let apiKey: string | undefined; 46 | if (protocolHeader) { 47 | const requestedProtocols = protocolHeader.split(",").map((p) => p.trim()); 48 | if (requestedProtocols.includes("realtime")) { 49 | // Not exactly sure why this protocol needs to be accepted. 50 | responseHeaders.set("Sec-WebSocket-Protocol", "realtime"); 51 | } 52 | 53 | for (const protocol of requestedProtocols) { 54 | if (protocol.startsWith("openai-insecure-api-key.")) { 55 | const parsedApiKey = protocol 56 | .slice("openai-insecure-api-key.".length) 57 | .trim(); 58 | if (parsedApiKey.length > 0 && parsedApiKey !== "null") { 59 | apiKey = parsedApiKey; 60 | } 61 | } 62 | } 63 | } 64 | 65 | const url = new URL(request.url); 66 | let model = url.searchParams.get("model") ?? MODEL; 67 | 68 | if (!apiKey) { 69 | return new Response("Missing API key", { status: 401 }); 70 | } 71 | 72 | let loggingParams: ProxyLoggingParam | undefined; 73 | let secrets: APISecret[] = []; 74 | 75 | // First, try to use temp credentials, because then we'll get access to the project name 76 | // for logging. 77 | if (isTempCredential(apiKey)) { 78 | const { credentialCacheValue, jwtPayload } = await verifyTempCredentials({ 79 | jwt: apiKey, 80 | cacheGet, 81 | }); 82 | // Unwrap the API key here to avoid a duplicate call to 83 | // `verifyTempCredentials` inside `getApiSecrets`. That call will use Redis 84 | // which is not available in Cloudflare. 85 | apiKey = credentialCacheValue.authToken; 86 | loggingParams = jwtPayload.bt.logging ?? undefined; 87 | model = jwtPayload.bt.model ?? MODEL; 88 | } 89 | 90 | const orgName = request.headers.get(ORG_NAME_HEADER) ?? undefined; 91 | 92 | secrets = await getApiSecrets(true, apiKey, model, orgName); 93 | 94 | if (secrets.length === 0) { 95 | return new Response("No secrets found", { status: 401 }); 96 | } 97 | 98 | const realtimeLogger: OpenAiRealtimeLogger | undefined = 99 | loggingParams && 100 | new OpenAiRealtimeLogger({ 101 | apiKey, 102 | appUrl: braintrustAppUrl(env).toString(), 103 | loggingParams, 104 | }); 105 | 106 | // Create RealtimeClient 107 | try { 108 | console.log("Creating RealtimeApi"); 109 | realtimeApi = new RealtimeAPI({ apiKey: secrets[0].secret }); 110 | } catch (e) { 111 | console.error(`Error connecting to OpenAI: ${e}`); 112 | server.close(); 113 | return new Response("Error connecting to OpenAI", { status: 502 }); 114 | } 115 | 116 | // Relay: OpenAI Realtime API Event -> Client 117 | realtimeApi.on("server.*", (event: { type: string }) => { 118 | server.send(JSON.stringify(event)); 119 | try { 120 | realtimeLogger?.handleMessageServer(event); 121 | } catch (e) { 122 | console.warn(`Error logging server event: ${e} ${event.type}`); 123 | } 124 | }); 125 | 126 | realtimeApi.on("close", () => { 127 | console.log("Closing server-side because I received a close event"); 128 | server.close(); 129 | if (realtimeLogger) { 130 | ctx.waitUntil(realtimeLogger.close()); 131 | } 132 | }); 133 | 134 | // Relay: Client -> OpenAI Realtime API Event 135 | const messageQueue: string[] = []; 136 | 137 | server.addEventListener("message", (event: MessageEvent) => { 138 | const messageHandler = (data: string) => { 139 | try { 140 | const parsedEvent = JSON.parse(data); 141 | realtimeApi.send(parsedEvent.type, parsedEvent); 142 | try { 143 | realtimeLogger?.handleMessageClient(parsedEvent); 144 | } catch (e) { 145 | console.warn(`Error logging client event: ${e} ${parsedEvent.type}`); 146 | } 147 | } catch (e) { 148 | console.error(`Error parsing event from client: ${data}`); 149 | } 150 | }; 151 | 152 | const data = 153 | typeof event.data === "string" ? event.data : event.data.toString(); 154 | if (!realtimeApi.isConnected()) { 155 | messageQueue.push(data); 156 | } else { 157 | messageHandler(data); 158 | } 159 | }); 160 | 161 | server.addEventListener("close", () => { 162 | console.log("Closing server-side because the client closed the connection"); 163 | realtimeApi.disconnect(); 164 | if (realtimeLogger) { 165 | ctx.waitUntil(realtimeLogger.close()); 166 | } 167 | }); 168 | 169 | // Connect to OpenAI Realtime API. 170 | try { 171 | console.log(`Connecting to OpenAI...`); 172 | await realtimeApi.connect(); 173 | console.log(`Connected to OpenAI successfully!`); 174 | while (messageQueue.length) { 175 | const message = messageQueue.shift(); 176 | if (message) { 177 | server.send(message); 178 | } 179 | } 180 | } catch (e) { 181 | if (e instanceof Error) { 182 | console.error(`Error connecting to OpenAI: ${e.message}`); 183 | } else { 184 | console.error(`Error connecting to OpenAI: ${e}`); 185 | } 186 | return new Response("Error connecting to OpenAI", { status: 502 }); 187 | } 188 | 189 | return new Response(null, { 190 | status: 101, 191 | headers: responseHeaders, 192 | webSocket: client, 193 | }); 194 | } 195 | -------------------------------------------------------------------------------- /apis/cloudflare/src/tracing.ts: -------------------------------------------------------------------------------- 1 | import { 2 | ORG_NAME_HEADER, 3 | SpanLogger, 4 | isObject, 5 | parseAuthHeader, 6 | } from "@braintrust/proxy"; 7 | import { Attachment, BraintrustState, loginToState, Span } from "braintrust"; 8 | import { isArray, SpanComponentsV3, SpanObjectTypeV3 } from "@braintrust/core"; 9 | import { base64ToArrayBuffer } from "@braintrust/proxy/utils"; 10 | import { 11 | digestMessage, 12 | encryptedGet, 13 | encryptedPut, 14 | type Cache as EdgeCache, 15 | } from "@braintrust/proxy/edge"; 16 | 17 | export function makeProxySpanLogger( 18 | span: Span, 19 | waitUntil: (promise: Promise) => void, 20 | ): SpanLogger { 21 | return { 22 | log: (args) => { 23 | span.log(replacePayloadWithAttachments(args, span.state())); 24 | waitUntil(span.flush()); 25 | }, 26 | end: span.end.bind(span), 27 | setName(name) { 28 | span.setAttributes({ name }); 29 | }, 30 | reportProgress() { 31 | return; 32 | }, 33 | }; 34 | } 35 | export function replacePayloadWithAttachments( 36 | data: T, 37 | state: BraintrustState | undefined, 38 | ): T { 39 | return replacePayloadWithAttachmentsInner(data, state) as T; 40 | } 41 | 42 | function replacePayloadWithAttachmentsInner( 43 | data: unknown, 44 | state: BraintrustState | undefined, 45 | ): unknown { 46 | if (isArray(data)) { 47 | return data.map((item) => replacePayloadWithAttachmentsInner(item, state)); 48 | } else if (isObject(data)) { 49 | return Object.fromEntries( 50 | Object.entries(data).map(([key, value]) => [ 51 | key, 52 | replacePayloadWithAttachmentsInner(value, state), 53 | ]), 54 | ); 55 | } else if (typeof data === "string") { 56 | if (isBase64Image(data)) { 57 | const { mimeType, data: arrayBuffer } = getBase64Parts(data); 58 | const filename = `file.${mimeType.split("/")[1]}`; 59 | return new Attachment({ 60 | data: arrayBuffer, 61 | contentType: mimeType, 62 | filename, 63 | state, 64 | }); 65 | } else { 66 | return data; 67 | } 68 | } else { 69 | return data; 70 | } 71 | } 72 | 73 | const base64ImagePattern = 74 | /^data:image\/[a-zA-Z]+;base64,[A-Za-z0-9+/]+={0,2}$/; 75 | export function isBase64Image(s: string): boolean { 76 | // Avoid unnecessary (slower) pattern matching 77 | if (!s.startsWith("data:")) { 78 | return false; 79 | } 80 | 81 | return base64ImagePattern.test(s); 82 | } 83 | // Being as specific as possible about allowable characters and avoiding greedy matching 84 | // helps avoid catastrophic backtracking: https://github.com/braintrustdata/braintrust/pull/4831 85 | const base64ContentTypePattern = 86 | /^data:([a-zA-Z0-9]+\/[a-zA-Z0-9+.-]+);base64,/; 87 | export function getBase64Parts(s: string): { 88 | mimeType: string; 89 | data: ArrayBuffer; 90 | } { 91 | const parts = s.match(base64ContentTypePattern); 92 | if (!parts) { 93 | throw new Error("Invalid base64 image"); 94 | } 95 | const mimeType = parts[1]; 96 | const data = s.slice(`data:${mimeType};base64,`.length); 97 | return { mimeType, data: base64ToArrayBuffer(data) }; 98 | } 99 | 100 | export async function cachedLogin({ 101 | appUrl, 102 | headers, 103 | cache, 104 | }: { 105 | headers: Headers; 106 | appUrl: string; 107 | cache: EdgeCache; 108 | }) { 109 | const orgName = headers.get(ORG_NAME_HEADER) ?? undefined; 110 | const token = 111 | parseAuthHeader({ 112 | authorization: headers.get("authorization") ?? undefined, 113 | }) ?? undefined; 114 | 115 | const encryptionKey = await digestMessage( 116 | JSON.stringify({ token: token ?? "anon", orgName }), 117 | ); 118 | 119 | let state: BraintrustState; 120 | const stateResp = await encryptedGet(cache, encryptionKey, encryptionKey); 121 | if (stateResp) { 122 | state = BraintrustState.deserialize(JSON.parse(stateResp), { 123 | noExitFlush: true, 124 | }); 125 | } else { 126 | state = await loginToState({ 127 | apiKey: 128 | parseAuthHeader({ 129 | authorization: headers.get("authorization") ?? undefined, 130 | }) ?? undefined, 131 | // If the app URL is explicitly set to an env var, it's meant to override 132 | // the origin. 133 | appUrl: appUrl, 134 | orgName, 135 | noExitFlush: true, 136 | }); 137 | 138 | encryptedPut( 139 | cache, 140 | encryptionKey, 141 | encryptionKey, 142 | JSON.stringify(state.serialize()), 143 | { 144 | ttl: 60, 145 | }, 146 | ).catch((e) => { 147 | console.error("Error while caching login credentials", e); 148 | }); 149 | } 150 | 151 | return state; 152 | } 153 | -------------------------------------------------------------------------------- /apis/cloudflare/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | /* Visit https://aka.ms/tsconfig.json to read more about this file */ 4 | 5 | /* Projects */ 6 | // "incremental": true, /* Enable incremental compilation */ 7 | // "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */ 8 | // "tsBuildInfoFile": "./", /* Specify the folder for .tsbuildinfo incremental compilation files. */ 9 | // "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects */ 10 | // "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */ 11 | // "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */ 12 | 13 | /* Language and Environment */ 14 | "target": "es2021" /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */, 15 | "lib": ["es2021"] /* Specify a set of bundled library declaration files that describe the target runtime environment. */, 16 | "jsx": "react" /* Specify what JSX code is generated. */, 17 | // "experimentalDecorators": true, /* Enable experimental support for TC39 stage 2 draft decorators. */ 18 | // "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */ 19 | // "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h' */ 20 | // "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */ 21 | // "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using `jsx: react-jsx*`.` */ 22 | // "reactNamespace": "", /* Specify the object invoked for `createElement`. This only applies when targeting `react` JSX emit. */ 23 | // "noLib": true, /* Disable including any library files, including the default lib.d.ts. */ 24 | // "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */ 25 | 26 | /* Modules */ 27 | "module": "es2022" /* Specify what module code is generated. */, 28 | // "rootDir": "./", /* Specify the root folder within your source files. */ 29 | "moduleResolution": "node" /* Specify how TypeScript looks up a file from a given module specifier. */, 30 | // "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */ 31 | // "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */ 32 | // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */ 33 | // "typeRoots": [], /* Specify multiple folders that act like `./node_modules/@types`. */ 34 | "types": ["@cloudflare/workers-types"] /* Specify type package names to be included without being referenced in a source file. */, 35 | // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */ 36 | "resolveJsonModule": true /* Enable importing .json files */, 37 | // "noResolve": true, /* Disallow `import`s, `require`s or ``s from expanding the number of files TypeScript should add to a project. */ 38 | 39 | /* JavaScript Support */ 40 | "allowJs": true /* Allow JavaScript files to be a part of your program. Use the `checkJS` option to get errors from these files. */, 41 | "checkJs": false /* Enable error reporting in type-checked JavaScript files. */, 42 | // "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from `node_modules`. Only applicable with `allowJs`. */ 43 | 44 | /* Emit */ 45 | // "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */ 46 | // "declarationMap": true, /* Create sourcemaps for d.ts files. */ 47 | // "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */ 48 | // "sourceMap": true, /* Create source map files for emitted JavaScript files. */ 49 | // "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If `declaration` is true, also designates a file that bundles all .d.ts output. */ 50 | // "outDir": "./", /* Specify an output folder for all emitted files. */ 51 | // "removeComments": true, /* Disable emitting comments. */ 52 | "noEmit": true /* Disable emitting files from a compilation. */, 53 | // "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */ 54 | // "importsNotUsedAsValues": "remove", /* Specify emit/checking behavior for imports that are only used for types */ 55 | // "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */ 56 | // "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */ 57 | // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */ 58 | // "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */ 59 | // "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */ 60 | // "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */ 61 | // "newLine": "crlf", /* Set the newline character for emitting files. */ 62 | // "stripInternal": true, /* Disable emitting declarations that have `@internal` in their JSDoc comments. */ 63 | // "noEmitHelpers": true, /* Disable generating custom helper functions like `__extends` in compiled output. */ 64 | // "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */ 65 | // "preserveConstEnums": true, /* Disable erasing `const enum` declarations in generated code. */ 66 | // "declarationDir": "./", /* Specify the output directory for generated declaration files. */ 67 | // "preserveValueImports": true, /* Preserve unused imported values in the JavaScript output that would otherwise be removed. */ 68 | 69 | /* Interop Constraints */ 70 | "isolatedModules": true /* Ensure that each file can be safely transpiled without relying on other imports. */, 71 | "allowSyntheticDefaultImports": true /* Allow 'import x from y' when a module doesn't have a default export. */, 72 | // "esModuleInterop": true /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables `allowSyntheticDefaultImports` for type compatibility. */, 73 | // "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */ 74 | "forceConsistentCasingInFileNames": true /* Ensure that casing is correct in imports. */, 75 | 76 | /* Type Checking */ 77 | "strict": true /* Enable all strict type-checking options. */, 78 | // "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied `any` type.. */ 79 | // "strictNullChecks": true, /* When type checking, take into account `null` and `undefined`. */ 80 | // "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */ 81 | // "strictBindCallApply": true, /* Check that the arguments for `bind`, `call`, and `apply` methods match the original function. */ 82 | // "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */ 83 | // "noImplicitThis": true, /* Enable error reporting when `this` is given the type `any`. */ 84 | // "useUnknownInCatchVariables": true, /* Type catch clause variables as 'unknown' instead of 'any'. */ 85 | // "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */ 86 | // "noUnusedLocals": true, /* Enable error reporting when a local variables aren't read. */ 87 | // "noUnusedParameters": true, /* Raise an error when a function parameter isn't read */ 88 | // "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */ 89 | // "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */ 90 | // "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */ 91 | // "noUncheckedIndexedAccess": true, /* Include 'undefined' in index signature results */ 92 | // "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */ 93 | // "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type */ 94 | // "allowUnusedLabels": true, /* Disable error reporting for unused labels. */ 95 | // "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */ 96 | 97 | /* Completeness */ 98 | // "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */ 99 | "skipLibCheck": true /* Skip type checking all .d.ts files. */ 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /apis/cloudflare/tsup.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from "tsup"; 2 | 3 | // https://github.com/egoist/tsup/issues/840 discusses how there can 4 | // be an infinite loop bug with --watch, and we work around that by 5 | // calling build with --dts. 6 | export default defineConfig([ 7 | { 8 | entry: ["src/lib.ts"], 9 | format: ["esm"], 10 | outDir: "dist", 11 | }, 12 | ]); 13 | -------------------------------------------------------------------------------- /apis/cloudflare/worker-configuration.d.ts: -------------------------------------------------------------------------------- 1 | interface Env { 2 | // Example binding to KV. Learn more at https://developers.cloudflare.com/workers/runtime-apis/kv/ 3 | // MY_KV_NAMESPACE: KVNamespace; 4 | // 5 | // Example binding to Durable Object. Learn more at https://developers.cloudflare.com/workers/runtime-apis/durable-objects/ 6 | // MY_DURABLE_OBJECT: DurableObjectNamespace; 7 | // 8 | // Example binding to R2. Learn more at https://developers.cloudflare.com/workers/runtime-apis/r2/ 9 | // MY_BUCKET: R2Bucket; 10 | // 11 | // Example binding to a Service. Learn more at https://developers.cloudflare.com/workers/runtime-apis/service-bindings/ 12 | // MY_SERVICE: Fetcher; 13 | // 14 | // Example binding to a Queue. Learn more at https://developers.cloudflare.com/queues/javascript-apis/ 15 | // MY_QUEUE: Queue; 16 | } 17 | -------------------------------------------------------------------------------- /apis/cloudflare/wrangler-template.toml: -------------------------------------------------------------------------------- 1 | name = "proxy" 2 | main = "src/index.ts" 3 | compatibility_date = "2024-09-23" 4 | compatibility_flags = ["nodejs_compat_v2"] 5 | 6 | kv_namespaces = [ 7 | # Configure this id to map to the id returned from 8 | # wrangler kv:namespace create ai-proxy 9 | { binding = "ai_proxy", id = "" }, 10 | ] 11 | 12 | [durable_objects] 13 | bindings = [ 14 | { name = "METRICS_AGGREGATOR", class_name = "PrometheusMetricAggregator" }, 15 | ] 16 | 17 | [[migrations]] 18 | tag = "v1" # Should be unique for each entry 19 | new_classes = ["PrometheusMetricAggregator"] # Array of new classes 20 | 21 | # Variable bindings. These are arbitrary, plaintext strings (similar to environment variables) 22 | # Note: Use secrets to store sensitive data. 23 | # Docs: https://developers.cloudflare.com/workers/platform/environment-variables 24 | [vars] 25 | # You should not need to edit this 26 | BRAINTRUST_APP_URL = "https://www.braintrust.dev" 27 | PROMETHEUS_SCRAPE_USER="admin" 28 | PROMETHEUS_SCRAPE_PASSWORD="" 29 | 30 | [env.staging.vars] 31 | BRAINTRUST_APP_URL = "https://www.braintrust.dev" 32 | # These are not real credentials, just populated to suppress a wrangler warning. 33 | PROMETHEUS_SCRAPE_USER="admin" 34 | PROMETHEUS_SCRAPE_PASSWORD="password" 35 | 36 | [env.staging] 37 | kv_namespaces = [ 38 | # Configure this id to map to the id returned from 39 | # wrangler kv:namespace create ai-proxy 40 | { binding = "ai_proxy", id = "" }, 41 | ] 42 | 43 | [env.staging.observability] 44 | enabled = true 45 | head_sampling_rate = 1.0 # Sample 100% of staging logs. 46 | 47 | [env.staging.durable_objects] 48 | bindings = [ 49 | { name = "METRICS_AGGREGATOR", class_name = "PrometheusMetricAggregator" }, 50 | ] 51 | -------------------------------------------------------------------------------- /apis/node/README.md: -------------------------------------------------------------------------------- 1 | # Braintrust AI Proxy (Node, AWS Lambda) 2 | 3 | This directory contains an implementation of the Braintrust AI Proxy that runs on 4 | [Node.js](https://nodejs.org/) runtimes and can be bundled as an [Express server](https://expressjs.com/) 5 | or [AWS Lambda function](https://aws.amazon.com/blogs/compute/introducing-aws-lambda-response-streaming/). 6 | 7 | ## Building 8 | 9 | To build the proxy, you'll need to install [pnpm](https://pnpm.io/installation), and then from the 10 | [repository's root](../..), run: 11 | 12 | ```bash copy 13 | pnpm install 14 | pnpm build 15 | ``` 16 | 17 | ## Running locally (Express server) 18 | 19 | To run the proxy locally, you need to connect to a [Redis](https://redis.io) instance. The easiest way to 20 | run Redis locally is with [Docker](https://www.docker.com/). Once you have Docker installed, you can run 21 | ([full instructions](https://hub.docker.com/_/redis)) 22 | 23 | ```bash copy 24 | docker run --name some-redis -d redis 25 | ``` 26 | 27 | to run Redis on port 6379. Then, create a file named `.env.local` with the following contents: 28 | 29 | ```bash copy 30 | REDIS_HOST=127.0.0.1 31 | REDIS_PORT=6379 32 | ``` 33 | 34 | Finally, you can run the proxy with 35 | 36 | ```bash copy 37 | pnpm dev 38 | ``` 39 | 40 | ## Running on AWS Lambda 41 | 42 | To run on AWS, you'll need 43 | 44 | - The [AWS CLI](https://aws.amazon.com/cli/) 45 | - A [Lambda function](https://aws.amazon.com/pm/lambda) with a Node.js runtime 46 | - A [function URL](https://docs.aws.amazon.com/lambda/latest/dg/lambda-urls.html) for your Lambda function 47 | 48 | Once you've created and configured a Lambda function, you can deploy the proxy with 49 | 50 | ```bash copy 51 | aws lambda update-function-code --function-name --zip-file fileb://$PWD/dist/index.zip 52 | ``` 53 | 54 | ### CORS 55 | 56 | If you're using the proxy to access Braintrust AI from a browser, you'll need to enable CORS on your Lambda 57 | function. This is a tricky process, but the following function URL CORS settings should work: 58 | 59 | - `Allow origin`: `*` 60 | - `Expose headers`: `content-type, keep-alive, access-control-allow-credentials, access-control-allow-origin, access-control-allow-methods` 61 | - `Allow headers`: `authorization` 62 | - `Allow methods`: `POST, GET` 63 | - `Max age`: `86400` 64 | - `Allow credentials`: `true` 65 | -------------------------------------------------------------------------------- /apis/node/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ai-proxy-lambda", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "./dist/index.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1", 8 | "build": "run-p build:*", 9 | "build:typecheck": "tsc --noEmit", 10 | "build:local": "esbuild --platform=node --bundle src/local.ts --outfile=dist/local.js --minify --sourcemap --target=es2020", 11 | "build:lambda": "esbuild --platform=node --external:@aws-sdk --bundle src/index.js --outfile=dist/index.js --minify --sourcemap --target=es2020", 12 | "watch:local": "esbuild --platform=node --bundle src/local.ts --outfile=dist/local.js --sourcemap --target=es2020 --watch", 13 | "watch:lambda": "esbuild --platform=node --external:@aws-sdk --bundle src/index.js --outfile=dist/index.js --sourcemap --target=es2020 --watch", 14 | "dev": "run-p dev:serve watch:*", 15 | "dev:serve": "nodemon dist/local.js", 16 | "postbuild": "cd dist && zip -r index.zip index.js" 17 | }, 18 | "author": "", 19 | "license": "ISC", 20 | "dependencies": { 21 | "@braintrust/proxy": "workspace:*", 22 | "@supabase/supabase-js": "^2.32.0", 23 | "ai": "2.2.22", 24 | "aws-lambda": "^1.0.7", 25 | "axios": "^1.9.0", 26 | "binary-search": "^1.3.6", 27 | "combined-stream": "^1.0.8", 28 | "cors": "^2.8.5", 29 | "dotenv": "^16.3.1", 30 | "esbuild": "^0.19.9", 31 | "eventsource-parser": "^1.1.1", 32 | "express": "^4.19.2", 33 | "openai": "^4.42.0", 34 | "redis": "^4.6.8" 35 | }, 36 | "devDependencies": { 37 | "@types/aws-lambda": "^8.10.119", 38 | "@types/combined-stream": "^1.0.3", 39 | "@types/cors": "^2.8.13", 40 | "@types/dotenv": "^8.2.0", 41 | "@types/express": "^4.17.17", 42 | "@types/node": "^20.5.0", 43 | "nodemon": "^3.0.1", 44 | "npm-run-all": "^4.1.5", 45 | "typescript": "^5.0.4" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /apis/node/src/anthropic.ts: -------------------------------------------------------------------------------- 1 | import { AIStreamCallbacksAndOptions } from "ai"; 2 | import { AIStream } from "ai"; 3 | 4 | // https://github.com/anthropics/anthropic-sdk-typescript/blob/0fc31f4f1ae2976afd0af3236e82d9e2c84c43c9/src/resources/completions.ts#L28-L49 5 | interface CompletionChunk { 6 | /** 7 | * The resulting completion up to and excluding the stop sequences. 8 | */ 9 | completion: string; 10 | 11 | /** 12 | * The model that performed the completion. 13 | */ 14 | model: string; 15 | 16 | /** 17 | * The reason that we stopped sampling. 18 | * 19 | * This may be one the following values: 20 | * 21 | * - `"stop_sequence"`: we reached a stop sequence — either provided by you via the 22 | * `stop_sequences` parameter, or a stop sequence built into the model 23 | * - `"max_tokens"`: we exceeded `max_tokens_to_sample` or the model's maximum 24 | */ 25 | stop_reason: string; 26 | } 27 | 28 | interface StreamError { 29 | error: { 30 | type: string; 31 | message: string; 32 | }; 33 | } 34 | 35 | interface StreamPing {} 36 | 37 | type StreamData = CompletionChunk | StreamError | StreamPing; 38 | 39 | function parseAnthropicStream(): (data: string) => string | void { 40 | let isFirst = true; 41 | return (data) => { 42 | const json = JSON.parse(data as string) as StreamData; 43 | 44 | // error event 45 | if ("error" in json) { 46 | throw new Error(`${json.error.type}: ${json.error.message}`); 47 | } 48 | 49 | // ping event 50 | if (!("completion" in json)) { 51 | return ""; 52 | } 53 | 54 | let text = json.completion; 55 | if (isFirst) { 56 | text = text.trimStart(); 57 | isFirst = false; 58 | } 59 | return text; 60 | }; 61 | } 62 | 63 | export function AnthropicStream( 64 | res: Response, 65 | cb?: AIStreamCallbacksAndOptions, 66 | ): ReadableStream { 67 | return AIStream(res, parseAnthropicStream(), cb); 68 | } 69 | -------------------------------------------------------------------------------- /apis/node/src/cache.ts: -------------------------------------------------------------------------------- 1 | import { createClient, RedisClientType } from "redis"; 2 | 3 | import { Env } from "./env"; 4 | 5 | let redisClient: RedisClientType | null = null; 6 | export async function getRedis() { 7 | if ( 8 | redisClient === null && 9 | ((Env.redisHost && Env.redisPort) || Env.redisUrl) 10 | ) { 11 | if (Env.redisUrl) { 12 | redisClient = createClient({ 13 | url: Env.redisUrl, 14 | }); 15 | } else { 16 | redisClient = createClient({ 17 | socket: { 18 | host: Env.redisHost, 19 | port: Env.redisPort, 20 | }, 21 | }); 22 | } 23 | await redisClient.connect(); 24 | } 25 | return redisClient; 26 | } 27 | -------------------------------------------------------------------------------- /apis/node/src/env.ts: -------------------------------------------------------------------------------- 1 | function reloadEnv() { 2 | return { 3 | braintrustApiUrl: 4 | process.env.BRAINTRUST_APP_URL || "https://www.braintrust.dev", 5 | orgName: process.env.ORG_NAME || "*", 6 | redisHost: process.env.REDIS_HOST, 7 | redisPort: parseInt(process.env.REDIS_PORT || "6379"), 8 | redisUrl: process.env.REDIS_URL, 9 | }; 10 | } 11 | 12 | export let Env = reloadEnv(); 13 | export function resetEnv() { 14 | Env = reloadEnv(); 15 | } 16 | -------------------------------------------------------------------------------- /apis/node/src/index.js: -------------------------------------------------------------------------------- 1 | import stream from "stream"; 2 | import util from "util"; 3 | 4 | import { nodeProxyV1 } from "./node-proxy"; 5 | 6 | const pipeline = util.promisify(stream.pipeline); 7 | 8 | function processError(res, err) { 9 | res.write("!"); 10 | res.write(`${err}`); 11 | } 12 | 13 | export const handler = awslambda.streamifyResponse( 14 | async (event, responseStream, context) => { 15 | // This flag allows the function to instantly return after the responseStream finishes, without waiting 16 | // for sockets (namely, Redis) to close. 17 | // See https://stackoverflow.com/questions/46793670/reuse-redis-connections-for-nodejs-lambda-function 18 | // and https://docs.aws.amazon.com/lambda/latest/dg/nodejs-context.html 19 | context.callbackWaitsForEmptyEventLoop = false; 20 | 21 | // https://docs.aws.amazon.com/lambda/latest/dg/response-streaming-tutorial.html 22 | const metadata = { 23 | statusCode: 200, 24 | headers: { 25 | "content-type": "text/plain", 26 | "access-control-max-age": "86400", 27 | }, 28 | }; 29 | 30 | const wrap = () => { 31 | return awslambda.HttpResponseStream.from(responseStream, metadata); 32 | }; 33 | 34 | if (event.requestContext.http.method === "OPTIONS") { 35 | responseStream = wrap(); 36 | responseStream.end(); 37 | return; 38 | } 39 | 40 | let aiStream = null; 41 | if (event.rawPath === "/") { 42 | await resetRedisInfo(); // XXX 43 | responseStream = wrap(); 44 | responseStream.write("Hello World!"); 45 | responseStream.end(); 46 | } else if (event.rawPath === "/empty") { 47 | responseStream = wrap(); 48 | responseStream.end(); 49 | } else if (event.rawPath.startsWith("/proxy/v1")) { 50 | try { 51 | await nodeProxyV1( 52 | event.requestContext.http.method, 53 | event.rawPath.slice("/proxy/v1".length), 54 | event.headers, 55 | event.body, 56 | (name, value) => { 57 | metadata.headers[name] = value; 58 | }, 59 | (code) => { 60 | metadata.statusCode = code; 61 | }, 62 | wrap, 63 | ); 64 | } catch (err) { 65 | console.error(err); 66 | metadata.statusCode = 500; 67 | responseStream.write(`Internal Server Error: ${err}`); 68 | responseStream.end(); 69 | } 70 | } else { 71 | metadata.statusCode = 404; 72 | responseStream = wrap(); 73 | responseStream.write("Not Found"); 74 | responseStream.end(); 75 | } 76 | }, 77 | ); 78 | -------------------------------------------------------------------------------- /apis/node/src/local.ts: -------------------------------------------------------------------------------- 1 | import express, { Response } from "express"; 2 | import dotenv from "dotenv"; 3 | import cors from "cors"; 4 | import { pipeline } from "stream/promises"; 5 | 6 | import { nodeProxyV1 } from "./node-proxy"; 7 | import { resetEnv } from "./env"; 8 | 9 | dotenv.config({ path: ".env.local" }); 10 | resetEnv(); 11 | 12 | const app = express(); 13 | app.use(express.text({ type: "*/*", limit: "50mb" })); 14 | app.use(cors()); 15 | 16 | const host = "localhost"; 17 | const port = 8001; 18 | 19 | function processError(res: Response, err: any) { 20 | res.write(`!${err}`); 21 | res.end(); 22 | } 23 | 24 | app.get("/proxy/v1/*", async (req, res) => { 25 | const url = req.url.slice("/proxy/v1".length); 26 | try { 27 | await nodeProxyV1({ 28 | method: "GET", 29 | url, 30 | proxyHeaders: req.headers, 31 | body: null, 32 | setHeader: res.setHeader.bind(res), 33 | setStatusCode: res.status.bind(res), 34 | getRes: () => res, 35 | }); 36 | } catch (e: any) { 37 | console.error(e); 38 | processError(res, e); 39 | } 40 | }); 41 | 42 | app.post("/proxy/v1/*", async (req, res) => { 43 | const url = req.url.slice("/proxy/v1".length); 44 | try { 45 | await nodeProxyV1({ 46 | method: "POST", 47 | url, 48 | proxyHeaders: req.headers, 49 | body: req.body, 50 | setHeader: res.setHeader.bind(res), 51 | setStatusCode: res.status.bind(res), 52 | getRes: () => res, 53 | }); 54 | } catch (e: any) { 55 | console.error(e); 56 | processError(res, e); 57 | } 58 | }); 59 | 60 | app.listen(port, () => { 61 | console.log(`[server]: Server is running at http://${host}:${port}`); 62 | }); 63 | -------------------------------------------------------------------------------- /apis/node/src/login.ts: -------------------------------------------------------------------------------- 1 | import bsearch from "binary-search"; 2 | import { Env } from "./env"; 3 | import { APISecret } from "@braintrust/proxy/schema"; 4 | 5 | export async function lookupApiSecret( 6 | useCache: boolean, 7 | loginToken: string, 8 | model: string | null, 9 | org_name?: string, 10 | ) { 11 | const cacheKey = `${loginToken}:${model}`; 12 | const cached = useCache ? loginTokenToApiKey.get(cacheKey) : undefined; 13 | if (cached !== undefined) { 14 | return cached; 15 | } 16 | 17 | let secrets: APISecret[] = []; 18 | try { 19 | const response = await fetch(`${Env.braintrustApiUrl}/api/secret`, { 20 | method: "POST", 21 | headers: { 22 | Authorization: `Bearer ${loginToken}`, 23 | "Content-Type": "application/json", 24 | }, 25 | body: JSON.stringify({ 26 | model, 27 | org_name, 28 | mode: "full", 29 | }), 30 | }); 31 | if (response.ok) { 32 | secrets = (await response.json()).filter( 33 | (row: APISecret) => Env.orgName === "*" || row.org_name === Env.orgName, 34 | ); 35 | } else { 36 | throw new Error(await response.text()); 37 | } 38 | } catch (e) { 39 | throw new Error(`Failed to lookup api key: ${e}`); 40 | } 41 | 42 | if (secrets.length === 0) { 43 | return []; 44 | } 45 | 46 | // This is somewhat arbitrary. Cache the API key for an hour. 47 | loginTokenToApiKey.insert( 48 | cacheKey, 49 | secrets, 50 | Number(new Date()) / 1000 + 3600, 51 | ); 52 | 53 | return secrets; 54 | } 55 | 56 | function fixIndex(i: number) { 57 | return i >= 0 ? i : -i - 1; 58 | } 59 | 60 | class TTLCache { 61 | maxSize: number; 62 | cache: { [key: string]: { value: V; expiration: number } }; 63 | expirations: { expiration: number; key: string }[]; 64 | 65 | constructor(maxSize = 128) { 66 | this.maxSize = maxSize; 67 | this.cache = {}; 68 | this.expirations = []; 69 | } 70 | 71 | insert(key: string, value: V, expiration: number) { 72 | while (Object.keys(this.cache).length >= this.maxSize) { 73 | const first = this.expirations.shift(); 74 | delete this.cache[first!.key]; 75 | } 76 | 77 | this.cache[key] = { value, expiration }; 78 | let pos = fixIndex( 79 | bsearch( 80 | this.expirations, 81 | { expiration, key }, 82 | (a, b) => a.expiration - b.expiration, 83 | ), 84 | ); 85 | if (pos < 0) { 86 | pos = -pos - 1; 87 | } 88 | this.expirations = this.expirations 89 | .slice(0, pos) 90 | .concat({ expiration, key }) 91 | .concat(this.expirations.slice(pos)); 92 | } 93 | 94 | get(key: string) { 95 | const now = Date.now() / 1000; 96 | this._garbageCollect(now); 97 | const entry = this.cache[key]; 98 | if (entry === undefined) { 99 | return undefined; 100 | } else if (entry.expiration < now) { 101 | delete this.cache[key]; 102 | return undefined; 103 | } else { 104 | return entry.value; 105 | } 106 | } 107 | 108 | _garbageCollect(now: number) { 109 | let last_expired = fixIndex( 110 | bsearch( 111 | this.expirations, 112 | { expiration: now, key: "" }, 113 | (a, b) => a.expiration - b.expiration, 114 | ), 115 | ); 116 | 117 | if ( 118 | last_expired >= this.expirations.length || 119 | this.expirations[last_expired].expiration >= now 120 | ) { 121 | last_expired -= 1; 122 | } 123 | 124 | if (last_expired >= 0) { 125 | for (let i = 0; i < last_expired + 1; i++) { 126 | delete this.cache[this.expirations[i].key]; 127 | } 128 | this.expirations = this.expirations.slice(last_expired + 1); 129 | } 130 | } 131 | } 132 | 133 | const dbTokenCache = new TTLCache(128); 134 | const loginTokenToApiKey = new TTLCache(128); 135 | -------------------------------------------------------------------------------- /apis/node/src/node-proxy.ts: -------------------------------------------------------------------------------- 1 | import { Writable, Readable } from "node:stream"; 2 | import * as crypto from "crypto"; 3 | 4 | // https://stackoverflow.com/questions/73308289/typescript-error-converting-a-native-fetch-body-webstream-to-a-node-stream 5 | import type * as streamWeb from "node:stream/web"; 6 | 7 | import { proxyV1 } from "@braintrust/proxy"; 8 | 9 | import { getRedis } from "./cache"; 10 | import { lookupApiSecret } from "./login"; 11 | 12 | export async function nodeProxyV1({ 13 | method, 14 | url, 15 | proxyHeaders, 16 | body, 17 | setHeader, 18 | setStatusCode, 19 | getRes, 20 | }: { 21 | method: "GET" | "POST"; 22 | url: string; 23 | proxyHeaders: any; 24 | body: any; 25 | setHeader: (name: string, value: string) => void; 26 | setStatusCode: (code: number) => void; 27 | getRes: () => Writable; 28 | }): Promise { 29 | // Unlike the Cloudflare worker API, which supports public access, this API 30 | // mandates authentication 31 | 32 | const cacheGet = async (encryptionKey: string, key: string) => { 33 | const redis = await getRedis(); 34 | if (!redis) { 35 | return null; 36 | } 37 | return await redis.get(key); 38 | }; 39 | const cachePut = async ( 40 | encryptionKey: string, 41 | key: string, 42 | value: string, 43 | ttl_seconds?: number, 44 | ) => { 45 | const redis = await getRedis(); 46 | if (!redis) { 47 | return; 48 | } 49 | redis.set(key, value, { 50 | // Cache it for a week if no ttl_seconds is provided 51 | EX: ttl_seconds ?? 60 * 60 * 24 * 7, 52 | }); 53 | }; 54 | 55 | let { readable, writable } = new TransformStream(); 56 | 57 | // Note: we must resolve the proxy after forwarding the stream to `res`, 58 | // because the proxy promise resolves after its internal stream has finished 59 | // writing. 60 | const proxyPromise = proxyV1({ 61 | method, 62 | url, 63 | proxyHeaders, 64 | body, 65 | setHeader, 66 | setStatusCode, 67 | res: writable, 68 | getApiSecrets: lookupApiSecret, 69 | cacheGet, 70 | cachePut, 71 | digest: async (message: string) => { 72 | return crypto.createHash("md5").update(message).digest("hex"); 73 | }, 74 | }); 75 | 76 | const res = getRes(); 77 | const readableNode = Readable.fromWeb(readable as streamWeb.ReadableStream); 78 | readableNode.pipe(res, { end: true }); 79 | await proxyPromise; 80 | } 81 | -------------------------------------------------------------------------------- /apis/node/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es2015", 4 | "lib": ["dom", "dom.iterable", "esnext"], 5 | "allowJs": true, 6 | "preserveConstEnums": true, 7 | "sourceMap": false, 8 | "skipLibCheck": true, 9 | "strict": true, 10 | "forceConsistentCasingInFileNames": true, 11 | "noEmit": true, 12 | "esModuleInterop": true, 13 | "module": "commonjs", 14 | "moduleResolution": "node", 15 | "resolveJsonModule": true, 16 | "isolatedModules": true, 17 | "jsx": "preserve", 18 | "incremental": true, 19 | "paths": { 20 | "#/*": ["./*"] 21 | }, 22 | "plugins": [ 23 | { 24 | "name": "next" 25 | } 26 | ] 27 | }, 28 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], 29 | "exclude": ["node_modules"] 30 | } 31 | -------------------------------------------------------------------------------- /apis/vercel/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "root": true, 3 | "extends": "next/core-web-vitals" 4 | } 5 | -------------------------------------------------------------------------------- /apis/vercel/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # Dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # Testing 9 | /coverage 10 | 11 | # Next.js 12 | /.next/ 13 | /out/ 14 | 15 | # Production 16 | /build 17 | 18 | # Misc 19 | .DS_Store 20 | *.pem 21 | 22 | # Debug 23 | npm-debug.log* 24 | yarn-debug.log* 25 | yarn-error.log* 26 | 27 | # Local ENV files 28 | .env.local 29 | .env.development.local 30 | .env.test.local 31 | .env.production.local 32 | 33 | # Vercel 34 | .vercel 35 | 36 | # Turborepo 37 | .turbo 38 | 39 | # typescript 40 | *.tsbuildinfo 41 | .env*.local 42 | -------------------------------------------------------------------------------- /apis/vercel/.npmrc: -------------------------------------------------------------------------------- 1 | # Enabled to avoid deps failing to use next@canary 2 | legacy-peer-deps=true 3 | -------------------------------------------------------------------------------- /apis/vercel/README.md: -------------------------------------------------------------------------------- 1 | # Braintrust AI Proxy (Vercel) 2 | 3 | This directory contains an implementation of the Braintrust AI Proxy that runs on 4 | [Vercel Edge Functions](https://vercel.com/docs/functions/edge-functions). Because 5 | of their global network, you get the benefit of low latency and can scale up to millions 6 | of users. 7 | 8 | ## Deploying 9 | 10 | ### Forking the repository 11 | 12 | Vercel is tightly integrated with Git, so the best way to deploy is to fork this repository. Then, 13 | create a new [Vercel project](https://vercel.com/new) and 14 | 15 | - Connect your forked repository to the project 16 | - Create a [KV storage](https://vercel.com/docs/storage/vercel-kv/quickstart) instance and connect it to the project 17 | 18 | ### Connecting to vercel 19 | 20 | From this directory, link your project and pull down the KV configuration by running: 21 | 22 | ```bash copy 23 | npx vercel link 24 | npx vercel env pull 25 | ``` 26 | 27 | You should now have a file named `.env.local` with a bunch of `KV_` variables. 28 | 29 | ### Running locally 30 | 31 | To build the proxy, you'll need to install [pnpm](https://pnpm.io/installation), and then from the 32 | [repository's root](../..), run: 33 | 34 | ```bash copy 35 | pnpm install 36 | pnpm build 37 | ``` 38 | 39 | Then, back in this directory, you can run the proxy locally with 40 | 41 | ```bash copy 42 | pnpm dev 43 | ``` 44 | 45 | ### Deploying to Vercel 46 | 47 | If you've integrated the proxy into Vercel via Git, then it will automatically deploy on every push. 48 | -------------------------------------------------------------------------------- /apis/vercel/app/404.html: -------------------------------------------------------------------------------- 1 | Not found 2 | -------------------------------------------------------------------------------- /apis/vercel/app/layout.tsx: -------------------------------------------------------------------------------- 1 | export const metadata = { 2 | title: "Next.js", 3 | description: "Generated by Next.js", 4 | }; 5 | 6 | export default function RootLayout({ 7 | children, 8 | }: { 9 | children: React.ReactNode; 10 | }) { 11 | return ( 12 | 13 | {children} 14 | 15 | ); 16 | } 17 | -------------------------------------------------------------------------------- /apis/vercel/components/headers.tsx: -------------------------------------------------------------------------------- 1 | import { useState, FC } from "react"; 2 | import { Button } from "@vercel/examples-ui"; 3 | 4 | const Headers: FC<{ path: string; children: string }> = ({ 5 | path, 6 | children, 7 | }) => { 8 | const [loading, setLoading] = useState(false); 9 | const [state, setState] = useState({ 10 | path, 11 | latency: null, 12 | status: null, 13 | headers: { 14 | "X-RateLimit-Limit": "", 15 | "X-RateLimit-Remaining": "", 16 | "X-RateLimit-Reset": "", 17 | }, 18 | data: null, 19 | }); 20 | const handleFetch = async () => { 21 | const start = Date.now(); 22 | setLoading(true); 23 | 24 | try { 25 | const res = await fetch(path); 26 | setState({ 27 | path, 28 | latency: `~${Math.round(Date.now() - start)}ms`, 29 | status: `${res.status}`, 30 | headers: { 31 | "X-RateLimit-Limit": res.headers.get("X-RateLimit-Limit"), 32 | "X-RateLimit-Remaining": res.headers.get("x-RateLimit-Remaining"), 33 | "X-RateLimit-Reset": res.headers.get("x-RateLimit-Reset"), 34 | }, 35 | data: res.headers.get("Content-Type")?.includes("application/json") 36 | ? await res.json() 37 | : null, 38 | }); 39 | } finally { 40 | setLoading(false); 41 | } 42 | }; 43 | 44 | return ( 45 |
46 | 49 |
54 |         {JSON.stringify(state, null, 2)}
55 |       
56 |
57 | ); 58 | }; 59 | 60 | export default Headers; 61 | -------------------------------------------------------------------------------- /apis/vercel/next-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | /// 3 | /// 4 | 5 | // NOTE: This file should not be edited 6 | // see https://nextjs.org/docs/basic-features/typescript for more information. 7 | -------------------------------------------------------------------------------- /apis/vercel/next.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('next').NextConfig} */ 2 | const nextConfig = { 3 | poweredByHeader: false, 4 | transpilePackages: ["ai-proxy"], 5 | typescript: { 6 | ignoreBuildErrors: true, 7 | }, 8 | eslint: { 9 | ignoreDuringBuilds: true, 10 | }, 11 | }; 12 | 13 | module.exports = nextConfig; 14 | -------------------------------------------------------------------------------- /apis/vercel/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "braintrust-proxy", 3 | "repository": "https://github.com/braintrustdata/braintrust-proxy", 4 | "license": "MIT", 5 | "private": true, 6 | "scripts": { 7 | "dev": "next dev", 8 | "build": "next build", 9 | "start": "next start", 10 | "lint": "next lint" 11 | }, 12 | "dependencies": { 13 | "@upstash/ratelimit": "^0.4.3", 14 | "@vercel/examples-ui": "^1.0.5", 15 | "@vercel/kv": "^0.2.2", 16 | "@braintrust/proxy": "workspace:*", 17 | "next": "14.2.3", 18 | "react": "latest", 19 | "react-dom": "latest" 20 | }, 21 | "devDependencies": { 22 | "@types/node": "^17.0.45", 23 | "@types/react": "latest", 24 | "autoprefixer": "^10.4.14", 25 | "eslint": "^8.36.0", 26 | "eslint-config-next": "canary", 27 | "postcss": "^8.4.21", 28 | "tailwindcss": "^3.2.7", 29 | "turbo": "^1.8.5", 30 | "typescript": "4.7.4" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /apis/vercel/pages/_app.tsx: -------------------------------------------------------------------------------- 1 | import type { AppProps } from "next/app"; 2 | import type { LayoutProps } from "@vercel/examples-ui/layout"; 3 | import { getLayout } from "@vercel/examples-ui"; 4 | import "@vercel/examples-ui/globals.css"; 5 | 6 | export default function MyApp({ Component, pageProps }: AppProps) { 7 | const Layout = getLayout(Component); 8 | 9 | return ( 10 | 18 | 19 | 20 | ); 21 | } 22 | -------------------------------------------------------------------------------- /apis/vercel/pages/api/ping.ts: -------------------------------------------------------------------------------- 1 | import type { NextRequest } from "next/server"; 2 | import { Ratelimit } from "@upstash/ratelimit"; 3 | import { kv } from "@vercel/kv"; 4 | 5 | const ratelimit = new Ratelimit({ 6 | redis: kv, 7 | // 5 requests from the same IP in 10 seconds 8 | limiter: Ratelimit.slidingWindow(1000, "10 s"), 9 | }); 10 | 11 | export const config = { 12 | runtime: "edge", 13 | }; 14 | 15 | let i = 0; 16 | export default async function handler(request: NextRequest) { 17 | // You could alternatively limit based on user ID or similar 18 | const ip = request.ip ?? "127.0.0.1"; 19 | /* 20 | let start = Date.now(); 21 | const { limit, reset, remaining } = await ratelimit.limit(ip); 22 | let end = Date.now(); 23 | console.log("Rate limit KV latency (ms):", end - start); 24 | */ 25 | await kv.set("foo", `${i}`); 26 | i += 1; 27 | 28 | let start = Date.now(); 29 | const foo = await kv.get("foo"); 30 | let end = Date.now(); 31 | console.log("Get ", foo, " KV latency (ms):", end - start); 32 | 33 | return new Response(JSON.stringify({ success: true }), { 34 | status: 200, 35 | headers: { 36 | "X-RateLimit-Limit": "0", 37 | "X-RateLimit-Remaining": "0", 38 | "X-RateLimit-Reset": "0", 39 | }, 40 | }); 41 | } 42 | -------------------------------------------------------------------------------- /apis/vercel/pages/api/v1/[...slug].ts: -------------------------------------------------------------------------------- 1 | import { kv } from "@vercel/kv"; 2 | import { EdgeProxyV1, CacheSetOptions } from "@braintrust/proxy/edge"; 3 | 4 | export const config = { 5 | runtime: "edge", 6 | }; 7 | 8 | const KVCache = { 9 | get: kv.get, 10 | set: async (key: string, value: T, opts: CacheSetOptions) => { 11 | await kv.set( 12 | key, 13 | value, 14 | opts.ttl !== undefined 15 | ? { 16 | ex: opts.ttl, 17 | } 18 | : {}, 19 | ); 20 | }, 21 | }; 22 | 23 | export default EdgeProxyV1({ 24 | getRelativeURL: (request) => { 25 | const url = new URL(request.url); 26 | const params = url.searchParams.getAll("slug"); 27 | return "/" + params.join("/"); 28 | }, 29 | cors: true, 30 | credentialsCache: KVCache, 31 | completionsCache: KVCache, 32 | braintrustApiUrl: process.env.BRAINTRUST_APP_URL, 33 | }); 34 | -------------------------------------------------------------------------------- /apis/vercel/pages/index.tsx: -------------------------------------------------------------------------------- 1 | import { Layout, Page, Text, Link } from "@vercel/examples-ui"; 2 | import Headers from "@components/headers"; 3 | 4 | export default function Index() { 5 | return ( 6 | 7 | 8 | API Rate Limiting with Vercel KV 9 | 10 | 11 | By using Redis with Vercel KV, we can keep a counter of requests by IP 12 | address. 13 | 14 | 15 | For the demo below, you can send a maximum of{" "} 16 | 5 requests every 10 seconds. 17 | 18 | Make a request 19 | 20 | The pattern we're using in this example is inspired by the{" "} 21 | 26 | GitHub API 27 | 28 | . 29 | 30 | 31 | ); 32 | } 33 | 34 | Index.Layout = Layout; 35 | -------------------------------------------------------------------------------- /apis/vercel/postcss.config.js: -------------------------------------------------------------------------------- 1 | // If you want to use other PostCSS plugins, see the following: 2 | // https://tailwindcss.com/docs/using-with-preprocessors 3 | module.exports = { 4 | plugins: { 5 | tailwindcss: {}, 6 | autoprefixer: {}, 7 | }, 8 | }; 9 | -------------------------------------------------------------------------------- /apis/vercel/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braintrustdata/braintrust-proxy/fd1d7f3821e36c3161bbb9d8f836f46fd4e7ba2b/apis/vercel/public/favicon.ico -------------------------------------------------------------------------------- /apis/vercel/tailwind.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | presets: [require("@vercel/examples-ui/tailwind")], 3 | content: [ 4 | "./pages/**/*.{js,ts,jsx,tsx}", 5 | "./components/**/*.{js,ts,jsx,tsx}", 6 | "./node_modules/@vercel/examples-ui/**/*.js", 7 | ], 8 | }; 9 | -------------------------------------------------------------------------------- /apis/vercel/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "baseUrl": ".", 4 | "target": "es2015", 5 | "lib": [ 6 | "dom", 7 | "dom.iterable", 8 | "esnext" 9 | ], 10 | "allowJs": true, 11 | "skipLibCheck": true, 12 | "strict": true, 13 | "forceConsistentCasingInFileNames": true, 14 | "noEmit": true, 15 | "esModuleInterop": true, 16 | "module": "esnext", 17 | "moduleResolution": "node", 18 | "resolveJsonModule": true, 19 | "isolatedModules": true, 20 | "jsx": "preserve", 21 | "paths": { 22 | "@lib/*": [ 23 | "lib/*" 24 | ], 25 | "@components": [ 26 | "components/index" 27 | ], 28 | "@components/*": [ 29 | "components/*" 30 | ], 31 | "#/*": [ 32 | "./*" 33 | ] 34 | }, 35 | "incremental": true, 36 | "plugins": [ 37 | { 38 | "name": "next" 39 | } 40 | ] 41 | }, 42 | "include": [ 43 | "next-env.d.ts", 44 | "**/*.ts", 45 | "**/*.tsx", 46 | ".next/types/**/*.ts" 47 | ], 48 | "exclude": [ 49 | "node_modules" 50 | ] 51 | } 52 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "repository": "https://github.com/vercel/examples.git", 3 | "license": "MIT", 4 | "private": true, 5 | "workspaces": [ 6 | "apis/*", 7 | "packages/*" 8 | ], 9 | "scripts": { 10 | "build": "turbo run build", 11 | "dev": "turbo run dev", 12 | "start": "turbo run start", 13 | "lint": "turbo run lint", 14 | "clean": "turbo run clean", 15 | "test": "vitest run" 16 | }, 17 | "devDependencies": { 18 | "eslint": "^8.56.0", 19 | "eslint-config-turbo": "latest", 20 | "turbo": "^2.3.3", 21 | "vite-tsconfig-paths": "^4.3.2", 22 | "vitest": "^2.1.9" 23 | }, 24 | "packageManager": "pnpm@8.15.5", 25 | "resolutions": { 26 | "zod": "3.22.4" 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /packages/proxy/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | extends: ["../../.eslintrc.json"], 3 | parser: "@typescript-eslint/parser", 4 | rules: { 5 | "@typescript-eslint/no-floating-promises": "error", 6 | }, 7 | plugins: ["@typescript-eslint"], 8 | parserOptions: { 9 | project: "./tsconfig.json", 10 | tsconfigRootDir: __dirname, 11 | }, 12 | // This is necessary because we're asking eslint to parse the files in this package, 13 | // and its tsconfig.json says to ignore these. 14 | ignorePatterns: ["**/*.test.ts", "**/dist/**", "vitest.config.js"], 15 | }; 16 | -------------------------------------------------------------------------------- /packages/proxy/.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | -------------------------------------------------------------------------------- /packages/proxy/edge/deps.test.ts: -------------------------------------------------------------------------------- 1 | import skott from "skott"; 2 | import { describe, expect, it } from "vitest"; 3 | 4 | describe("proxy/edge", () => { 5 | it("no circ dependencies", async () => { 6 | const { useGraph } = await skott({ 7 | entrypoint: `${__dirname}/index.ts`, 8 | tsConfigPath: `${__dirname}/../tsconfig.json`, 9 | dependencyTracking: { 10 | builtin: false, 11 | thirdParty: true, 12 | typeOnly: true, 13 | }, 14 | }); 15 | 16 | const { findCircularDependencies } = useGraph(); 17 | 18 | expect(findCircularDependencies()).toEqual([]); 19 | }); 20 | }); 21 | -------------------------------------------------------------------------------- /packages/proxy/edge/exporter.ts: -------------------------------------------------------------------------------- 1 | import { diag } from "@opentelemetry/api"; 2 | import { 3 | Aggregation, 4 | AggregationTemporality, 5 | MetricReader, 6 | } from "@opentelemetry/sdk-metrics"; 7 | 8 | export class FlushingExporter extends MetricReader { 9 | /** 10 | * Constructor 11 | * @param config Exporter configuration 12 | * @param callback Callback to be called after a server was started 13 | */ 14 | constructor(private flushFn: (resourceMetrics: any) => Promise) { 15 | super({ 16 | aggregationSelector: (_instrumentType) => Aggregation.Default(), 17 | aggregationTemporalitySelector: (_instrumentType) => 18 | AggregationTemporality.CUMULATIVE, 19 | }); 20 | } 21 | 22 | override async onForceFlush(): Promise { 23 | // This is the main entry point, since the exporter is called by the SDK 24 | const { resourceMetrics, errors } = await this.collect(); 25 | if (errors.length > 0) { 26 | for (const error of errors) { 27 | diag.error("Error while exporting metrics", error); 28 | } 29 | } 30 | const resp = await this.flushFn(resourceMetrics); 31 | 32 | if (!resp.ok) { 33 | const error = Error( 34 | `Error while flushing metrics: ${resp.status} (${ 35 | resp.statusText 36 | }): ${await resp.text()}`, 37 | ); 38 | console.log("Error while flushing metrics", error); 39 | throw error; 40 | } 41 | } 42 | 43 | /** 44 | * Shuts down the export server and clears the registry 45 | */ 46 | override async onShutdown(): Promise { 47 | // do nothing 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /packages/proxy/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@braintrust/proxy", 3 | "version": "0.0.9", 4 | "description": "A proxy server that load balances across AI providers.", 5 | "main": "./dist/index.js", 6 | "module": "./dist/index.mjs", 7 | "types": "./dist/index.d.ts", 8 | "scripts": { 9 | "build": "tsup", 10 | "watch": "tsup --watch", 11 | "clean": "rm -r dist/*", 12 | "test": "vitest run" 13 | }, 14 | "exports": { 15 | "./package.json": "./package.json", 16 | ".": { 17 | "types": "./dist/index.d.ts", 18 | "import": "./dist/index.mjs", 19 | "module": "./dist/index.mjs", 20 | "require": "./dist/index.js" 21 | }, 22 | "./edge": { 23 | "types": "./edge/dist/index.d.ts", 24 | "import": "./edge/dist/index.mjs", 25 | "module": "./edge/dist/index.mjs", 26 | "require": "./edge/dist/index.js" 27 | }, 28 | "./schema": { 29 | "types": "./schema/dist/index.d.ts", 30 | "import": "./schema/dist/index.mjs", 31 | "module": "./schema/dist/index.mjs", 32 | "require": "./schema/dist/index.js" 33 | }, 34 | "./utils": { 35 | "types": "./utils/dist/index.d.ts", 36 | "import": "./utils/dist/index.mjs", 37 | "module": "./utils/dist/index.mjs", 38 | "require": "./utils/dist/index.js" 39 | }, 40 | "./types": { 41 | "types": "./types/dist/index.d.ts", 42 | "import": "./types/dist/index.mjs", 43 | "module": "./types/dist/index.mjs", 44 | "require": "./types/dist/index.js" 45 | } 46 | }, 47 | "files": [ 48 | "dist/**/*", 49 | "edge/dist/**/*", 50 | "schema/dist/**/*", 51 | "types/dist/**/*" 52 | ], 53 | "license": "MIT", 54 | "publishConfig": { 55 | "access": "public" 56 | }, 57 | "homepage": "https://www.braintrust.dev/docs/guides/proxy", 58 | "repository": { 59 | "type": "git", 60 | "url": "git+https://github.com/braintrustdata/braintrust-proxy.git" 61 | }, 62 | "bugs": { 63 | "url": "https://github.com/braintrustdata/braintrust-proxy/issues" 64 | }, 65 | "keywords": [ 66 | "ai", 67 | "proxy", 68 | "vercel", 69 | "cloudflare", 70 | "workers", 71 | "edge", 72 | "openai", 73 | "lambda", 74 | "express" 75 | ], 76 | "devDependencies": { 77 | "@types/content-disposition": "^0.5.8", 78 | "@types/jsonwebtoken": "^9.0.7", 79 | "@types/node": "^20.10.5", 80 | "@types/uuid": "^9.0.7", 81 | "@types/yargs": "^17.0.33", 82 | "@typescript-eslint/eslint-plugin": "^8.21.0", 83 | "esbuild": "^0.19.10", 84 | "msw": "^2.8.2", 85 | "npm-run-all": "^4.1.5", 86 | "skott": "^0.35.4", 87 | "tsup": "^8.4.0", 88 | "typescript": "5.5.4", 89 | "vite-tsconfig-paths": "^4.3.2", 90 | "vitest": "^2.1.9", 91 | "yargs": "^17.7.2" 92 | }, 93 | "dependencies": { 94 | "@anthropic-ai/sdk": "^0.39.0", 95 | "@apidevtools/json-schema-ref-parser": "^11.9.1", 96 | "@aws-sdk/client-bedrock-runtime": "^3.806.0", 97 | "@braintrust/core": "^0.0.87", 98 | "@breezystack/lamejs": "^1.2.7", 99 | "@google/genai": "^0.13.0", 100 | "@opentelemetry/api": "^1.7.0", 101 | "@opentelemetry/core": "^1.19.0", 102 | "@opentelemetry/resources": "^1.19.0", 103 | "@opentelemetry/sdk-metrics": "^1.19.0", 104 | "ai": "2.2.37", 105 | "cache-control-parser": "^2.0.6", 106 | "content-disposition": "^0.5.4", 107 | "date-fns": "^4.1.0", 108 | "eventsource-parser": "^1.1.1", 109 | "jose": "^5.9.6", 110 | "jsonwebtoken": "^9.0.2", 111 | "openai": "4.89.0", 112 | "uuid": "^9.0.1", 113 | "zod": "^3.22.4" 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /packages/proxy/schema/audio.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | 3 | export const mp3BitrateSchema = z.union([ 4 | z.literal(8), 5 | z.literal(16), 6 | z.literal(24), 7 | z.literal(32), 8 | z.literal(40), 9 | z.literal(48), 10 | z.literal(64), 11 | z.literal(80), 12 | z.literal(96), 13 | z.literal(112), 14 | z.literal(128), 15 | z.literal(160), 16 | z.literal(192), 17 | z.literal(224), 18 | z.literal(256), 19 | z.literal(320), 20 | ]); 21 | 22 | export type Mp3Bitrate = z.infer; 23 | 24 | export const pcmAudioFormatSchema = z 25 | .discriminatedUnion("name", [ 26 | z.object({ 27 | name: z.literal("pcm"), 28 | byte_order: z.enum(["little", "big"]).default("little"), 29 | number_encoding: z.enum(["int", "float"]).default("int"), 30 | bits_per_sample: z.union([z.literal(8), z.literal(16)]), 31 | }), 32 | z.object({ 33 | name: z.literal("g711"), 34 | algorithm: z.enum(["a", "mu"]), 35 | }), 36 | ]) 37 | .and( 38 | z.object({ 39 | // Common codec parameters. 40 | channels: z.literal(1).or(z.literal(2)), 41 | sample_rate: z.union([ 42 | z.literal(8000), 43 | z.literal(11025), 44 | z.literal(12000), 45 | z.literal(16000), 46 | z.literal(22050), 47 | z.literal(24000), 48 | z.literal(32000), 49 | z.literal(44100), 50 | z.literal(48000), 51 | ]), 52 | }), 53 | ); 54 | 55 | export type PcmAudioFormat = z.infer; 56 | -------------------------------------------------------------------------------- /packages/proxy/schema/deps.test.ts: -------------------------------------------------------------------------------- 1 | import skott from "skott"; 2 | import { describe, expect, it } from "vitest"; 3 | 4 | describe("proxy/schema", () => { 5 | it("no circ dependencies", async () => { 6 | const { useGraph } = await skott({ 7 | entrypoint: `${__dirname}/index.ts`, 8 | tsConfigPath: `${__dirname}/../tsconfig.json`, 9 | dependencyTracking: { 10 | builtin: false, 11 | thirdParty: true, 12 | typeOnly: true, 13 | }, 14 | }); 15 | 16 | const { findCircularDependencies } = useGraph(); 17 | 18 | expect(findCircularDependencies()).toEqual([]); 19 | }); 20 | }); 21 | -------------------------------------------------------------------------------- /packages/proxy/schema/index.test.ts: -------------------------------------------------------------------------------- 1 | import { MessageCreateParamsBase } from "@anthropic-ai/sdk/resources/messages"; 2 | import { GenerateContentParameters } from "@google/genai"; 3 | import { ChatCompletionCreateParams } from "openai/resources"; 4 | import { expect, it } from "vitest"; 5 | import { ModelFormat, translateParams } from "./index"; 6 | 7 | const examples: Record< 8 | string, 9 | { 10 | openai: ChatCompletionCreateParams; 11 | } & ( // NOTE: these are not strictly the API params. 12 | | { google: GenerateContentParameters } 13 | | { anthropic: MessageCreateParamsBase } 14 | ) 15 | > = { 16 | simple: { 17 | openai: { 18 | model: "gpt-4o", 19 | max_tokens: 1500, 20 | temperature: 0.7, 21 | top_p: 0.9, 22 | frequency_penalty: 0.1, 23 | presence_penalty: 0.2, 24 | messages: [ 25 | { role: "system", content: "You are a helpful assistant." }, 26 | { role: "user", content: "Hello, how are you?" }, 27 | ], 28 | stream: true, 29 | }, 30 | google: { 31 | maxOutputTokens: 1500, 32 | max_tokens: 1500, 33 | messages: [ 34 | { 35 | content: "You are a helpful assistant.", 36 | role: "system", 37 | }, 38 | { 39 | content: "Hello, how are you?", 40 | role: "user", 41 | }, 42 | ], 43 | model: "gpt-4o", 44 | stream: true, 45 | temperature: 0.7, 46 | top_p: 0.9, 47 | }, 48 | anthropic: { 49 | max_tokens: 1500, 50 | messages: [ 51 | { 52 | content: "You are a helpful assistant.", 53 | // @ts-expect-error -- TODO: shouldn't we have translated this to a non system role? 54 | role: "system", 55 | }, 56 | { 57 | content: "Hello, how are you?", 58 | role: "user", 59 | }, 60 | ], 61 | model: "gpt-4o", 62 | stream: true, 63 | temperature: 0.7, 64 | top_p: 0.9, 65 | }, 66 | }, 67 | reasoning_effort: { 68 | openai: { 69 | model: "gpt-4o", 70 | messages: [ 71 | { 72 | role: "system", 73 | content: "You are a detailed reasoning assistant.", 74 | }, 75 | { 76 | role: "user", 77 | content: "Explain how to solve 2x + 4 = 12 step by step.", 78 | }, 79 | ], 80 | temperature: 0, 81 | max_tokens: 1000, 82 | reasoning_effort: "high", 83 | stream: false, 84 | }, 85 | google: { 86 | model: "gpt-4o", 87 | // notice how this is still an intermediate param 88 | // google's api expects a content instead of messages, for example 89 | messages: [ 90 | { 91 | role: "system", 92 | content: "You are a detailed reasoning assistant.", 93 | }, 94 | { 95 | role: "user", 96 | content: "Explain how to solve 2x + 4 = 12 step by step.", 97 | }, 98 | ], 99 | temperature: 0, 100 | thinkingConfig: { 101 | thinkingBudget: 800, 102 | includeThoughts: true, 103 | }, 104 | maxOutputTokens: 1000, 105 | max_tokens: 1000, 106 | stream: false, 107 | }, 108 | anthropic: { 109 | model: "gpt-4o", 110 | messages: [ 111 | { 112 | // @ts-expect-error -- we use the role to later manipulate the request 113 | role: "system", 114 | content: "You are a detailed reasoning assistant.", 115 | }, 116 | { 117 | role: "user", 118 | content: "Explain how to solve 2x + 4 = 12 step by step.", 119 | }, 120 | ], 121 | temperature: 1, 122 | stream: false, 123 | max_tokens: 1536, 124 | thinking: { 125 | budget_tokens: 1024, 126 | type: "enabled", 127 | }, 128 | }, 129 | }, 130 | "reasoning disable": { 131 | openai: { 132 | model: "gpt-4o", 133 | messages: [ 134 | { 135 | role: "system", 136 | content: "You are a detailed reasoning assistant.", 137 | }, 138 | { 139 | role: "user", 140 | content: "Explain how to solve 2x + 4 = 12 step by step.", 141 | }, 142 | ], 143 | temperature: 0, 144 | reasoning_enabled: false, 145 | reasoning_budget: 1024, 146 | stream: false, 147 | }, 148 | google: { 149 | model: "gpt-4o", 150 | // notice how this is still an intermediate param 151 | // google's api expects a content instead of messages, for example 152 | messages: [ 153 | { 154 | role: "system", 155 | content: "You are a detailed reasoning assistant.", 156 | }, 157 | { 158 | role: "user", 159 | content: "Explain how to solve 2x + 4 = 12 step by step.", 160 | }, 161 | ], 162 | temperature: 0, 163 | thinkingConfig: { 164 | thinkingBudget: 0, 165 | }, 166 | stream: false, 167 | }, 168 | anthropic: { 169 | model: "gpt-4o", 170 | messages: [ 171 | { 172 | // @ts-expect-error -- we use the role to later manipulate the request 173 | role: "system", 174 | content: "You are a detailed reasoning assistant.", 175 | }, 176 | { 177 | role: "user", 178 | content: "Explain how to solve 2x + 4 = 12 step by step.", 179 | }, 180 | ], 181 | temperature: 0, 182 | stream: false, 183 | max_tokens: 1024, 184 | thinking: { 185 | type: "disabled", 186 | }, 187 | }, 188 | }, 189 | "reasoning budget": { 190 | openai: { 191 | model: "gpt-4o", 192 | messages: [ 193 | { 194 | role: "system", 195 | content: "You are a detailed reasoning assistant.", 196 | }, 197 | { 198 | role: "user", 199 | content: "Explain how to solve 2x + 4 = 12 step by step.", 200 | }, 201 | ], 202 | temperature: 0, 203 | reasoning_enabled: true, 204 | reasoning_budget: 1024, 205 | stream: false, 206 | }, 207 | google: { 208 | model: "gpt-4o", 209 | // notice how this is still an intermediate param 210 | // google's api expects a content instead of messages, for example 211 | messages: [ 212 | { 213 | role: "system", 214 | content: "You are a detailed reasoning assistant.", 215 | }, 216 | { 217 | role: "user", 218 | content: "Explain how to solve 2x + 4 = 12 step by step.", 219 | }, 220 | ], 221 | temperature: 0, 222 | thinkingConfig: { 223 | thinkingBudget: 1024, 224 | includeThoughts: true, 225 | }, 226 | stream: false, 227 | }, 228 | anthropic: { 229 | model: "gpt-4o", 230 | messages: [ 231 | { 232 | // @ts-expect-error -- we use the role to later manipulate the request 233 | role: "system", 234 | content: "You are a detailed reasoning assistant.", 235 | }, 236 | { 237 | role: "user", 238 | content: "Explain how to solve 2x + 4 = 12 step by step.", 239 | }, 240 | ], 241 | temperature: 1, 242 | stream: false, 243 | max_tokens: 1536, 244 | thinking: { 245 | budget_tokens: 1024, 246 | type: "enabled", 247 | }, 248 | }, 249 | }, 250 | }; 251 | 252 | Object.entries(examples).forEach(([example, { openai, ...providers }]) => { 253 | Object.entries(providers).forEach(([provider, expected]) => { 254 | it(`[${example}] translate openai to ${provider} params`, () => { 255 | const result = translateParams( 256 | provider as ModelFormat, 257 | openai as unknown as Record, 258 | ); 259 | try { 260 | expect(result).toEqual(expected); 261 | } catch (error) { 262 | console.warn( 263 | `Exact openai -> ${provider} translation failed. Found:`, 264 | JSON.stringify(result, null, 2), 265 | ); 266 | expect.soft(result).toEqual(expected); 267 | } 268 | }); 269 | }); 270 | }); 271 | -------------------------------------------------------------------------------- /packages/proxy/schema/models.test.ts: -------------------------------------------------------------------------------- 1 | import { expect } from "vitest"; 2 | import { it } from "vitest"; 3 | import raw_models from "./model_list.json"; 4 | import { ModelSchema } from "./models"; 5 | import { z } from "zod"; 6 | 7 | it("parse model list", () => { 8 | const models = z.record(z.unknown()).parse(raw_models); 9 | for (const [key, value] of Object.entries(models)) { 10 | const result = ModelSchema.safeParse(value); 11 | if (!result.success) { 12 | console.log("failed to parse ", key, result.error); 13 | } 14 | expect(result.success).toBe(true); 15 | } 16 | }); 17 | -------------------------------------------------------------------------------- /packages/proxy/schema/models.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | 3 | export const PromptInputs = ["chat", "completion"] as const; 4 | export type PromptInputType = (typeof PromptInputs)[number]; 5 | 6 | export const ModelFormats = [ 7 | "openai", 8 | "anthropic", 9 | "google", 10 | "window", 11 | "js", 12 | "converse", 13 | ] as const; 14 | export type ModelFormat = (typeof ModelFormats)[number]; 15 | 16 | export const ModelEndpointType = [ 17 | "openai", 18 | "anthropic", 19 | "google", 20 | "mistral", 21 | "bedrock", 22 | "vertex", 23 | "together", 24 | "fireworks", 25 | "perplexity", 26 | "xAI", 27 | "groq", 28 | "azure", 29 | "databricks", 30 | "lepton", 31 | "cerebras", 32 | "ollama", 33 | "replicate", 34 | "js", 35 | ] as const; 36 | export type ModelEndpointType = (typeof ModelEndpointType)[number]; 37 | 38 | export const ModelSchema = z.object({ 39 | format: z.enum(ModelFormats), 40 | flavor: z.enum(PromptInputs), 41 | multimodal: z.boolean().nullish(), 42 | input_cost_per_token: z.number().nullish(), 43 | output_cost_per_token: z.number().nullish(), 44 | input_cost_per_mil_tokens: z.number().nullish(), 45 | output_cost_per_mil_tokens: z.number().nullish(), 46 | input_cache_read_cost_per_mil_tokens: z.number().nullish(), 47 | input_cache_write_cost_per_mil_tokens: z.number().nullish(), 48 | displayName: z 49 | .string() 50 | .nullish() 51 | .describe("The model is the latest production/stable"), 52 | o1_like: z.boolean().nullish().describe('DEPRECATED use "reasoning" instead'), 53 | reasoning: z 54 | .boolean() 55 | .nullish() 56 | .describe("The model supports reasoning/thinking tokens"), 57 | reasoning_budget: z 58 | .boolean() 59 | .nullish() 60 | .describe("The model supports reasoning/thinking budgets"), 61 | experimental: z 62 | .boolean() 63 | .nullish() 64 | .describe("The model is not allowed production load or API is unstable."), 65 | deprecated: z 66 | .boolean() 67 | .nullish() 68 | .describe( 69 | "Discourage the use of the model (we will hide the model in the UI).", 70 | ), 71 | parent: z.string().nullish().describe("The model was replaced this model."), 72 | endpoint_types: z.array(z.enum(ModelEndpointType)).nullish(), 73 | locations: z.array(z.string()).nullish(), 74 | description: z.string().nullish(), 75 | }); 76 | 77 | export type ModelSpec = z.infer; 78 | 79 | import models from "./model_list.json"; 80 | export const AvailableModels = models as { [name: string]: ModelSpec }; 81 | -------------------------------------------------------------------------------- /packages/proxy/schema/secrets.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | import { ModelSchema } from "./models"; 3 | 4 | export const BaseMetadataSchema = z 5 | .object({ 6 | models: z.array(z.string()).nullish(), 7 | customModels: z.record(ModelSchema).nullish(), 8 | excludeDefaultModels: z.boolean().nullish(), 9 | additionalHeaders: z.record(z.string(), z.string()).nullish(), 10 | supportsStreaming: z.boolean().default(true), 11 | }) 12 | .strict(); 13 | 14 | export const AzureMetadataSchema = BaseMetadataSchema.merge( 15 | z.object({ 16 | api_base: z.string().url(), 17 | api_version: z.string().default("2023-07-01-preview"), 18 | deployment: z.string().nullish(), 19 | auth_type: z.enum(["api_key", "entra_api"]).default("api_key"), 20 | no_named_deployment: z 21 | .boolean() 22 | .default(false) 23 | .describe( 24 | "If true, the deployment name will not be used in the request path.", 25 | ), 26 | }), 27 | ).strict(); 28 | 29 | export const AzureEntraSecretSchema = z.object({ 30 | client_id: z.string().min(1, "Client ID cannot be empty"), 31 | client_secret: z.string().min(1, "Client secret cannot be empty"), 32 | tenant_id: z.string().min(1, "Tenant ID cannot be empty"), 33 | scope: z.string().min(1, "Scope cannot be empty"), 34 | }); 35 | export type AzureEntraSecret = z.infer; 36 | 37 | export const BedrockMetadataSchema = BaseMetadataSchema.merge( 38 | z.object({ 39 | region: z.string().min(1, "Region cannot be empty"), 40 | access_key: z.string().min(1, "Access key cannot be empty"), 41 | session_token: z.string().nullish(), 42 | }), 43 | ).strict(); 44 | export type BedrockMetadata = z.infer; 45 | 46 | export const VertexMetadataSchema = BaseMetadataSchema.merge( 47 | z.object({ 48 | project: z.string().min(1, "Project cannot be empty"), 49 | authType: z.enum(["access_token", "service_account_key"]), 50 | api_base: z.union([z.string().url(), z.string().length(0)]).nullish(), 51 | }), 52 | ).strict(); 53 | 54 | export const DatabricksMetadataSchema = BaseMetadataSchema.merge( 55 | z.object({ 56 | api_base: z.string().url(), 57 | auth_type: z.enum(["pat", "service_principal_oauth"]).default("pat"), 58 | }), 59 | ).strict(); 60 | 61 | export const DatabricksOAuthSecretSchema = z.object({ 62 | client_id: z.string().min(1, "Client ID cannot be empty"), 63 | client_secret: z.string().min(1, "Client secret cannot be empty"), 64 | }); 65 | export type DatabricksOAuthSecret = z.infer; 66 | 67 | export const OpenAIMetadataSchema = BaseMetadataSchema.merge( 68 | z.object({ 69 | api_base: z.union([ 70 | z.string().url().optional(), 71 | z.string().length(0), 72 | z.null(), 73 | ]), 74 | organization_id: z.string().nullish(), 75 | }), 76 | ).strict(); 77 | 78 | export const MistralMetadataSchema = BaseMetadataSchema.merge( 79 | z.object({ 80 | api_base: z.union([z.string().url(), z.string().length(0)]).nullish(), 81 | }), 82 | ).strict(); 83 | 84 | const APISecretBaseSchema = z 85 | .object({ 86 | id: z.string().uuid().nullish(), 87 | org_name: z.string().nullish(), 88 | name: z.string().nullish(), 89 | secret: z.string(), 90 | metadata: z.record(z.unknown()).nullish(), 91 | }) 92 | .strict(); 93 | 94 | export const APISecretSchema = z.union([ 95 | APISecretBaseSchema.merge( 96 | z.object({ 97 | type: z.enum([ 98 | "perplexity", 99 | "anthropic", 100 | "google", 101 | "replicate", 102 | "together", 103 | "ollama", 104 | "groq", 105 | "lepton", 106 | "fireworks", 107 | "cerebras", 108 | "xAI", 109 | "js", 110 | ]), 111 | metadata: BaseMetadataSchema.nullish(), 112 | }), 113 | ), 114 | APISecretBaseSchema.merge( 115 | z.object({ 116 | type: z.literal("openai"), 117 | metadata: OpenAIMetadataSchema.nullish(), 118 | }), 119 | ), 120 | APISecretBaseSchema.merge( 121 | z.object({ 122 | type: z.literal("azure"), 123 | metadata: AzureMetadataSchema.nullish(), 124 | }), 125 | ), 126 | APISecretBaseSchema.merge( 127 | z.object({ 128 | type: z.literal("bedrock"), 129 | metadata: BedrockMetadataSchema.nullish(), 130 | }), 131 | ), 132 | APISecretBaseSchema.merge( 133 | z.object({ 134 | type: z.literal("vertex"), 135 | metadata: VertexMetadataSchema.nullish(), 136 | }), 137 | ), 138 | APISecretBaseSchema.merge( 139 | z.object({ 140 | type: z.literal("databricks"), 141 | metadata: DatabricksMetadataSchema.nullish(), 142 | }), 143 | ), 144 | APISecretBaseSchema.merge( 145 | z.object({ 146 | type: z.literal("mistral"), 147 | metadata: MistralMetadataSchema.nullish(), 148 | }), 149 | ), 150 | ]); 151 | 152 | export type APISecret = z.infer; 153 | 154 | export const proxyLoggingParamSchema = z 155 | .object({ 156 | project_name: z.string(), 157 | compress_audio: z.boolean().default(true), 158 | }) 159 | .describe( 160 | "If present, proxy will log requests to the given Braintrust project name.", 161 | ); 162 | 163 | export type ProxyLoggingParam = z.infer; 164 | 165 | export const credentialsRequestSchema = z 166 | .object({ 167 | model: z 168 | .string() 169 | .nullish() 170 | .describe( 171 | "Granted model name. Null/undefined to grant usage of all models.", 172 | ), 173 | ttl_seconds: z 174 | .number() 175 | .max(60 * 60 * 24) 176 | .default(60 * 10) 177 | .describe("TTL of the temporary credential. 10 minutes by default."), 178 | logging: proxyLoggingParamSchema.nullish(), 179 | }) 180 | .describe("Payload for requesting temporary credentials."); 181 | export type CredentialsRequest = z.infer; 182 | 183 | export const tempCredentialsCacheValueSchema = z 184 | .object({ 185 | authToken: z.string().describe("Braintrust API key."), 186 | }) 187 | .describe("Schema for the proxy's internal credential cache."); 188 | export type TempCredentialsCacheValue = z.infer< 189 | typeof tempCredentialsCacheValueSchema 190 | >; 191 | 192 | export const tempCredentialJwtPayloadSchema = z 193 | .object({ 194 | iss: z.literal("braintrust_proxy"), 195 | aud: z.literal("braintrust_proxy"), 196 | jti: z 197 | .string() 198 | .min(1) 199 | .describe("JWT ID, a unique identifier for this token."), 200 | exp: z.number().describe("Standard JWT expiration field."), 201 | iat: z.number().describe("Standard JWT issued-at field"), 202 | bt: z 203 | .object({ 204 | org_name: z.string().nullish(), 205 | model: z.string().nullish(), 206 | secret: z.string().min(1), 207 | logging: proxyLoggingParamSchema.nullish(), 208 | }) 209 | .describe("Braintrust-specific grants. See credentialsRequestSchema."), 210 | }) 211 | .describe("Braintrust Proxy JWT payload."); 212 | export type TempCredentialJwtPayload = z.infer< 213 | typeof tempCredentialJwtPayloadSchema 214 | >; 215 | -------------------------------------------------------------------------------- /packages/proxy/src/constants.ts: -------------------------------------------------------------------------------- 1 | export const DEFAULT_BRAINTRUST_APP_URL = "https://www.braintrust.dev"; 2 | -------------------------------------------------------------------------------- /packages/proxy/src/deps.test.ts: -------------------------------------------------------------------------------- 1 | import skott from "skott"; 2 | import { describe, expect, it } from "vitest"; 3 | 4 | describe("proxy/src", () => { 5 | it("no circ dependencies", async () => { 6 | const { useGraph } = await skott({ 7 | entrypoint: `${__dirname}/index.ts`, 8 | tsConfigPath: `${__dirname}/../tsconfig.json`, 9 | dependencyTracking: { 10 | builtin: false, 11 | thirdParty: true, 12 | typeOnly: true, 13 | }, 14 | }); 15 | 16 | const { findCircularDependencies } = useGraph(); 17 | 18 | expect(findCircularDependencies()).toEqual([]); 19 | }); 20 | }); 21 | -------------------------------------------------------------------------------- /packages/proxy/src/index.ts: -------------------------------------------------------------------------------- 1 | export * from "./util"; 2 | export * from "./proxy"; 3 | export * from "./metrics"; 4 | -------------------------------------------------------------------------------- /packages/proxy/src/metrics.ts: -------------------------------------------------------------------------------- 1 | import { 2 | DataPoint, 3 | DataPointType, 4 | Histogram, 5 | MeterProvider, 6 | MetricData, 7 | MetricReader, 8 | ResourceMetrics, 9 | } from "@opentelemetry/sdk-metrics"; 10 | import { Resource } from "@opentelemetry/resources"; 11 | import { hrTimeToMicroseconds } from "@opentelemetry/core"; 12 | import { HrTime } from "@opentelemetry/api"; 13 | import { PrometheusSerializer } from "./PrometheusSerializer"; 14 | 15 | export { NOOP_METER_PROVIDER } from "@opentelemetry/api/build/src/metrics/NoopMeterProvider"; 16 | 17 | export function initMetrics( 18 | metricReader: MetricReader, 19 | resourceLabels?: Record, 20 | ) { 21 | const resource = Resource.default().merge( 22 | new Resource({ 23 | ...resourceLabels, 24 | }), 25 | ); 26 | 27 | const myServiceMeterProvider = new MeterProvider({ 28 | resource, 29 | }); 30 | myServiceMeterProvider.addMetricReader(metricReader); 31 | return myServiceMeterProvider; 32 | } 33 | 34 | export async function flushMetrics(meterProvider: MeterProvider) { 35 | await meterProvider.forceFlush(); 36 | } 37 | 38 | // These are copied from prom-client 39 | // https://github.com/siimon/prom-client/blob/master/lib/bucketGenerators.js 40 | export function linearBuckets(start: number, width: number, count: number) { 41 | if (count < 1) { 42 | throw new Error("Linear buckets needs a positive count"); 43 | } 44 | 45 | const buckets = new Array(count); 46 | buckets[0] = 0; 47 | for (let i = 1; i < count; i++) { 48 | buckets[i] = start + i * width; 49 | } 50 | return buckets; 51 | } 52 | 53 | export function exponentialBuckets( 54 | start: number, 55 | factor: number, 56 | count: number, 57 | ) { 58 | if (start <= 0) { 59 | throw new Error("Exponential buckets needs a positive start"); 60 | } 61 | if (count < 1) { 62 | throw new Error("Exponential buckets needs a positive count"); 63 | } 64 | if (factor <= 1) { 65 | throw new Error("Exponential buckets needs a factor greater than 1"); 66 | } 67 | const buckets = new Array(count); 68 | buckets[0] = 0; 69 | for (let i = 1; i < count; i++) { 70 | buckets[i] = start; 71 | start *= factor; 72 | } 73 | return buckets; 74 | } 75 | 76 | export function nowMs() { 77 | return performance?.now ? performance.now() : Date.now(); 78 | } 79 | 80 | export async function aggregateMetrics( 81 | metrics: ResourceMetrics, 82 | cacheGet: (key: string) => Promise, 83 | cachePut: (key: string, value: MetricData) => void, 84 | ): Promise { 85 | for (const scopeMetrics of metrics.scopeMetrics) { 86 | for (const metric of scopeMetrics.metrics) { 87 | for (let i = 0; i < metric.dataPoints.length; i++) { 88 | // NOTE: We should be able to batch these get operations 89 | // into sets of keys at most 128 in length 90 | const metricKey = 91 | "otel_metric_" + 92 | JSON.stringify({ 93 | name: metric.descriptor.name, 94 | dataPointType: metric.dataPointType, 95 | labels: metric.dataPoints[i].attributes, 96 | }); 97 | 98 | let existing = (await cacheGet(metricKey)) || { 99 | ...metric, 100 | dataPoints: [], 101 | }; 102 | if (existing && existing.dataPointType !== metric.dataPointType) { 103 | throw new Error("Invalid data point (type mismatch)"); 104 | } 105 | 106 | let newValue = undefined; 107 | switch (metric.dataPointType) { 108 | case DataPointType.SUM: 109 | newValue = coalesceFn( 110 | existing.dataPoints[0] as DataPoint, 111 | metric.dataPoints[i], 112 | mergeCounters, 113 | ); 114 | break; 115 | case DataPointType.GAUGE: 116 | newValue = coalesceFn( 117 | existing.dataPoints[0] as DataPoint, 118 | metric.dataPoints[i], 119 | mergeGauges, 120 | ); 121 | break; 122 | case DataPointType.HISTOGRAM: 123 | newValue = coalesceFn( 124 | existing.dataPoints[0] as DataPoint, 125 | metric.dataPoints[i], 126 | mergeHistograms, 127 | ); 128 | break; 129 | case DataPointType.EXPONENTIAL_HISTOGRAM: 130 | throw new Error("Not Implemented: Exponential Histogram"); 131 | } 132 | 133 | if (newValue !== undefined) { 134 | (existing as any).descriptor = metric.descriptor; // Update the descriptor in case the code changes it 135 | existing.dataPoints[0] = newValue; 136 | // See "Write buffer behavior" in https://developers.cloudflare.com/durable-objects/api/transactional-storage-api/ 137 | // The only reason to await this put is to apply backpressure, which should be unnecessary given the small # of metrics 138 | // we're aggregating over 139 | cachePut(metricKey, existing); 140 | } 141 | } 142 | } 143 | } 144 | } 145 | 146 | export function prometheusSerialize(metrics: ResourceMetrics): string { 147 | const serializer = new PrometheusSerializer("", false /*appendTimestamp*/); 148 | return serializer.serialize(metrics); 149 | } 150 | 151 | function mergeHistograms( 152 | base: DataPoint, 153 | delta: DataPoint, 154 | ): DataPoint { 155 | if ( 156 | JSON.stringify(base.value.buckets.boundaries) !== 157 | JSON.stringify(delta.value.buckets.boundaries) 158 | ) { 159 | throw new Error( 160 | "Unsupported: merging histograms with different bucket boundaries", 161 | ); 162 | } 163 | 164 | return { 165 | startTime: minHrTime(base.startTime, delta.startTime), 166 | endTime: maxHrTime(base.endTime, delta.endTime), 167 | attributes: { ...base.attributes } /* these are assumed to be the same */, 168 | value: { 169 | buckets: { 170 | boundaries: [...base.value.buckets.boundaries], 171 | counts: base.value.buckets.counts.map( 172 | (count, i) => count + delta.value.buckets.counts[i], 173 | ), 174 | }, 175 | sum: (base.value.sum || 0) + (delta.value.sum || 0), 176 | count: base.value.count + delta.value.count, 177 | min: coalesceFn(base.value.max, delta.value.max, Math.min), 178 | max: coalesceFn(base.value.max, delta.value.max, Math.max), 179 | }, 180 | }; 181 | } 182 | 183 | function mergeGauges( 184 | base: DataPoint, 185 | delta: DataPoint, 186 | ): DataPoint { 187 | const baseT = hrTimeToMicroseconds(base.endTime); 188 | const deltaT = hrTimeToMicroseconds(delta.endTime); 189 | return { 190 | startTime: deltaT >= baseT ? base.startTime : delta.startTime, 191 | endTime: maxHrTime(base.endTime, delta.endTime), 192 | attributes: { ...base.attributes } /* these are assumed to be the same */, 193 | value: deltaT >= baseT ? delta.value : base.value, 194 | }; 195 | } 196 | 197 | function mergeCounters( 198 | base: DataPoint, 199 | delta: DataPoint, 200 | ): DataPoint { 201 | return { 202 | startTime: minHrTime(base.startTime, delta.startTime), 203 | endTime: maxHrTime(base.endTime, delta.endTime), 204 | attributes: { ...base.attributes } /* these are assumed to be the same */, 205 | value: base.value + delta.value, 206 | }; 207 | } 208 | 209 | function minHrTime(a: HrTime, b: HrTime): HrTime { 210 | const at = hrTimeToMicroseconds(a); 211 | const bt = hrTimeToMicroseconds(b); 212 | return at <= bt ? a : b; 213 | } 214 | 215 | function maxHrTime(a: HrTime, b: HrTime): HrTime { 216 | const at = hrTimeToMicroseconds(a); 217 | const bt = hrTimeToMicroseconds(b); 218 | return at >= bt ? a : b; 219 | } 220 | 221 | function coalesceFn( 222 | a: T | undefined, 223 | b: T | undefined, 224 | coalesce: (a: T, b: T) => T, 225 | ): T | undefined { 226 | return a === undefined ? b : b === undefined ? a : coalesce(a, b); 227 | } 228 | -------------------------------------------------------------------------------- /packages/proxy/src/providers/anthropic.test.ts: -------------------------------------------------------------------------------- 1 | import { describe, it, expect } from "vitest"; 2 | import { callProxyV1 } from "../../utils/tests"; 3 | import { 4 | OpenAIChatCompletion, 5 | OpenAIChatCompletionChunk, 6 | OpenAIChatCompletionCreateParams, 7 | } from "@types"; 8 | 9 | it("should convert OpenAI streaming request to Anthropic and back", async () => { 10 | const { events } = await callProxyV1< 11 | OpenAIChatCompletionCreateParams, 12 | OpenAIChatCompletionChunk 13 | >({ 14 | body: { 15 | model: "claude-2", 16 | messages: [ 17 | { role: "system", content: "You are a helpful assistant." }, 18 | { role: "user", content: "Tell me a short joke about programming." }, 19 | ], 20 | stream: true, 21 | max_tokens: 150, 22 | }, 23 | }); 24 | 25 | const streamedEvents = events(); 26 | 27 | expect(streamedEvents.length).toBeGreaterThan(0); 28 | 29 | streamedEvents.forEach((event) => { 30 | expect(event.type).toBe("event"); 31 | 32 | const data = event.data; 33 | expect(data.id).toBeTruthy(); 34 | expect(data.object).toBe("chat.completion.chunk"); 35 | expect(data.created).toBeTruthy(); 36 | expect(Array.isArray(data.choices)).toBe(true); 37 | 38 | if (data.choices[0]?.delta?.content) { 39 | expect(data.choices[0].delta.content.trim()).not.toBe(""); 40 | } 41 | }); 42 | 43 | const hasContent = streamedEvents.some( 44 | (event) => event.data.choices[0]?.delta?.content !== undefined, 45 | ); 46 | expect(hasContent).toBe(true); 47 | }); 48 | 49 | it("should convert OpenAI non-streaming request to Anthropic and back", async () => { 50 | const { json } = await callProxyV1< 51 | OpenAIChatCompletionCreateParams, 52 | OpenAIChatCompletion 53 | >({ 54 | body: { 55 | model: "claude-2.1", 56 | messages: [ 57 | { role: "system", content: "You are a helpful assistant." }, 58 | { role: "user", content: "Tell me a short joke about programming." }, 59 | ], 60 | stream: false, 61 | max_tokens: 150, 62 | }, 63 | }); 64 | 65 | expect(json()).toEqual({ 66 | choices: [ 67 | { 68 | finish_reason: "stop", 69 | index: 0, 70 | logprobs: null, 71 | message: { 72 | content: expect.any(String), 73 | refusal: null, 74 | role: "assistant", 75 | }, 76 | }, 77 | ], 78 | created: expect.any(Number), 79 | id: expect.any(String), 80 | model: "claude-2.1", 81 | object: "chat.completion", 82 | usage: { 83 | completion_tokens: expect.any(Number), 84 | prompt_tokens: expect.any(Number), 85 | total_tokens: expect.any(Number), 86 | prompt_tokens_details: { 87 | cache_creation_tokens: expect.any(Number), 88 | cached_tokens: expect.any(Number), 89 | }, 90 | }, 91 | }); 92 | }); 93 | 94 | it("should accept and return reasoning/thinking params and detail streaming", async () => { 95 | const { events } = await callProxyV1< 96 | OpenAIChatCompletionCreateParams, 97 | OpenAIChatCompletionChunk 98 | >({ 99 | body: { 100 | model: "claude-3-7-sonnet-latest", 101 | reasoning_effort: "medium", 102 | messages: [ 103 | { 104 | role: "user", 105 | content: "How many rs in 'ferrocarril'", 106 | }, 107 | { 108 | role: "assistant", 109 | content: "There are 4 letter 'r's in the word \"ferrocarril\".", 110 | refusal: null, 111 | reasoning: [ 112 | { 113 | id: "ErUBCkYIAxgCIkDWT/7OwDfkVSgdtjIwGqUpzIHQXkiBQQpIqzh6WnHHoGxN1ilJxIlnJQNarUI4Jo/3WWrmRnnqOU3LtAakLr4REgwvY1G5jTSbLHWOo4caDKNco+CyDfNT56iXBCIwrNSFdvNJNsBaa0hpbTZ6N4Q4z4/6l+gu8hniKnftBhS+IuzcncsuJqKxWKs/EVyjKh3tvH/eDeYovKskosVSO5x64iebuze1S8JbavI3UBgC", 114 | content: 115 | "To count the number of 'r's in the word 'ferrocarril', I'll just go through the word letter by letter.\n\n'ferrocarril' has the following letters:\nf-e-r-r-o-c-a-r-r-i-l\n\nLooking at each letter:\n- 'f': not an 'r'\n- 'e': not an 'r'\n- 'r': This is an 'r', so that's 1.\n- 'r': This is an 'r', so that's 2.\n- 'o': not an 'r'\n- 'c': not an 'r'\n- 'a': not an 'r'\n- 'r': This is an 'r', so that's 3.\n- 'r': This is an 'r', so that's 4.\n- 'i': not an 'r'\n- 'l': not an 'r'\n\nSo there are 4 'r's in the word 'ferrocarril'.", 116 | }, 117 | ], 118 | }, 119 | { 120 | role: "user", 121 | content: "How many e in what you said?", 122 | }, 123 | ], 124 | stream: true, 125 | }, 126 | }); 127 | 128 | const streamedEvents = events(); 129 | expect(streamedEvents.length).toBeGreaterThan(0); 130 | 131 | const hasReasoning = streamedEvents.some( 132 | (event) => event.data.choices[0]?.delta?.reasoning?.content !== undefined, 133 | ); 134 | expect(hasReasoning).toBe(true); 135 | 136 | const hasContent = streamedEvents.some( 137 | (event) => event.data.choices[0]?.delta?.content !== undefined, 138 | ); 139 | expect(hasContent).toBe(true); 140 | }); 141 | 142 | it("should accept and return reasoning/thinking params and detail non-streaming", async () => { 143 | const { json } = await callProxyV1< 144 | OpenAIChatCompletionCreateParams, 145 | OpenAIChatCompletionChunk 146 | >({ 147 | body: { 148 | model: "claude-3-7-sonnet-20250219", 149 | reasoning_effort: "medium", 150 | stream: false, 151 | messages: [ 152 | { 153 | role: "user", 154 | content: "How many rs in 'ferrocarril'", 155 | }, 156 | { 157 | role: "assistant", 158 | content: "There are 4 letter 'r's in the word \"ferrocarril\".", 159 | refusal: null, 160 | reasoning: [ 161 | { 162 | id: "ErUBCkYIAxgCIkDWT/7OwDfkVSgdtjIwGqUpzIHQXkiBQQpIqzh6WnHHoGxN1ilJxIlnJQNarUI4Jo/3WWrmRnnqOU3LtAakLr4REgwvY1G5jTSbLHWOo4caDKNco+CyDfNT56iXBCIwrNSFdvNJNsBaa0hpbTZ6N4Q4z4/6l+gu8hniKnftBhS+IuzcncsuJqKxWKs/EVyjKh3tvH/eDeYovKskosVSO5x64iebuze1S8JbavI3UBgC", 163 | content: 164 | "To count the number of 'r's in the word 'ferrocarril', I'll just go through the word letter by letter.\n\n'ferrocarril' has the following letters:\nf-e-r-r-o-c-a-r-r-i-l\n\nLooking at each letter:\n- 'f': not an 'r'\n- 'e': not an 'r'\n- 'r': This is an 'r', so that's 1.\n- 'r': This is an 'r', so that's 2.\n- 'o': not an 'r'\n- 'c': not an 'r'\n- 'a': not an 'r'\n- 'r': This is an 'r', so that's 3.\n- 'r': This is an 'r', so that's 4.\n- 'i': not an 'r'\n- 'l': not an 'r'\n\nSo there are 4 'r's in the word 'ferrocarril'.", 165 | }, 166 | ], 167 | }, 168 | { 169 | role: "user", 170 | content: "How many e in what you said?", 171 | }, 172 | ], 173 | }, 174 | }); 175 | 176 | expect(json()).toEqual({ 177 | choices: [ 178 | { 179 | finish_reason: "stop", 180 | index: 0, 181 | logprobs: null, 182 | message: { 183 | content: expect.any(String), 184 | reasoning: [ 185 | { 186 | content: expect.any(String), 187 | id: expect.any(String), 188 | }, 189 | ], 190 | refusal: null, 191 | role: "assistant", 192 | }, 193 | }, 194 | ], 195 | created: expect.any(Number), 196 | id: expect.any(String), 197 | model: "claude-3-7-sonnet-20250219", 198 | object: "chat.completion", 199 | usage: { 200 | completion_tokens: expect.any(Number), 201 | prompt_tokens: expect.any(Number), 202 | total_tokens: expect.any(Number), 203 | prompt_tokens_details: { 204 | cache_creation_tokens: expect.any(Number), 205 | cached_tokens: expect.any(Number), 206 | }, 207 | }, 208 | }); 209 | }); 210 | 211 | it("should disable reasoning/thinking params non-streaming", async () => { 212 | const { json } = await callProxyV1< 213 | OpenAIChatCompletionCreateParams, 214 | OpenAIChatCompletionChunk 215 | >({ 216 | body: { 217 | model: "claude-3-7-sonnet-20250219", 218 | reasoning_enabled: false, 219 | stream: false, 220 | messages: [ 221 | { 222 | role: "user", 223 | content: "How many rs in 'ferrocarril'", 224 | }, 225 | { 226 | role: "assistant", 227 | content: "There are 4 letter 'r's in the word \"ferrocarril\".", 228 | refusal: null, 229 | reasoning: [ 230 | { 231 | id: "ErUBCkYIAxgCIkDWT/7OwDfkVSgdtjIwGqUpzIHQXkiBQQpIqzh6WnHHoGxN1ilJxIlnJQNarUI4Jo/3WWrmRnnqOU3LtAakLr4REgwvY1G5jTSbLHWOo4caDKNco+CyDfNT56iXBCIwrNSFdvNJNsBaa0hpbTZ6N4Q4z4/6l+gu8hniKnftBhS+IuzcncsuJqKxWKs/EVyjKh3tvH/eDeYovKskosVSO5x64iebuze1S8JbavI3UBgC", 232 | content: 233 | "To count the number of 'r's in the word 'ferrocarril', I'll just go through the word letter by letter.\n\n'ferrocarril' has the following letters:\nf-e-r-r-o-c-a-r-r-i-l\n\nLooking at each letter:\n- 'f': not an 'r'\n- 'e': not an 'r'\n- 'r': This is an 'r', so that's 1.\n- 'r': This is an 'r', so that's 2.\n- 'o': not an 'r'\n- 'c': not an 'r'\n- 'a': not an 'r'\n- 'r': This is an 'r', so that's 3.\n- 'r': This is an 'r', so that's 4.\n- 'i': not an 'r'\n- 'l': not an 'r'\n\nSo there are 4 'r's in the word 'ferrocarril'.", 234 | }, 235 | ], 236 | }, 237 | { 238 | role: "user", 239 | content: "How many e in what you said?", 240 | }, 241 | ], 242 | }, 243 | }); 244 | 245 | expect(json()).toEqual({ 246 | choices: [ 247 | { 248 | finish_reason: "stop", 249 | index: 0, 250 | logprobs: null, 251 | message: { 252 | content: expect.any(String), 253 | refusal: null, 254 | role: "assistant", 255 | }, 256 | }, 257 | ], 258 | created: expect.any(Number), 259 | id: expect.any(String), 260 | model: "claude-3-7-sonnet-20250219", 261 | object: "chat.completion", 262 | usage: { 263 | completion_tokens: expect.any(Number), 264 | prompt_tokens: expect.any(Number), 265 | total_tokens: expect.any(Number), 266 | prompt_tokens_details: { 267 | cache_creation_tokens: expect.any(Number), 268 | cached_tokens: expect.any(Number), 269 | }, 270 | }, 271 | }); 272 | }); 273 | -------------------------------------------------------------------------------- /packages/proxy/src/providers/azure.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | import { AzureEntraSecretSchema } from "@braintrust/proxy/schema"; 3 | 4 | const azureEntraResponseSchema = z.union([ 5 | z.object({ 6 | access_token: z.string(), 7 | token_type: z.literal("Bearer"), 8 | expires_in: z.number(), 9 | }), 10 | z.object({ 11 | error: z.string(), 12 | }), 13 | ]); 14 | 15 | export async function getAzureEntraAccessToken({ 16 | secret, 17 | digest, 18 | cacheGet, 19 | cachePut, 20 | }: { 21 | secret: z.infer; 22 | digest: (message: string) => Promise; 23 | cacheGet: (encryptionKey: string, key: string) => Promise; 24 | cachePut: ( 25 | encryptionKey: string, 26 | key: string, 27 | value: string, 28 | ttl_seconds?: number, 29 | ) => Promise; 30 | }): Promise { 31 | const { client_id, tenant_id, scope, client_secret } = secret; 32 | const tokenUrl = `https://login.microsoftonline.com/${tenant_id}/oauth2/v2.0/token`; 33 | const body = new URLSearchParams({ 34 | client_id, 35 | tenant: tenant_id, 36 | scope, 37 | grant_type: "client_credentials", 38 | client_secret, 39 | }); 40 | 41 | const cachePath = await digest( 42 | `${client_id}:${tenant_id}:${scope}:${client_secret}`, 43 | ); 44 | const cacheKey = `aiproxy/proxy/entra/${cachePath}`; 45 | const encryptionKey = await digest(`${cachePath}:${client_secret}`); 46 | 47 | const cached = await cacheGet(encryptionKey, cacheKey); 48 | if (cached) { 49 | return cached; 50 | } 51 | 52 | const res = await fetch(tokenUrl, { 53 | method: "POST", 54 | headers: { 55 | "Content-Type": "application/x-www-form-urlencoded", 56 | }, 57 | body, 58 | }); 59 | if (!res.ok) { 60 | throw new Error( 61 | `Azure Entra error (${res.status}): ${res.statusText} ${await res.text()}`, 62 | ); 63 | } 64 | const data = await res.json(); 65 | const parsed = azureEntraResponseSchema.parse(data); 66 | if ("error" in parsed) { 67 | throw new Error(`Azure Entra error: ${parsed.error}`); 68 | } 69 | 70 | // Give it a 1 minute buffer. 71 | const cacheTtl = Math.max(parsed.expires_in - 60, 0); 72 | if (cacheTtl > 0) { 73 | await cachePut(encryptionKey, cacheKey, parsed.access_token, cacheTtl); 74 | } 75 | return parsed.access_token; 76 | } 77 | -------------------------------------------------------------------------------- /packages/proxy/src/providers/databricks.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | import { DatabricksOAuthSecretSchema } from "@braintrust/proxy/schema"; 3 | 4 | const databricksOAuthResponseSchema = z.union([ 5 | z.object({ 6 | access_token: z.string(), 7 | token_type: z.literal("Bearer"), 8 | expires_in: z.number(), 9 | }), 10 | z.object({ 11 | error: z.string(), 12 | }), 13 | ]); 14 | 15 | export async function getDatabricksOAuthAccessToken({ 16 | secret, 17 | apiBase, 18 | digest, 19 | cacheGet, 20 | cachePut, 21 | }: { 22 | secret: z.infer; 23 | apiBase: string; 24 | digest: (message: string) => Promise; 25 | cacheGet: (encryptionKey: string, key: string) => Promise; 26 | cachePut: ( 27 | encryptionKey: string, 28 | key: string, 29 | value: string, 30 | ttl_seconds?: number, 31 | ) => Promise; 32 | }): Promise { 33 | const { client_id, client_secret } = secret; 34 | const tokenUrl = `${apiBase}/oidc/v1/token`; 35 | 36 | const cachePath = await digest(`${client_id}:${client_secret}:${apiBase}`); 37 | const cacheKey = `aiproxy/proxy/databricks/${cachePath}`; 38 | const encryptionKey = await digest(`${cachePath}:${client_secret}`); 39 | 40 | const cached = await cacheGet(encryptionKey, cacheKey); 41 | if (cached) { 42 | return cached; 43 | } 44 | 45 | // Create credentials for basic auth. 46 | const credentials = Buffer.from(`${client_id}:${client_secret}`).toString( 47 | "base64", 48 | ); 49 | const res = await fetch(tokenUrl, { 50 | method: "POST", 51 | headers: { 52 | "Content-Type": "application/x-www-form-urlencoded", 53 | Authorization: `Basic ${credentials}`, 54 | }, 55 | body: new URLSearchParams({ 56 | grant_type: "client_credentials", 57 | scope: "all-apis", 58 | }), 59 | }); 60 | if (!res.ok) { 61 | throw new Error( 62 | `Databricks OAuth error (${res.status}): ${res.statusText} ${await res.text()}`, 63 | ); 64 | } 65 | 66 | const data = await res.json(); 67 | const parsed = databricksOAuthResponseSchema.parse(data); 68 | if ("error" in parsed) { 69 | throw new Error(`Databricks OAuth error: ${parsed.error}`); 70 | } 71 | 72 | // Give it a 1 minute buffer. 73 | const cacheTtl = Math.max(parsed.expires_in - 60, 0); 74 | if (cacheTtl > 0) { 75 | await cachePut(encryptionKey, cacheKey, parsed.access_token, cacheTtl); 76 | } 77 | 78 | return parsed.access_token; 79 | } 80 | -------------------------------------------------------------------------------- /packages/proxy/src/providers/google.test.ts: -------------------------------------------------------------------------------- 1 | import { describe, it, expect } from "vitest"; 2 | import { callProxyV1 } from "../../utils/tests"; 3 | import { 4 | OpenAIChatCompletionChunk, 5 | OpenAIChatCompletionCreateParams, 6 | } from "@types"; 7 | 8 | for (const model of [ 9 | "gemini-2.5-flash-preview-05-20", 10 | // TODO: re-enable when we have a working CI/CD solution 11 | // "publishers/google/models/gemini-2.5-flash-preview-05-20", 12 | ]) { 13 | describe(model, () => { 14 | it("should accept and should not return reasoning/thinking params and detail streaming", async () => { 15 | const { events, json } = await callProxyV1< 16 | OpenAIChatCompletionCreateParams, 17 | OpenAIChatCompletionChunk 18 | >({ 19 | body: { 20 | model, 21 | reasoning_effort: "medium", 22 | messages: [ 23 | { 24 | role: "user", 25 | content: "How many rs in 'ferrocarril'", 26 | }, 27 | { 28 | role: "assistant", 29 | content: "There are 4 letter 'r's in the word \"ferrocarril\".", 30 | refusal: null, 31 | reasoning: [ 32 | { 33 | id: "", 34 | content: 35 | "To count the number of 'r's in the word 'ferrocarril', I'll just go through the word letter by letter.\n\n'ferrocarril' has the following letters:\nf-e-r-r-o-c-a-r-r-i-l\n\nLooking at each letter:\n- 'f': not an 'r'\n- 'e': not an 'r'\n- 'r': This is an 'r', so that's 1.\n- 'r': This is an 'r', so that's 2.\n- 'o': not an 'r'\n- 'c': not an 'r'\n- 'a': not an 'r'\n- 'r': This is an 'r', so that's 3.\n- 'r': This is an 'r', so that's 4.\n- 'i': not an 'r'\n- 'l': not an 'r'\n\nSo there are 4 'r's in the word 'ferrocarril'.", 36 | }, 37 | ], 38 | }, 39 | { 40 | role: "user", 41 | content: "How many e in what you said?", 42 | }, 43 | ], 44 | stream: true, 45 | }, 46 | }); 47 | 48 | const streamedEvents = events(); 49 | expect(streamedEvents.length).toBeGreaterThan(0); 50 | 51 | const hasContent = streamedEvents.some( 52 | (event) => event.data.choices[0]?.delta?.content !== undefined, 53 | ); 54 | expect(hasContent).toBe(true); 55 | 56 | const hasReasoning = streamedEvents.some( 57 | (event) => 58 | event.data.choices[0]?.delta?.reasoning?.content !== undefined, 59 | ); 60 | expect(hasReasoning).toBe(true); 61 | }); 62 | 63 | it("should accept and return reasoning/thinking params and detail non-streaming", async () => { 64 | const { json } = await callProxyV1< 65 | OpenAIChatCompletionCreateParams, 66 | OpenAIChatCompletionChunk 67 | >({ 68 | body: { 69 | model, 70 | reasoning_effort: "medium", 71 | stream: false, 72 | messages: [ 73 | { 74 | role: "user", 75 | content: "How many rs in 'ferrocarril'", 76 | }, 77 | { 78 | role: "assistant", 79 | content: "There are 4 letter 'r's in the word \"ferrocarril\".", 80 | refusal: null, 81 | reasoning: [ 82 | { 83 | id: "", 84 | content: 85 | "To count the number of 'r's in the word 'ferrocarril', I'll just go through the word letter by letter.\n\n'ferrocarril' has the following letters:\nf-e-r-r-o-c-a-r-r-i-l\n\nLooking at each letter:\n- 'f': not an 'r'\n- 'e': not an 'r'\n- 'r': This is an 'r', so that's 1.\n- 'r': This is an 'r', so that's 2.\n- 'o': not an 'r'\n- 'c': not an 'r'\n- 'a': not an 'r'\n- 'r': This is an 'r', so that's 3.\n- 'r': This is an 'r', so that's 4.\n- 'i': not an 'r'\n- 'l': not an 'r'\n\nSo there are 4 'r's in the word 'ferrocarril'.", 86 | }, 87 | ], 88 | }, 89 | { 90 | role: "user", 91 | content: "How many e in what you said?", 92 | }, 93 | ], 94 | }, 95 | }); 96 | 97 | expect(json()).toEqual({ 98 | choices: [ 99 | { 100 | finish_reason: "stop", 101 | index: 0, 102 | logprobs: null, 103 | message: { 104 | content: expect.any(String), 105 | reasoning: [ 106 | { 107 | id: expect.any(String), 108 | content: expect.any(String), 109 | }, 110 | ], 111 | refusal: null, 112 | role: "assistant", 113 | }, 114 | }, 115 | ], 116 | created: expect.any(Number), 117 | id: expect.any(String), 118 | model, 119 | object: "chat.completion", 120 | usage: { 121 | completion_tokens: expect.any(Number), 122 | completion_tokens_details: { 123 | reasoning_tokens: expect.any(Number), 124 | }, 125 | prompt_tokens: expect.any(Number), 126 | total_tokens: expect.any(Number), 127 | }, 128 | }); 129 | }); 130 | 131 | it("should disable reasoning/thinking non-streaming", async () => { 132 | const { json } = await callProxyV1< 133 | OpenAIChatCompletionCreateParams, 134 | OpenAIChatCompletionChunk 135 | >({ 136 | body: { 137 | model, 138 | reasoning_enabled: true, 139 | reasoning_budget: 0, 140 | stream: false, 141 | messages: [ 142 | { 143 | role: "user", 144 | content: "How many rs in 'ferrocarril'", 145 | }, 146 | { 147 | role: "assistant", 148 | content: "There are 4 letter 'r's in the word \"ferrocarril\".", 149 | refusal: null, 150 | reasoning: [ 151 | { 152 | id: "", 153 | content: 154 | "To count the number of 'r's in the word 'ferrocarril', I'll just go through the word letter by letter.\n\n'ferrocarril' has the following letters:\nf-e-r-r-o-c-a-r-r-i-l\n\nLooking at each letter:\n- 'f': not an 'r'\n- 'e': not an 'r'\n- 'r': This is an 'r', so that's 1.\n- 'r': This is an 'r', so that's 2.\n- 'o': not an 'r'\n- 'c': not an 'r'\n- 'a': not an 'r'\n- 'r': This is an 'r', so that's 3.\n- 'r': This is an 'r', so that's 4.\n- 'i': not an 'r'\n- 'l': not an 'r'\n\nSo there are 4 'r's in the word 'ferrocarril'.", 155 | }, 156 | ], 157 | }, 158 | { 159 | role: "user", 160 | content: "How many e in what you said?", 161 | }, 162 | ], 163 | }, 164 | }); 165 | 166 | expect(json()).toEqual({ 167 | choices: [ 168 | { 169 | finish_reason: "stop", 170 | index: 0, 171 | logprobs: null, 172 | message: { 173 | content: expect.any(String), 174 | refusal: null, 175 | role: "assistant", 176 | }, 177 | }, 178 | ], 179 | created: expect.any(Number), 180 | id: expect.any(String), 181 | model, 182 | object: "chat.completion", 183 | usage: { 184 | completion_tokens: expect.any(Number), 185 | prompt_tokens: expect.any(Number), 186 | total_tokens: expect.any(Number), 187 | }, 188 | }); 189 | }); 190 | }); 191 | } 192 | -------------------------------------------------------------------------------- /packages/proxy/src/providers/openai.ts: -------------------------------------------------------------------------------- 1 | import { 2 | ChatCompletionChunk, 3 | ChatCompletion, 4 | ChatCompletionMessageParam, 5 | ChatCompletionContentPart, 6 | } from "openai/resources"; 7 | import { base64ToUrl, convertBase64Media, convertMediaToBase64 } from "./util"; 8 | import { parseFileMetadataFromUrl } from "../util"; 9 | 10 | function openAIChatCompletionToChatEvent( 11 | completion: ChatCompletion, 12 | ): ChatCompletionChunk { 13 | return { 14 | id: completion.id, 15 | choices: completion.choices.map((choice) => ({ 16 | index: choice.index, 17 | delta: { 18 | role: choice.message.role, 19 | content: choice.message.content || "", 20 | tool_calls: choice.message.tool_calls 21 | ? choice.message.tool_calls.map((tool_call, index) => ({ 22 | index, 23 | id: tool_call.id, 24 | function: tool_call.function, 25 | type: tool_call.type, 26 | })) 27 | : undefined, 28 | }, 29 | finish_reason: choice.finish_reason, 30 | })), 31 | created: completion.created, 32 | model: completion.model, 33 | object: "chat.completion.chunk", 34 | usage: completion.usage, 35 | }; 36 | } 37 | 38 | export function makeFakeOpenAIStreamTransformer() { 39 | let responseChunks: Uint8Array[] = []; 40 | return new TransformStream({ 41 | transform(chunk, controller) { 42 | responseChunks.push(chunk); 43 | }, 44 | flush(controller) { 45 | const decoder = new TextDecoder(); 46 | const responseText = responseChunks 47 | .map((c) => decoder.decode(c)) 48 | .join(""); 49 | let responseJson: ChatCompletion = { 50 | id: "invalid", 51 | choices: [], 52 | created: 0, 53 | model: "invalid", 54 | object: "chat.completion", 55 | usage: { 56 | prompt_tokens: 0, 57 | completion_tokens: 0, 58 | total_tokens: 0, 59 | }, 60 | }; 61 | try { 62 | responseJson = JSON.parse(responseText); 63 | } catch (e) { 64 | console.error("Failed to parse response as JSON", responseText); 65 | } 66 | controller.enqueue( 67 | new TextEncoder().encode( 68 | `data: ${JSON.stringify(openAIChatCompletionToChatEvent(responseJson))}\n\n`, 69 | ), 70 | ); 71 | controller.enqueue(new TextEncoder().encode(`data: [DONE]\n\n`)); 72 | controller.terminate(); 73 | }, 74 | }); 75 | } 76 | 77 | export async function normalizeOpenAIMessages( 78 | messages: ChatCompletionMessageParam[], 79 | ): Promise { 80 | return Promise.all( 81 | messages.map(async (message) => { 82 | if ( 83 | message.role === "user" && 84 | message.content && 85 | typeof message.content !== "string" 86 | ) { 87 | message.content = await Promise.all( 88 | message.content.map( 89 | async (c): Promise => 90 | await normalizeOpenAIContent(c), 91 | ), 92 | ); 93 | } 94 | // not part of the openai spec 95 | if ("reasoning" in message) { 96 | delete message.reasoning; 97 | } 98 | return message; 99 | }), 100 | ); 101 | } 102 | 103 | // https://platform.openai.com/docs/guides/pdf-files?api-mode=chat 104 | export async function normalizeOpenAIContent( 105 | content: ChatCompletionContentPart, 106 | ): Promise { 107 | if (typeof content === "string") { 108 | return content; 109 | } 110 | switch (content.type) { 111 | case "image_url": 112 | const mediaBlock = convertBase64Media(content.image_url.url); 113 | if (mediaBlock?.media_type.startsWith("image/")) { 114 | return content; 115 | } else if (mediaBlock) { 116 | // Let OpenAI validate the mime type of the base64 encoded input file 117 | // As of 05/20/25 this supports .pdf and appears to have limited support for .csv, .xlsx, .docx, and .pptx 118 | // but is not clearly documented 119 | return { 120 | type: "file", 121 | file: { 122 | filename: "file_from_base64", 123 | file_data: content.image_url.url, 124 | }, 125 | }; 126 | } 127 | 128 | const parsed = parseFileMetadataFromUrl(content.image_url.url); 129 | if ( 130 | parsed?.filename?.endsWith(".pdf") || 131 | parsed?.contentType === "application/pdf" 132 | ) { 133 | const base64 = await convertMediaToBase64({ 134 | media: content.image_url.url, 135 | allowedMediaTypes: ["application/pdf"], 136 | maxMediaBytes: 20 * 1024 * 1024, 137 | }); 138 | return { 139 | type: "file", 140 | file: { 141 | filename: parsed.filename, 142 | file_data: base64ToUrl(base64), 143 | }, 144 | }; 145 | } else if ( 146 | content.image_url.url.startsWith("http://127.0.0.1") || 147 | content.image_url.url.startsWith("http://localhost") 148 | ) { 149 | const base64 = await convertMediaToBase64({ 150 | media: content.image_url.url, 151 | allowedMediaTypes: null, 152 | maxMediaBytes: 20 * 1024 * 1024, 153 | }); 154 | return { 155 | type: "image_url", 156 | image_url: { 157 | url: base64ToUrl(base64), 158 | }, 159 | }; 160 | } 161 | return content; 162 | default: 163 | return content; 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /packages/proxy/src/providers/util.ts: -------------------------------------------------------------------------------- 1 | import { arrayBufferToBase64 } from "utils"; 2 | 3 | const base64MediaPattern = 4 | /^data:([a-zA-Z0-9]+\/[a-zA-Z0-9+.-]+);base64,([A-Za-z0-9+/]+={0,2})$/; 5 | 6 | interface MediaBlock { 7 | media_type: string; 8 | data: string; 9 | } 10 | 11 | export function convertBase64Media(media: string): MediaBlock | null { 12 | const match = media.match(base64MediaPattern); 13 | if (!match) { 14 | return null; 15 | } 16 | 17 | const [, media_type, data] = match; 18 | return { 19 | media_type, 20 | data, 21 | }; 22 | } 23 | 24 | async function convertMediaUrl({ 25 | url, 26 | allowedMediaTypes, 27 | maxMediaBytes, 28 | }: { 29 | url: string; 30 | allowedMediaTypes: string[] | null; 31 | maxMediaBytes: number | null; 32 | }): Promise { 33 | const response = await fetch(url); 34 | if (!response.ok) { 35 | throw new Error(`Failed to fetch media: ${response.statusText}`); 36 | } 37 | 38 | const contentType = response.headers.get("content-type"); 39 | if (!contentType) { 40 | throw new Error("Failed to get content type of the media"); 41 | } 42 | const baseContentType = contentType.split(";")[0].trim(); 43 | if ( 44 | allowedMediaTypes !== null && 45 | !allowedMediaTypes.includes(baseContentType) 46 | ) { 47 | throw new Error(`Unsupported media type: ${baseContentType}`); 48 | } 49 | 50 | const arrayBuffer = await response.arrayBuffer(); 51 | if (maxMediaBytes !== null && arrayBuffer.byteLength > maxMediaBytes) { 52 | throw new Error( 53 | `Media size exceeds the ${maxMediaBytes / 1024 / 1024} MB limit`, 54 | ); 55 | } 56 | 57 | const data = arrayBufferToBase64(arrayBuffer); 58 | 59 | return { 60 | media_type: baseContentType, 61 | data, 62 | }; 63 | } 64 | 65 | export async function convertMediaToBase64({ 66 | media, 67 | allowedMediaTypes, 68 | maxMediaBytes, 69 | }: { 70 | media: string; 71 | allowedMediaTypes: string[] | null; 72 | maxMediaBytes: number | null; 73 | }): Promise { 74 | const mediaBlock = convertBase64Media(media); 75 | if (mediaBlock) { 76 | return mediaBlock; 77 | } else { 78 | return await convertMediaUrl({ 79 | url: media, 80 | allowedMediaTypes, 81 | maxMediaBytes, 82 | }); 83 | } 84 | } 85 | 86 | export function base64ToUrl(base64: MediaBlock): string { 87 | return `data:${base64.media_type};base64,${base64.data}`; 88 | } 89 | -------------------------------------------------------------------------------- /packages/proxy/src/util.test.ts: -------------------------------------------------------------------------------- 1 | import { describe, expect, test } from "vitest"; 2 | import { parseFileMetadataFromUrl } from "./util"; 3 | 4 | describe("parseFileMetadataFromUrl", () => { 5 | test("handles basic URLs", () => { 6 | expect(parseFileMetadataFromUrl("https://example.com/file.pdf")).toEqual({ 7 | filename: "file.pdf", 8 | url: expect.any(URL), 9 | }); 10 | expect(parseFileMetadataFromUrl("http://foo.com/bar/example.pdf")).toEqual({ 11 | filename: "example.pdf", 12 | url: expect.any(URL), 13 | }); 14 | }); 15 | 16 | test("handles URLs with query parameters", () => { 17 | expect( 18 | parseFileMetadataFromUrl("https://example.com/file.pdf?query=value"), 19 | ).toEqual({ filename: "file.pdf", url: expect.any(URL) }); 20 | expect( 21 | parseFileMetadataFromUrl("http://foo.com/doc.pdf?v=1&id=123"), 22 | ).toEqual({ filename: "doc.pdf", url: expect.any(URL) }); 23 | expect( 24 | parseFileMetadataFromUrl("https://site.com/download.pdf?token=abc123"), 25 | ).toEqual({ filename: "download.pdf", url: expect.any(URL) }); 26 | expect( 27 | parseFileMetadataFromUrl( 28 | "http://example.com/report.pdf?token=example%20with%20spaces", 29 | ), 30 | ).toEqual({ filename: "report.pdf", url: expect.any(URL) }); 31 | }); 32 | 33 | test("handles filenames with spaces and special characters", () => { 34 | expect( 35 | parseFileMetadataFromUrl("https://example.com/my%20file.pdf"), 36 | ).toEqual({ filename: "my file.pdf", url: expect.any(URL) }); 37 | expect(parseFileMetadataFromUrl("http://foo.com/report-2023.pdf")).toEqual({ 38 | filename: "report-2023.pdf", 39 | url: expect.any(URL), 40 | }); 41 | expect(parseFileMetadataFromUrl("https://site.com/exa%20mple.pdf")).toEqual( 42 | { filename: "exa mple.pdf", url: expect.any(URL) }, 43 | ); 44 | expect( 45 | parseFileMetadataFromUrl("http://example.com/file%20with%20spaces.pdf"), 46 | ).toEqual({ filename: "file with spaces.pdf", url: expect.any(URL) }); 47 | expect( 48 | parseFileMetadataFromUrl( 49 | "https://example.com/file-name_with.special-chars.pdf", 50 | ), 51 | ).toEqual({ 52 | filename: "file-name_with.special-chars.pdf", 53 | url: expect.any(URL), 54 | }); 55 | expect( 56 | parseFileMetadataFromUrl("http://site.org/file%25with%25percent.pdf"), 57 | ).toEqual({ filename: "file%with%percent.pdf", url: expect.any(URL) }); 58 | expect( 59 | parseFileMetadataFromUrl("https://example.com/file+with+plus.pdf"), 60 | ).toEqual({ filename: "file+with+plus.pdf", url: expect.any(URL) }); 61 | }); 62 | 63 | test("handles pathless URLs", () => { 64 | expect(parseFileMetadataFromUrl("https://example.pdf")).toBeUndefined(); 65 | expect(parseFileMetadataFromUrl("file.pdf")).toBeUndefined(); 66 | expect(parseFileMetadataFromUrl("folder/file.pdf")).toBeUndefined(); 67 | }); 68 | 69 | test("handles URLs with fragments", () => { 70 | expect( 71 | parseFileMetadataFromUrl("https://example.com/document.pdf#page=1"), 72 | ).toEqual({ filename: "document.pdf", url: expect.any(URL) }); 73 | expect( 74 | parseFileMetadataFromUrl("http://site.com/resume.pdf#section"), 75 | ).toEqual({ filename: "resume.pdf", url: expect.any(URL) }); 76 | expect( 77 | parseFileMetadataFromUrl( 78 | "https://example.com/file.pdf#fragment=with=equals", 79 | ), 80 | ).toEqual({ filename: "file.pdf", url: expect.any(URL) }); 81 | }); 82 | 83 | test("handles URLs with both query parameters and fragments", () => { 84 | expect( 85 | parseFileMetadataFromUrl( 86 | "https://example.com/report.pdf?version=2#page=5", 87 | ), 88 | ).toEqual({ filename: "report.pdf", url: expect.any(URL) }); 89 | expect( 90 | parseFileMetadataFromUrl( 91 | "http://site.org/document.pdf?dl=true#section=summary", 92 | ), 93 | ).toEqual({ filename: "document.pdf", url: expect.any(URL) }); 94 | expect( 95 | parseFileMetadataFromUrl("https://example.com/file.pdf?a=1&b=2#c=3&d=4"), 96 | ).toEqual({ filename: "file.pdf", url: expect.any(URL) }); 97 | }); 98 | 99 | test("returns undefined for URLs with uninferrable file names", () => { 100 | expect( 101 | parseFileMetadataFromUrl("http://foo.com/bar/?file=example.pdf"), 102 | ).toBeUndefined(); 103 | expect(parseFileMetadataFromUrl("http://foo.com/bar/")).toBeUndefined(); 104 | expect(parseFileMetadataFromUrl("http://foo.com")).toBeUndefined(); 105 | }); 106 | 107 | test("returns undefined for non-standard URL formats", () => { 108 | expect( 109 | parseFileMetadataFromUrl("http://foo.com/bar/?file=example.pdf"), 110 | ).toBeUndefined(); 111 | expect(parseFileMetadataFromUrl("gs://bucket/file.pdf")).toBeUndefined(); 112 | expect( 113 | parseFileMetadataFromUrl("ftp://files.org/documents/sample.pdf"), 114 | ).toBeUndefined(); 115 | expect( 116 | parseFileMetadataFromUrl("s3://my-bucket/backup/archive.pdf"), 117 | ).toBeUndefined(); 118 | expect( 119 | parseFileMetadataFromUrl("file:///C:/Users/name/Documents/file.pdf"), 120 | ).toBeUndefined(); 121 | expect( 122 | parseFileMetadataFromUrl( 123 | "sftp://username:password@server.com/path/to/file.pdf", 124 | ), 125 | ).toBeUndefined(); 126 | }); 127 | 128 | test("returns undefined for URLs without filename", () => { 129 | expect(parseFileMetadataFromUrl("https://example.com/")).toBeUndefined(); 130 | expect(parseFileMetadataFromUrl("http://site.org")).toBeUndefined(); 131 | expect(parseFileMetadataFromUrl("")).toBeUndefined(); 132 | expect(parseFileMetadataFromUrl(" ")).toBeUndefined(); 133 | expect(parseFileMetadataFromUrl(null as unknown as string)).toBeUndefined(); 134 | expect( 135 | parseFileMetadataFromUrl(undefined as unknown as string), 136 | ).toBeUndefined(); 137 | }); 138 | 139 | test("handles different file extensions", () => { 140 | expect( 141 | parseFileMetadataFromUrl("https://example.com/document.docx"), 142 | ).toEqual({ filename: "document.docx", url: expect.any(URL) }); 143 | expect( 144 | parseFileMetadataFromUrl("https://example.com/spreadsheet.xlsx"), 145 | ).toEqual({ filename: "spreadsheet.xlsx", url: expect.any(URL) }); 146 | expect( 147 | parseFileMetadataFromUrl("https://example.com/presentation.pptx"), 148 | ).toEqual({ filename: "presentation.pptx", url: expect.any(URL) }); 149 | expect(parseFileMetadataFromUrl("https://example.com/archive.zip")).toEqual( 150 | { filename: "archive.zip", url: expect.any(URL) }, 151 | ); 152 | expect(parseFileMetadataFromUrl("https://example.com/image.jpg")).toEqual({ 153 | filename: "image.jpg", 154 | url: expect.any(URL), 155 | }); 156 | expect(parseFileMetadataFromUrl("https://example.com/video.mp4")).toEqual({ 157 | filename: "video.mp4", 158 | url: expect.any(URL), 159 | }); 160 | expect(parseFileMetadataFromUrl("https://example.com/data.json")).toEqual({ 161 | filename: "data.json", 162 | url: expect.any(URL), 163 | }); 164 | expect(parseFileMetadataFromUrl("https://example.com/page.html")).toEqual({ 165 | filename: "page.html", 166 | url: expect.any(URL), 167 | }); 168 | }); 169 | 170 | test("handles complex URL encodings", () => { 171 | expect( 172 | parseFileMetadataFromUrl( 173 | "https://example.com/file%20with%20spaces%20and%20%23%20symbols.pdf", 174 | ), 175 | ).toEqual({ 176 | filename: "file with spaces and # symbols.pdf", 177 | url: expect.any(URL), 178 | }); 179 | expect( 180 | parseFileMetadataFromUrl("https://example.com/%E6%96%87%E4%BB%B6.pdf"), 181 | ).toEqual({ filename: "文件.pdf", url: expect.any(URL) }); 182 | expect( 183 | parseFileMetadataFromUrl("https://example.com/r%C3%A9sum%C3%A9.pdf"), 184 | ).toEqual({ filename: "résumé.pdf", url: expect.any(URL) }); 185 | expect( 186 | parseFileMetadataFromUrl("https://example.com/file%2Bwith%2Bplus.pdf"), 187 | ).toEqual({ filename: "file+with+plus.pdf", url: expect.any(URL) }); 188 | expect( 189 | parseFileMetadataFromUrl( 190 | "https://example.com/file%3Fwith%3Fquestion.pdf", 191 | ), 192 | ).toEqual({ filename: "file?with?question.pdf", url: expect.any(URL) }); 193 | }); 194 | 195 | test("handles S3 pre-signed URLs", () => { 196 | expect( 197 | parseFileMetadataFromUrl( 198 | "https://somes3subdomain.s3.amazonaws.com/files/e1ebccc2-4006-434e-a739-cba3b3fd85dd?X-Amz-Expires=86400&response-content-disposition=attachment%3B%20filename%3D%22test.pdf%22&response-content-type=application%2Fpdf&x-id=GetObject", 199 | ), 200 | ).toEqual({ 201 | filename: "test.pdf", 202 | contentType: "application/pdf", 203 | url: expect.any(URL), 204 | }); 205 | }); 206 | }); 207 | -------------------------------------------------------------------------------- /packages/proxy/src/util.ts: -------------------------------------------------------------------------------- 1 | import contentDisposition from "content-disposition"; 2 | export interface ModelResponse { 3 | stream: ReadableStream | null; 4 | response: Response; 5 | } 6 | 7 | export function parseAuthHeader( 8 | headers: Record, 9 | ): string | null { 10 | const authHeader = headers["authorization"]; 11 | let authValue = null; 12 | if (Array.isArray(authHeader)) { 13 | authValue = authHeader[authHeader.length - 1]; 14 | } else { 15 | authValue = authHeader; 16 | } 17 | 18 | if (authValue) { 19 | const parts = authValue.split(" "); 20 | if (parts.length !== 2) { 21 | return null; 22 | } 23 | return parts[1]; 24 | } 25 | 26 | // Anthropic uses x-api-key instead of authorization. 27 | const apiKeyHeader = headers["x-api-key"]; 28 | if (apiKeyHeader) { 29 | return Array.isArray(apiKeyHeader) 30 | ? apiKeyHeader[apiKeyHeader.length - 1] 31 | : apiKeyHeader; 32 | } 33 | 34 | return null; 35 | } 36 | 37 | export function parseNumericHeader( 38 | headers: Record, 39 | headerKey: string, 40 | ): number | null { 41 | let value = headers[headerKey]; 42 | if (Array.isArray(value)) { 43 | value = value[value.length - 1]; 44 | } 45 | 46 | if (value !== undefined) { 47 | try { 48 | return parseInt(value, 10); 49 | } catch (e) {} 50 | } 51 | 52 | return null; 53 | } 54 | 55 | // This is duplicated from app/utils/object.ts 56 | export function isObject(value: any): value is { [key: string]: any } { 57 | return value instanceof Object && !(value instanceof Array); 58 | } 59 | 60 | export function getTimestampInSeconds() { 61 | return Math.floor(Date.now() / 1000); 62 | } 63 | 64 | export function flattenChunksArray(allChunks: Uint8Array[]): Uint8Array { 65 | const flatArray = new Uint8Array(allChunks.reduce((a, b) => a + b.length, 0)); 66 | for (let i = 0, offset = 0; i < allChunks.length; i++) { 67 | flatArray.set(allChunks[i], offset); 68 | offset += allChunks[i].length; 69 | } 70 | return flatArray; 71 | } 72 | 73 | export function flattenChunks(allChunks: Uint8Array[]) { 74 | const flatArray = flattenChunksArray(allChunks); 75 | return new TextDecoder().decode(flatArray); 76 | } 77 | 78 | export function isEmpty(a: any): a is null | undefined { 79 | return a === undefined || a === null; 80 | } 81 | 82 | export function getRandomInt(max: number) { 83 | return Math.floor(Math.random() * max); 84 | } 85 | 86 | export class ProxyBadRequestError extends Error { 87 | constructor(public message: string) { 88 | super(message); 89 | } 90 | } 91 | 92 | export function parseFileMetadataFromUrl( 93 | url: string, 94 | ): { filename: string; contentType?: string; url: URL } | undefined { 95 | try { 96 | // Handle empty string 97 | if (!url || url.trim() === "") { 98 | return undefined; 99 | } 100 | 101 | // Use URL to parse complex URLs rather than string splitting 102 | let parsedUrl: URL | undefined; 103 | try { 104 | parsedUrl = new URL(url); 105 | } catch (e) { 106 | return undefined; 107 | } 108 | 109 | // If the URL is not http(s), file cannot be accessed 110 | // If pathname is empty or ends with "/", there's no filename to extract 111 | if (parsedUrl.protocol !== "http:" && parsedUrl.protocol !== "https:") { 112 | return undefined; 113 | } else if ( 114 | !parsedUrl.pathname || 115 | parsedUrl.pathname === "/" || 116 | parsedUrl.pathname.endsWith("/") 117 | ) { 118 | return undefined; 119 | } 120 | 121 | // Get the last segment of the path 122 | let filename = parsedUrl.pathname.split("/").pop(); 123 | if (!filename) { 124 | return undefined; 125 | } 126 | 127 | let contentType = undefined; 128 | 129 | // Handle case where this is an S3 pre-signed URL 130 | if (parsedUrl.searchParams.get("X-Amz-Expires") !== null) { 131 | const disposition = contentDisposition.parse( 132 | parsedUrl.searchParams.get("response-content-disposition") || "", 133 | ); 134 | filename = disposition.parameters.filename 135 | ? decodeURIComponent(disposition.parameters.filename) 136 | : filename; 137 | contentType = 138 | parsedUrl.searchParams.get("response-content-type") ?? undefined; 139 | } 140 | 141 | try { 142 | filename = decodeURIComponent(filename); 143 | } catch (e) { 144 | // If the filename is not valid UTF-8, we'll just return the original filename 145 | } 146 | 147 | return { filename, contentType, url: parsedUrl }; 148 | } catch (e) { 149 | return undefined; 150 | } 151 | } 152 | 153 | export const writeToReadable = (response: string) => { 154 | return new ReadableStream({ 155 | start(controller) { 156 | controller.enqueue(new TextEncoder().encode(response)); 157 | controller.close(); 158 | }, 159 | }); 160 | }; 161 | -------------------------------------------------------------------------------- /packages/proxy/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "declaration": true, 4 | "lib": ["es2015", "dom"], 5 | "target": "ES2018", 6 | "strict": true, 7 | "moduleResolution": "node", 8 | "baseUrl": ".", 9 | "paths": { 10 | "@lib/*": ["src/*"], 11 | "@schema": ["schema/index"], 12 | "@schema/*": ["schema/*"], 13 | "@types": ["types/index"] 14 | }, 15 | "resolveJsonModule": true, 16 | "esModuleInterop": true 17 | }, 18 | "include": ["."], 19 | "exclude": ["node_modules/**", "**/dist/**"] 20 | } 21 | -------------------------------------------------------------------------------- /packages/proxy/tsup.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from "tsup"; 2 | 3 | export default defineConfig([ 4 | { 5 | entry: ["src/index.ts"], 6 | format: ["cjs", "esm"], 7 | outDir: "dist", 8 | dts: true, 9 | }, 10 | { 11 | entry: ["edge/index.ts"], 12 | format: ["cjs", "esm"], 13 | outDir: "edge/dist", 14 | dts: true, 15 | }, 16 | { 17 | entry: ["schema/index.ts"], 18 | format: ["cjs", "esm"], 19 | outDir: "schema/dist", 20 | dts: true, 21 | }, 22 | { 23 | entry: ["utils/index.ts"], 24 | format: ["cjs", "esm"], 25 | outDir: "utils/dist", 26 | dts: true, 27 | }, 28 | { 29 | entry: ["types/index.ts"], 30 | format: ["cjs", "esm"], 31 | outDir: "types/dist", 32 | dts: true, 33 | }, 34 | ]); 35 | -------------------------------------------------------------------------------- /packages/proxy/turbo.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["//"], 3 | "tasks": { 4 | "build": { 5 | "outputs": ["**/dist/**"] 6 | } 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /packages/proxy/types/anthropic.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | 3 | const cacheControlSchema = z.object({ 4 | type: z.enum(["ephemeral"]), 5 | }); 6 | 7 | const anthropicBase64ImageSourceSchema = z.object({ 8 | type: z.literal("base64"), 9 | media_type: z.enum(["image/jpeg", "image/png", "image/gif", "image/webp"]), 10 | data: z.string(), 11 | }); 12 | 13 | const anthropicUrlImageSourceSchema = z.object({ 14 | type: z.literal("url"), 15 | url: z.string(), 16 | }); 17 | 18 | const anthropicFileSourceSchema = z.object({ 19 | type: z.literal("file"), 20 | file_id: z.string(), 21 | }); 22 | 23 | export const anthropicContentPartImageSchema = z.object({ 24 | type: z.literal("image"), 25 | source: z.union([ 26 | anthropicBase64ImageSourceSchema, 27 | anthropicUrlImageSourceSchema, 28 | anthropicFileSourceSchema, 29 | ]), 30 | cache_control: cacheControlSchema.optional(), 31 | }); 32 | 33 | const anthropicContentPartTextSchema = z.object({ 34 | type: z.literal("text"), 35 | text: z.string(), 36 | cache_control: cacheControlSchema.optional(), 37 | }); 38 | 39 | const anthropicToolUseContentPartSchema = z.object({ 40 | type: z.literal("tool_use"), 41 | id: z.string(), 42 | name: z.string(), 43 | input: z.record(z.any()), 44 | cache_control: cacheControlSchema.optional(), 45 | }); 46 | 47 | const anthropicServerToolUseContentPartSchema = z.object({ 48 | type: z.literal("server_tool_use"), 49 | id: z.string(), 50 | name: z.enum(["web_search", "code_execution"]), 51 | input: z.record(z.any()), 52 | cache_control: cacheControlSchema.optional(), 53 | }); 54 | 55 | const anthropicWebSearchToolResultErrorSchema = z.object({ 56 | type: z.literal("web_search_tool_result_error"), 57 | errorCode: z.enum([ 58 | "invalid_tool_input", 59 | "unavailable", 60 | "max_uses_exceeded", 61 | "too_many_requests", 62 | "query_too_long", 63 | ]), 64 | }); 65 | 66 | const anthropicWebSearchToolResultSuccessSchema = z.object({ 67 | type: z.literal("web_search_result"), 68 | url: z.string(), 69 | page_age: z.number().nullish(), 70 | title: z.string(), 71 | encrypted_content: z.string(), 72 | }); 73 | 74 | const anthropicWebSearchToolResultContentPartSchema = z.object({ 75 | type: z.literal("web_search_tool_result"), 76 | tool_use_id: z.string(), 77 | content: z.union([ 78 | anthropicWebSearchToolResultErrorSchema, 79 | z.array(anthropicWebSearchToolResultSuccessSchema), 80 | ]), 81 | cache_control: cacheControlSchema.nullish(), 82 | }); 83 | 84 | const anthropicCodeExecutionToolResultErrorSchema = z.object({ 85 | type: z.literal("code_execution_tool_result_error"), 86 | errorCode: z.enum([ 87 | "invalid_tool_input", 88 | "unavailable", 89 | "too_many_requests", 90 | "query_too_long", 91 | ]), 92 | }); 93 | 94 | const anthropicCodeExecutionToolResultSuccessSchema = z.object({ 95 | type: z.literal("code_execution_result"), 96 | return_code: z.number(), 97 | stderr: z.string(), 98 | stdout: z.string(), 99 | content: z.array( 100 | z.object({ 101 | type: z.literal("code_execution_output"), 102 | file_id: z.string(), 103 | }), 104 | ), 105 | }); 106 | 107 | const anthropicCodeExecutionToolResultContentPartSchema = z.object({ 108 | type: z.literal("code_execution_tool_result"), 109 | tool_use_id: z.string(), 110 | content: z.union([ 111 | anthropicCodeExecutionToolResultErrorSchema, 112 | anthropicCodeExecutionToolResultSuccessSchema, 113 | ]), 114 | cache_control: cacheControlSchema.nullish(), 115 | }); 116 | 117 | const anthropicMCPToolUseContentPartSchema = z.object({ 118 | type: z.literal("mcp_tool_use"), 119 | id: z.string(), 120 | name: z.string(), 121 | input: z.record(z.any()), 122 | server_name: z.string(), 123 | cache_control: cacheControlSchema.nullish(), 124 | }); 125 | 126 | const anthropicMCPToolResultContentPartSchema = z.object({ 127 | type: z.literal("mcp_tool_result"), 128 | tool_use_id: z.string(), 129 | is_error: z.boolean(), 130 | content: z.union([ 131 | z.string(), 132 | z.array( 133 | z.object({ 134 | type: z.literal("text"), 135 | text: z.string(), 136 | // This is a simplification of the strict citation schema 137 | citations: z.array(z.record(z.any())).nullish(), 138 | cache_control: cacheControlSchema.nullish(), 139 | }), 140 | ), 141 | ]), 142 | cache_control: cacheControlSchema.nullish(), 143 | }); 144 | 145 | const anthropicTextImageContentBlockSchema = z.union([ 146 | z.string(), 147 | z.array( 148 | z.union([anthropicContentPartTextSchema, anthropicContentPartImageSchema]), 149 | ), 150 | ]); 151 | 152 | const anthropicToolResultContentPartSchema = z.object({ 153 | type: z.literal("tool_result"), 154 | tool_use_id: z.string(), 155 | content: anthropicTextImageContentBlockSchema.optional(), 156 | is_error: z.boolean().optional(), 157 | cache_control: cacheControlSchema.nullish(), 158 | }); 159 | 160 | const anthropicPDFSchema = z.object({ 161 | media_type: z.literal("application/pdf"), 162 | data: z.string(), 163 | type: z.literal("base64"), 164 | }); 165 | 166 | const anthropicPlainTextSchema = z.object({ 167 | media_type: z.literal("text/plain"), 168 | data: z.string(), 169 | type: z.literal("text"), 170 | }); 171 | 172 | const anthropicURLPDFSchema = z.object({ 173 | url: z.string(), 174 | type: z.literal("url"), 175 | }); 176 | 177 | const anthropicDocumentContentPartSchema = z.object({ 178 | type: z.literal("document"), 179 | source: z.union([ 180 | anthropicPDFSchema, 181 | anthropicPlainTextSchema, 182 | anthropicURLPDFSchema, 183 | anthropicTextImageContentBlockSchema, 184 | anthropicFileSourceSchema, 185 | ]), 186 | citations: z 187 | .object({ 188 | enabled: z.boolean().optional(), 189 | }) 190 | .optional(), 191 | context: z.string().nullish(), 192 | title: z.string().nullish(), 193 | cache_control: cacheControlSchema.nullish(), 194 | }); 195 | 196 | const anthropicThinkingContentPartSchema = z.object({ 197 | type: z.literal("thinking"), 198 | thinking: z.string(), 199 | signature: z.string(), 200 | }); 201 | 202 | const anthropicRedactedThinkingContentPartSchema = z.object({ 203 | type: z.literal("redacted_thinking"), 204 | data: z.string(), 205 | }); 206 | 207 | const anthropicContainerUploadContentPartSchema = z.object({ 208 | type: z.literal("container_upload"), 209 | file_id: z.string(), 210 | cache_control: cacheControlSchema.nullish(), 211 | }); 212 | 213 | export const anthropicContentPartSchema = z.union([ 214 | anthropicContentPartTextSchema, 215 | anthropicContentPartImageSchema, 216 | anthropicToolUseContentPartSchema, 217 | anthropicToolResultContentPartSchema, 218 | anthropicServerToolUseContentPartSchema, 219 | anthropicWebSearchToolResultContentPartSchema, 220 | anthropicCodeExecutionToolResultContentPartSchema, 221 | anthropicMCPToolUseContentPartSchema, 222 | anthropicMCPToolResultContentPartSchema, 223 | anthropicDocumentContentPartSchema, 224 | anthropicThinkingContentPartSchema, 225 | anthropicRedactedThinkingContentPartSchema, 226 | anthropicContainerUploadContentPartSchema, 227 | ]); 228 | 229 | // System blocks are provided as a separate parameter to the Anthropic client, rather than in the messages parameter. 230 | // However, we include it as an input on LLM spans created by the anthropic wrapper, so we need to support it here. 231 | export const anthropicMessageParamSchema = z.object({ 232 | role: z.enum(["system", "user", "assistant"]), 233 | content: z.union([z.string(), z.array(anthropicContentPartSchema)]), 234 | }); 235 | 236 | export type AnthropicContentPart = z.infer; 237 | export type AnthropicMessageParam = z.infer; 238 | -------------------------------------------------------------------------------- /packages/proxy/types/index.ts: -------------------------------------------------------------------------------- 1 | export * from "./openai"; 2 | export type * from "./openai"; 3 | export * from "./anthropic"; 4 | export type * from "./anthropic"; 5 | -------------------------------------------------------------------------------- /packages/proxy/types/openai.ts: -------------------------------------------------------------------------------- 1 | // TODO: move from core 2 | import { chatCompletionMessageParamSchema } from "@braintrust/core/typespecs"; 3 | 4 | import { z } from "zod"; 5 | 6 | import { 7 | ChatCompletion, 8 | ChatCompletionChunk, 9 | ChatCompletionCreateParams, 10 | } from "openai/resources"; 11 | 12 | export type OpenAIChatCompletionMessage = z.infer< 13 | typeof chatCompletionMessageParamSchema 14 | >; 15 | 16 | export type OpenAIChatCompletionChoice = ChatCompletion.Choice & { 17 | message: OpenAIChatCompletionMessage; 18 | }; 19 | 20 | export type OpenAIChatCompletion = ChatCompletion & { 21 | choices: Array; 22 | }; 23 | 24 | export const chatCompletionMessageReasoningSchema = z 25 | .object({ 26 | id: z 27 | .string() 28 | .nullish() 29 | .transform((x) => x ?? undefined), 30 | content: z 31 | .string() 32 | .nullish() 33 | .transform((x) => x ?? undefined), 34 | }) 35 | .describe( 36 | "Note: This is not part of the OpenAI API spec, but we added it for interoperability with multiple reasoning models.", 37 | ); 38 | 39 | export type OpenAIReasoning = z.infer< 40 | typeof chatCompletionMessageReasoningSchema 41 | >; 42 | 43 | export type OpenAIChatCompletionChunkChoiceDelta = 44 | ChatCompletionChunk.Choice.Delta & { 45 | reasoning?: OpenAIReasoning; 46 | }; 47 | 48 | export type OpenAIChatCompletionChunkChoice = ChatCompletionChunk.Choice & { 49 | delta: OpenAIChatCompletionChunkChoiceDelta; 50 | }; 51 | 52 | export type OpenAIChatCompletionChunk = ChatCompletionChunk & { 53 | choices: Array; 54 | }; 55 | 56 | export type OpenAIChatCompletionCreateParams = ChatCompletionCreateParams & { 57 | messages: Array; 58 | reasoning_enabled?: boolean; 59 | reasoning_budget?: number; 60 | }; 61 | 62 | // overrides 63 | declare module "openai/resources/chat/completions" { 64 | interface ChatCompletionCreateParamsBase { 65 | reasoning_enabled?: boolean; 66 | reasoning_budget?: number; 67 | } 68 | interface ChatCompletionAssistantMessageParam { 69 | reasoning?: OpenAIReasoning[]; 70 | } 71 | namespace ChatCompletion { 72 | interface Choice { 73 | reasoning?: OpenAIReasoning[]; 74 | } 75 | } 76 | namespace ChatCompletionChunk { 77 | namespace Choice { 78 | interface Delta { 79 | reasoning?: OpenAIReasoning; 80 | } 81 | } 82 | } 83 | } 84 | 85 | export const completionUsageSchema = z.object({ 86 | completion_tokens: z.number(), 87 | prompt_tokens: z.number(), 88 | total_tokens: z.number(), 89 | completion_tokens_details: z 90 | .object({ 91 | accepted_prediction_tokens: z.number().optional(), 92 | audio_tokens: z.number().optional(), 93 | reasoning_tokens: z.number().optional(), 94 | rejected_prediction_tokens: z.number().optional(), 95 | }) 96 | .optional(), 97 | prompt_tokens_details: z 98 | .object({ 99 | audio_tokens: z.number().optional(), 100 | cached_tokens: z.number().optional(), 101 | cache_creation_tokens: z 102 | .number() 103 | .optional() 104 | .describe( 105 | "Extension to support Anthropic `cache_creation_input_tokens`", 106 | ), 107 | }) 108 | .optional(), 109 | }); 110 | 111 | export type OpenAICompletionUsage = z.infer; 112 | -------------------------------------------------------------------------------- /packages/proxy/utils/audioEncoder.ts: -------------------------------------------------------------------------------- 1 | import { Mp3Bitrate, PcmAudioFormat } from "@schema/audio"; 2 | import { Mp3Encoder } from "@breezystack/lamejs"; 3 | 4 | export function makeWavFile( 5 | format: PcmAudioFormat, 6 | buffers: ArrayBufferLike[], 7 | ): Blob { 8 | if ( 9 | format.name === "pcm" && 10 | (format.byte_order !== "little" || format.bits_per_sample % 8 !== 0) 11 | ) { 12 | throw new Error(`Unsupported PCM format: ${JSON.stringify(format)}`); 13 | } 14 | 15 | if (format.name === "pcm" && format.number_encoding === "float") { 16 | // TODO(kevin): This path is untested. 17 | // TODO(kevin): This should probably just result in a float WAV file. 18 | format = { 19 | ...format, 20 | number_encoding: "int", 21 | byte_order: "little", 22 | bits_per_sample: 16, 23 | }; 24 | buffers = buffers.map((buffer) => 25 | floatTo16BitPCM(new Float32Array(buffer)), 26 | ); 27 | } 28 | 29 | const dataLength = buffers.reduce((sum, b) => sum + b.byteLength, 0); 30 | 31 | const bitsPerSample = format.name === "pcm" ? format.bits_per_sample : 8; 32 | 33 | // http://soundfile.sapp.org/doc/WaveFormat/ 34 | const blobParts = [ 35 | // Header. 36 | "RIFF", 37 | // Length. 38 | pack( 39 | 1, 40 | 4 + // "WAVE" length. 41 | (8 + 16) + // Chunk 1 length. 42 | (8 + dataLength), // Chunk 2 length. 43 | ), 44 | "WAVE", 45 | // Chunk 1. 46 | "fmt ", 47 | pack(1, 16), // Chunk length. 48 | pack(0, wavFormatCode(format)), // Audio format (1 is linear quantization). 49 | pack(0, format.channels), 50 | pack(1, format.sample_rate), 51 | pack(1, (format.sample_rate * format.channels * bitsPerSample) / 8), // Byte rate. 52 | pack(0, (format.channels * bitsPerSample) / 8), 53 | pack(0, bitsPerSample), 54 | // Chunk 2. 55 | "data", 56 | pack(1, dataLength), // Chunk length. 57 | ...buffers, 58 | ]; 59 | 60 | return new Blob(blobParts, { type: "audio/wav" }); 61 | } 62 | 63 | function wavFormatCode(format: PcmAudioFormat) { 64 | const name = format.name; // Need local variable to pass type checker. 65 | switch (name) { 66 | case "pcm": 67 | return 0x0001; 68 | case "g711": { 69 | switch (format.algorithm) { 70 | case "a": 71 | return 0x0006; 72 | case "mu": 73 | return 0x0007; 74 | default: 75 | const x: never = format.algorithm; 76 | throw new Error(x); 77 | } 78 | } 79 | default: 80 | const x: never = name; 81 | throw new Error(x); 82 | } 83 | } 84 | 85 | /** 86 | * Pack a number into a byte array. 87 | * @param size Pass `0` for 16-bit output, or `1` for 32-bit output. Large 88 | * values will be truncated. 89 | * @param arg Integer to pack. 90 | * @returns Byte array with the integer. 91 | */ 92 | function pack(size: 0 | 1, arg: number) { 93 | return new Uint8Array( 94 | size === 0 ? [arg, arg >> 8] : [arg, arg >> 8, arg >> 16, arg >> 24], 95 | ); 96 | } 97 | 98 | function floatTo16BitPCM(float32Array: Float32Array) { 99 | const buffer = new ArrayBuffer(float32Array.length * 2); 100 | const view = new DataView(buffer); 101 | let offset = 0; 102 | for (let i = 0; i < float32Array.length; i++, offset += 2) { 103 | const s = Math.max(-1, Math.min(1, float32Array[i])); 104 | view.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7fff, true); 105 | } 106 | return new Int16Array(buffer); 107 | } 108 | 109 | export function makeMp3File( 110 | inputCodec: PcmAudioFormat, 111 | bitrate: Mp3Bitrate, 112 | buffers: ArrayBufferLike[], 113 | ): Blob { 114 | if (inputCodec.name !== "pcm") { 115 | throw new Error("Unsupported input codec"); 116 | } 117 | if ( 118 | inputCodec.bits_per_sample !== 16 || 119 | inputCodec.byte_order !== "little" || 120 | inputCodec.channels !== 1 121 | ) { 122 | throw new Error("Unsupported input encoding"); 123 | } 124 | const minBitrate: Mp3Bitrate = 40; 125 | if (bitrate < minBitrate) { 126 | // Possible bug in lamejs that results in a silent file when bitrate <= 32. 127 | console.warn(`Adjusting bitrate ${bitrate} -> ${minBitrate}`); 128 | bitrate = minBitrate; 129 | } 130 | 131 | const encoder = new Mp3Encoder( 132 | inputCodec.channels, 133 | inputCodec.sample_rate, 134 | bitrate, 135 | ); 136 | 137 | const blobParts: ArrayBuffer[] = []; 138 | 139 | for (const buffer of buffers) { 140 | const int16Buffer = 141 | inputCodec.number_encoding === "int" 142 | ? new Int16Array(buffer) 143 | : floatTo16BitPCM(new Float32Array(buffer)); 144 | const encoded = encoder.encodeBuffer(int16Buffer); 145 | if (encoded.length) { 146 | blobParts.push(encoded); 147 | } 148 | } 149 | 150 | blobParts.push(encoder.flush()); 151 | 152 | return new Blob(blobParts, { type: "audio/mpeg" }); 153 | } 154 | -------------------------------------------------------------------------------- /packages/proxy/utils/deps.test.ts: -------------------------------------------------------------------------------- 1 | import skott from "skott"; 2 | import { expect, it } from "vitest"; 3 | 4 | it("no circ dependencies", async () => { 5 | const { useGraph } = await skott({ 6 | entrypoint: `${__dirname}/index.ts`, 7 | tsConfigPath: `${__dirname}/../tsconfig.json`, 8 | dependencyTracking: { 9 | builtin: false, 10 | thirdParty: true, 11 | typeOnly: true, 12 | }, 13 | }); 14 | 15 | const { findCircularDependencies } = useGraph(); 16 | 17 | expect(findCircularDependencies()).toEqual([]); 18 | }); 19 | -------------------------------------------------------------------------------- /packages/proxy/utils/encrypt.ts: -------------------------------------------------------------------------------- 1 | import { z } from "zod"; 2 | 3 | let _issuedCryptoSubtleWarning = false; 4 | function issueCryptoSubtleWarning() { 5 | if (!_issuedCryptoSubtleWarning) { 6 | console.warn( 7 | "Crypto utils are not supported in this browser. Skipping any crypto-related functionality (such as realtime)", 8 | ); 9 | _issuedCryptoSubtleWarning = true; 10 | } 11 | } 12 | 13 | function getSubtleCrypto() { 14 | return globalThis.crypto.subtle; 15 | } 16 | 17 | export function isCryptoAvailable(): boolean { 18 | const ret = !!getSubtleCrypto(); 19 | if (!ret) { 20 | issueCryptoSubtleWarning(); 21 | } 22 | return ret; 23 | } 24 | 25 | export function base64ToArrayBuffer(base64: string) { 26 | var binaryString = atob(base64); 27 | var bytes = new Uint8Array(binaryString.length); 28 | for (var i = 0; i < binaryString.length; i++) { 29 | bytes[i] = binaryString.charCodeAt(i); 30 | } 31 | return bytes.buffer; 32 | } 33 | 34 | export function arrayBufferToBase64(buffer: ArrayBuffer) { 35 | var binary = ""; 36 | var bytes = new Uint8Array(buffer); 37 | var len = bytes.byteLength; 38 | for (var i = 0; i < len; i++) { 39 | binary += String.fromCharCode(bytes[i]); 40 | } 41 | return btoa(binary); 42 | } 43 | 44 | // https://github.com/mdn/dom-examples/blob/main/web-crypto/encrypt-decrypt/aes-gcm.js 45 | export async function decryptMessage( 46 | keyString: string, 47 | iv: string, 48 | message: string, 49 | ): Promise { 50 | if (!isCryptoAvailable()) return undefined; 51 | 52 | const key = await getSubtleCrypto().importKey( 53 | "raw", 54 | base64ToArrayBuffer(keyString), 55 | { name: "AES-GCM", length: 256 }, 56 | false, 57 | ["decrypt"], 58 | ); 59 | 60 | const decoded = await getSubtleCrypto().decrypt( 61 | { 62 | name: "AES-GCM", 63 | iv: base64ToArrayBuffer(iv), 64 | }, 65 | key, 66 | base64ToArrayBuffer(message), 67 | ); 68 | 69 | return new TextDecoder().decode(decoded); 70 | } 71 | 72 | export const encryptedMessageSchema = z.strictObject({ 73 | iv: z.string(), 74 | data: z.string(), 75 | }); 76 | export type EncryptedMessage = z.infer; 77 | 78 | export async function encryptMessage( 79 | keyString: string, 80 | message: string, 81 | ): Promise { 82 | if (!isCryptoAvailable()) return undefined; 83 | 84 | const key = await getSubtleCrypto().importKey( 85 | "raw", 86 | base64ToArrayBuffer(keyString), 87 | { name: "AES-GCM", length: 256 }, 88 | false, 89 | ["encrypt"], 90 | ); 91 | 92 | const iv = crypto.getRandomValues(new Uint8Array(12)); 93 | const decoded = await getSubtleCrypto().encrypt( 94 | { 95 | name: "AES-GCM", 96 | iv, 97 | }, 98 | key, 99 | new TextEncoder().encode(message), 100 | ); 101 | 102 | return { 103 | iv: arrayBufferToBase64(new Uint8Array(iv)), 104 | data: arrayBufferToBase64(decoded), 105 | }; 106 | } 107 | -------------------------------------------------------------------------------- /packages/proxy/utils/index.ts: -------------------------------------------------------------------------------- 1 | export { 2 | parseOpenAIStream, 3 | isChatCompletionChunk, 4 | isCompletion, 5 | } from "./openai"; 6 | export * from "./encrypt"; 7 | 8 | export { 9 | isTempCredential, 10 | makeTempCredentials, 11 | verifyTempCredentials, 12 | } from "./tempCredentials"; 13 | 14 | export { makeWavFile, makeMp3File } from "./audioEncoder"; 15 | 16 | export function getCurrentUnixTimestamp(): number { 17 | return Date.now() / 1000; 18 | } 19 | 20 | export const effortToBudgetMultiplier = { 21 | low: 0.2, 22 | medium: 0.5, 23 | high: 0.8, 24 | } as const; 25 | 26 | export const getBudgetMultiplier = ( 27 | effort: keyof typeof effortToBudgetMultiplier, 28 | ) => { 29 | return effortToBudgetMultiplier[effort] || effortToBudgetMultiplier.low; 30 | }; 31 | -------------------------------------------------------------------------------- /packages/proxy/utils/openai.ts: -------------------------------------------------------------------------------- 1 | import { 2 | OpenAIChatCompletionChunk, 3 | OpenAIChatCompletionCreateParams, 4 | } from "@types"; 5 | import { trimStartOfStreamHelper } from "ai"; 6 | import { ChatCompletionCreateParams, Completion } from "openai/resources"; 7 | 8 | /** 9 | * Creates a parser function for processing the OpenAI stream data. 10 | * The parser extracts and trims text content from the JSON data. This parser 11 | * can handle data for chat or completion models. 12 | * 13 | * @return {(data: string) => string | void} A parser function that takes a JSON string as input and returns the extracted text content or nothing. 14 | */ 15 | export function parseOpenAIStream(): (data: string) => string | void { 16 | const extract = chunkToText(); 17 | return (data) => extract(JSON.parse(data) as OpenAIStreamReturnTypes); 18 | } 19 | 20 | function chunkToText(): (chunk: OpenAIStreamReturnTypes) => string | void { 21 | const trimStartOfStream = trimStartOfStreamHelper(); 22 | let isFunctionStreamingIn: boolean; 23 | return (json) => { 24 | if (isChatCompletionChunk(json)) { 25 | const delta = json.choices[0]?.delta; 26 | if (delta.function_call?.name) { 27 | isFunctionStreamingIn = true; 28 | return `{"function_call": {"name": "${delta.function_call.name}", "arguments": "`; 29 | } else if (delta.tool_calls?.[0]?.function?.name) { 30 | isFunctionStreamingIn = true; 31 | const toolCall = delta.tool_calls[0]; 32 | if (toolCall.index === 0) { 33 | return `{"tool_calls":[ {"id": "${toolCall.id}", "type": "function", "function": {"name": "${toolCall.function?.name}", "arguments": "`; 34 | } else { 35 | return `"}}, {"id": "${toolCall.id}", "type": "function", "function": {"name": "${toolCall.function?.name}", "arguments": "`; 36 | } 37 | } else if (delta.function_call?.arguments) { 38 | return cleanupArguments(delta.function_call?.arguments); 39 | } else if (delta.tool_calls?.[0]?.function?.arguments) { 40 | return cleanupArguments(delta.tool_calls?.[0]?.function?.arguments); 41 | } else if ( 42 | isFunctionStreamingIn && 43 | (json.choices[0]?.finish_reason === "function_call" || 44 | json.choices[0]?.finish_reason === "stop") 45 | ) { 46 | isFunctionStreamingIn = false; // Reset the flag 47 | return '"}}'; 48 | } else if ( 49 | isFunctionStreamingIn && 50 | json.choices[0]?.finish_reason === "tool_calls" 51 | ) { 52 | isFunctionStreamingIn = false; // Reset the flag 53 | return '"}}]}'; 54 | } 55 | } 56 | 57 | const text = trimStartOfStream( 58 | isChatCompletionChunk(json) && json.choices[0].delta.content 59 | ? json.choices[0].delta.content 60 | : isCompletion(json) 61 | ? json.choices[0].text 62 | : "", 63 | ); 64 | return text; 65 | }; 66 | 67 | function cleanupArguments(argumentChunk: string) { 68 | let escapedPartialJson = argumentChunk 69 | .replace(/\\/g, "\\\\") // Replace backslashes first to prevent double escaping 70 | .replace(/\//g, "\\/") // Escape slashes 71 | .replace(/"/g, '\\"') // Escape double quotes 72 | .replace(/\n/g, "\\n") // Escape new lines 73 | .replace(/\r/g, "\\r") // Escape carriage returns 74 | .replace(/\t/g, "\\t") // Escape tabs 75 | .replace(/\f/g, "\\f"); // Escape form feeds 76 | 77 | return `${escapedPartialJson}`; 78 | } 79 | } 80 | 81 | const __internal__OpenAIFnMessagesSymbol = Symbol( 82 | "internal_openai_fn_messages", 83 | ); 84 | 85 | type AzureChatCompletions = any; 86 | 87 | type AsyncIterableOpenAIStreamReturnTypes = 88 | | AsyncIterable 89 | | AsyncIterable 90 | | AsyncIterable; 91 | 92 | type ExtractType = T extends AsyncIterable ? U : never; 93 | 94 | type OpenAIStreamReturnTypes = 95 | ExtractType; 96 | 97 | export function isChatCompletionChunk( 98 | data: unknown, 99 | ): data is OpenAIChatCompletionChunk { 100 | if (!data || typeof data !== "object") { 101 | return false; 102 | } 103 | return ( 104 | "choices" in data && 105 | data.choices && 106 | Array.isArray(data.choices) && 107 | data.choices[0] && 108 | "delta" in data.choices[0] 109 | ); 110 | } 111 | 112 | export function isCompletion(data: unknown): data is Completion { 113 | if (!data || typeof data !== "object") { 114 | return false; 115 | } 116 | return ( 117 | "choices" in data && 118 | data.choices && 119 | Array.isArray(data.choices) && 120 | data.choices[0] && 121 | "text" in data.choices[0] 122 | ); 123 | } 124 | 125 | /** 126 | * Cleans the OpenAI parameters by removing extra braintrust fields. 127 | * 128 | * @param {OpenAIChatCompletionCreateParams} params - The OpenAI parameters to clean. 129 | * @returns {ChatCompletionCreateParams} - The cleaned OpenAI parameters. 130 | */ 131 | export function cleanOpenAIParams({ 132 | reasoning_effort, 133 | reasoning_budget, 134 | reasoning_enabled, 135 | ...openai 136 | }: OpenAIChatCompletionCreateParams): ChatCompletionCreateParams { 137 | return openai; 138 | } 139 | -------------------------------------------------------------------------------- /packages/proxy/utils/tempCredentials.test.ts: -------------------------------------------------------------------------------- 1 | import { expect, test } from "vitest"; 2 | import { 3 | isTempCredential, 4 | makeTempCredentialsJwt, 5 | verifyTempCredentials, 6 | verifyJwtOnly, 7 | makeTempCredentials, 8 | } from "./tempCredentials"; 9 | import { 10 | sign as jwtSign, 11 | verify as jwtVerify, 12 | decode as jwtDecode, 13 | } from "jsonwebtoken"; 14 | import { base64ToArrayBuffer } from "./encrypt"; 15 | import { 16 | tempCredentialJwtPayloadSchema, 17 | TempCredentialsCacheValue, 18 | } from "@schema"; 19 | 20 | test("isTempCredential", () => { 21 | expect(isTempCredential("")).toStrictEqual(false); 22 | expect(isTempCredential("not a jwt")).toStrictEqual(false); 23 | expect(isTempCredential("foo.bar.baz")).toStrictEqual(false); 24 | 25 | // Generated by https://jwt.io/ with empty object payload. 26 | const jwtEmptyPayload = 27 | "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Et9HFtf9R3GEMA0IICOfFMVXY7kkTX1wr4qCyhIf58U"; 28 | expect(isTempCredential(jwtEmptyPayload)).toStrictEqual(false); 29 | 30 | // Payload contains { iss: "other" }. 31 | const jwtWithOtherIss = 32 | "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJvdGhlciIsImlhdCI6MTUxNjIzOTAyMn0.AEa0ufe56lGXsudWgXkGQFgCHASl01lgg9QOOOxVDrk"; 33 | expect(isTempCredential(jwtWithOtherIss)).toStrictEqual(false); 34 | 35 | // Payload contains { iss: "braintrust_proxy" }. 36 | const jwtWithBraintrustIss = 37 | "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJicmFpbnRydXN0X3Byb3h5IiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.hpplNcSv9qiWEpk_vKXSWZWnXBjiFy4F6phxdKUG30s"; 38 | expect(isTempCredential(jwtWithBraintrustIss)).toStrictEqual(true); 39 | 40 | // Payload contains { aud: "braintrust_proxy" }. 41 | const jwtWithBraintrustAud = 42 | "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJicmFpbnRydXN0X3Byb3h5IiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.nrmMZvokcpywnPvDRhgG635FTY_bBzkYpswafWqogTs"; 43 | expect(isTempCredential(jwtWithBraintrustAud)).toStrictEqual(true); 44 | }); 45 | 46 | test("makeTempCredentialsJwt signing", () => { 47 | const result = makeTempCredentialsJwt({ 48 | request: { model: "model", ttl_seconds: 100 }, 49 | authToken: "auth token", 50 | orgName: "my org name", 51 | }); 52 | 53 | // Some HTTP servers have a header size limit. 54 | expect(result.jwt.length).toBeLessThan(2000); 55 | 56 | // Throws if JWT signature verification fails. 57 | const rawPayload = jwtVerify(result.jwt, "auth token", { complete: false }); 58 | 59 | expect(rawPayload).toBeTruthy(); 60 | expect(rawPayload).toBeTypeOf("object"); 61 | 62 | // Example: 63 | // { 64 | // "aud": "braintrust_proxy", 65 | // "bt": { 66 | // "model": "model", 67 | // "org_name": "my org name", 68 | // "secret": "nCCxgkBoyy/zyOJlikuHILBMoK78bHFosEzy03SjJF0=", 69 | // }, 70 | // "exp": 1729928077, 71 | // "iat": 1729927977, 72 | // "iss": "braintrust_proxy", 73 | // "jti": "bt_tmp:331278af-937c-4f97-9d42-42c83631001a", 74 | // } 75 | const payload = tempCredentialJwtPayloadSchema.parse(rawPayload); 76 | 77 | expect(payload.bt.model).toStrictEqual("model"); 78 | expect(payload.bt.org_name).toStrictEqual("my org name"); 79 | 80 | expect(payload.bt.secret).not.toHaveLength(0); 81 | expect(payload.bt.secret).toStrictEqual(result.cacheEncryptionKey); 82 | 83 | expect(payload.jti).not.toHaveLength(0); 84 | expect(payload.jti).toStrictEqual(result.credentialId); 85 | 86 | expect(payload.exp - payload.iat).toStrictEqual(100); 87 | 88 | expect(base64ToArrayBuffer(result.cacheEncryptionKey).byteLength).toEqual( 89 | 256 / 8, 90 | ); 91 | 92 | expect(result.cachePayloadPlaintext).toEqual({ authToken: "auth token" }); 93 | expect(result.credentialId).toMatch( 94 | /^bt_tmp:[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/, 95 | ); 96 | }); 97 | 98 | test("makeTempCredentialsJwt no secret reuse", () => { 99 | const args = { 100 | request: { model: "model", ttl_seconds: 100 }, 101 | authToken: "auth token", 102 | orgName: "my org name", 103 | }; 104 | const result1 = makeTempCredentialsJwt({ ...args }); 105 | const result2 = makeTempCredentialsJwt({ ...args }); 106 | 107 | expect(result1.credentialId).not.toStrictEqual(result2.credentialId); 108 | expect(result1.cacheEncryptionKey).not.toStrictEqual( 109 | result2.cacheEncryptionKey, 110 | ); 111 | 112 | const raw1 = jwtDecode(result1.jwt, { complete: false, json: true }); 113 | const raw2 = jwtDecode(result2.jwt, { complete: false, json: true }); 114 | 115 | const payload1 = tempCredentialJwtPayloadSchema.parse(raw1); 116 | const payload2 = tempCredentialJwtPayloadSchema.parse(raw2); 117 | 118 | expect(payload1.bt.secret).not.toStrictEqual(payload2.bt.secret); 119 | expect(payload1.jti).not.toStrictEqual(payload2.jti); 120 | }); 121 | 122 | test("makeTempCredentials no wrapping other temp credential", async () => { 123 | const result = makeTempCredentialsJwt({ 124 | request: { model: "model", ttl_seconds: 100 }, 125 | authToken: "auth token", 126 | orgName: "my org name", 127 | }); 128 | 129 | // Use the previous temp credential JWT to issue another one. 130 | await expect( 131 | makeTempCredentials({ 132 | authToken: result.jwt, 133 | body: { 134 | model: null, 135 | ttl_seconds: 200, 136 | }, 137 | cachePut: async () => undefined, 138 | }), 139 | ).rejects.toThrow(); 140 | }); 141 | 142 | test("verifyJwtOnly basic", () => { 143 | const credentialCacheValue: TempCredentialsCacheValue = { 144 | authToken: "auth token", 145 | }; 146 | 147 | expect(() => 148 | verifyJwtOnly({ jwt: "not a jwt", credentialCacheValue }), 149 | ).toThrow("jwt malformed"); 150 | 151 | expect(() => verifyJwtOnly({ jwt: "a.b.c", credentialCacheValue })).toThrow( 152 | "invalid token", 153 | ); 154 | }); 155 | 156 | test("verifyTempCredentials wrong payload type", async () => { 157 | const cacheGet = async () => `{ "authToken": "auth token" }`; 158 | 159 | // Object that does not conform to schema. 160 | const jwtWrongSchema = jwtSign({ wrong: "schema" }, "auth token", { 161 | algorithm: "HS256", 162 | }); 163 | await expect( 164 | verifyTempCredentials({ jwt: jwtWrongSchema, cacheGet }), 165 | ).rejects.toThrow("invalid_literal"); 166 | 167 | // Non object. 168 | const jwtWrongType = jwtSign("not an object", "auth token", { 169 | algorithm: "HS256", 170 | }); 171 | await expect( 172 | verifyTempCredentials({ jwt: jwtWrongType, cacheGet }), 173 | ).rejects.toThrow("not valid JSON"); 174 | }); 175 | 176 | test("verifyTempCredentials signature verification", async () => { 177 | const { 178 | jwt, 179 | cacheEncryptionKey, 180 | credentialId, 181 | cachePayloadPlaintext: credentialCacheValue, 182 | } = makeTempCredentialsJwt({ 183 | request: { model: "model", ttl_seconds: 100 }, 184 | authToken: "auth token", 185 | orgName: "my org name", 186 | }); 187 | 188 | // Valid JWT. 189 | expect(() => verifyJwtOnly({ jwt, credentialCacheValue })).not.toThrow(); 190 | 191 | // Valid JWT, valid cache. 192 | const cacheGet = async ( 193 | encryptionKey: string, 194 | key: string, 195 | ): Promise => 196 | encryptionKey === cacheEncryptionKey && key === credentialId 197 | ? JSON.stringify(credentialCacheValue) 198 | : null; 199 | 200 | await expect(verifyTempCredentials({ jwt, cacheGet })).resolves.toEqual({ 201 | jwtPayload: jwtDecode(jwt, { complete: false, json: true }), 202 | credentialCacheValue, 203 | }); 204 | 205 | // Valid JWT, failed cache call. 206 | const badCacheGet = async () => null; 207 | await expect( 208 | verifyTempCredentials({ jwt, cacheGet: badCacheGet }), 209 | ).rejects.toThrow(); 210 | 211 | // Incorrect signature, nonnull cache response. 212 | const wrongSecretCacheGet = async () => `{ "authToken": "wrong auth token" }`; 213 | await expect( 214 | verifyTempCredentials({ jwt, cacheGet: wrongSecretCacheGet }), 215 | ).rejects.toThrow("invalid signature"); 216 | 217 | // Correct signature, incorrect scheme. 218 | const jwtPayloadRaw = jwtDecode(jwt, { complete: false, json: true }); 219 | if (!jwtPayloadRaw) { 220 | throw new Error("This should not happen"); 221 | } 222 | const jwtWrongAlgorithm = jwtSign(jwtPayloadRaw, "auth token", { 223 | algorithm: "HS512", 224 | }); 225 | await expect( 226 | verifyTempCredentials({ jwt: jwtWrongAlgorithm, cacheGet }), 227 | ).rejects.toThrow("invalid algorithm"); 228 | }); 229 | 230 | test("verifyTempCredentials expiration", async () => { 231 | const { jwt, cachePayloadPlaintext: credentialCacheValue } = 232 | makeTempCredentialsJwt({ 233 | request: { ttl_seconds: 0 }, 234 | authToken: "auth token", 235 | }); 236 | 237 | // Make sure the token is truly expired. 238 | // Probably not the best practice in a unit test. 239 | await new Promise((r) => setTimeout(r, 1000)); 240 | 241 | const cacheGet = async () => JSON.stringify(credentialCacheValue); 242 | await expect(verifyTempCredentials({ jwt, cacheGet })).rejects.toThrow( 243 | "jwt expired", 244 | ); 245 | }); 246 | -------------------------------------------------------------------------------- /packages/proxy/utils/tempCredentials.ts: -------------------------------------------------------------------------------- 1 | import { 2 | CredentialsRequest, 3 | credentialsRequestSchema, 4 | TempCredentialJwtPayload, 5 | tempCredentialJwtPayloadSchema, 6 | TempCredentialsCacheValue, 7 | tempCredentialsCacheValueSchema, 8 | } from "@schema/secrets"; 9 | import { v4 as uuidv4 } from "uuid"; 10 | import { arrayBufferToBase64 } from "./encrypt"; 11 | import jsonwebtoken from "jsonwebtoken"; 12 | import { isEmpty } from "@lib/util"; 13 | 14 | const JWT_ALGORITHM = "HS256"; 15 | 16 | export interface MakeTempCredentialResult { 17 | /** 18 | * A generated ID to identify the temporary credential request. The caller 19 | * uses this key for the credential cache. 20 | */ 21 | credentialId: string; 22 | /** 23 | * The plaintext payload for the credential cache. The caller is expected to 24 | * encrypt this value and insert it into the credential cache. 25 | */ 26 | cachePayloadPlaintext: TempCredentialsCacheValue; 27 | /** 28 | * The encryption key to be used for the credential cache. The caller should 29 | * not retain this value after it is used for insertion into the credential 30 | * cache. 31 | */ 32 | cacheEncryptionKey: string; 33 | /** 34 | * The new temporary credential encoded as a JWT. 35 | */ 36 | jwt: string; 37 | } 38 | 39 | /** 40 | * Generate a new temporary credential in the JWT format. 41 | * 42 | * @param param0 43 | * @param param0.request The temporary credential request to sign. 44 | * @param param0.authToken The user's Braintrust API key. 45 | * @param param0.orgName (Optional) The oranization name associated with the 46 | * Braintrust API key, to be used by the proxy at request time for looking up AI 47 | * provider keys. 48 | * @returns See {@link MakeTempCredentialResult}. 49 | */ 50 | export function makeTempCredentialsJwt({ 51 | request, 52 | authToken, 53 | orgName, 54 | }: { 55 | request: CredentialsRequest; 56 | authToken: string; 57 | orgName?: string; 58 | }): MakeTempCredentialResult { 59 | const credentialId = `bt_tmp:${uuidv4()}`; 60 | 61 | // Generate 256-bit key since our cache uses AES-256. 62 | const keyLengthBytes = 256 / 8; 63 | const cacheEncryptionKey = arrayBufferToBase64( 64 | crypto.getRandomValues(new Uint8Array(keyLengthBytes)), 65 | ); 66 | 67 | // The partial payload is missing timestamps (`iat`, `exp`). They will be 68 | // populated at signing time with the `mutatePayload` option. 69 | const jwtPayload: Partial = { 70 | iss: "braintrust_proxy", 71 | aud: "braintrust_proxy", 72 | jti: credentialId, 73 | bt: { 74 | org_name: orgName, 75 | model: request.model ?? undefined, 76 | secret: cacheEncryptionKey, 77 | logging: request.logging ?? undefined, 78 | }, 79 | }; 80 | const jwt = jsonwebtoken.sign(jwtPayload, authToken, { 81 | expiresIn: request.ttl_seconds, 82 | mutatePayload: true, 83 | algorithm: JWT_ALGORITHM, 84 | }); 85 | 86 | if (!tempCredentialJwtPayloadSchema.safeParse(jwtPayload).success) { 87 | // This should not happen. 88 | throw new Error("JWT payload didn't pass schema check after signing"); 89 | } 90 | 91 | return { 92 | credentialId, 93 | cachePayloadPlaintext: { authToken }, 94 | cacheEncryptionKey, 95 | jwt, 96 | }; 97 | } 98 | 99 | /** 100 | * Generate a new temporary credential and insert it into the credential cache. 101 | * 102 | * @param param0 103 | * @param param0.authToken The user's Braintrust API key. 104 | * @param param0.body The credential request body after JSON decoding. 105 | * @param param0.orgName (Optional) The oranization name associated with the 106 | * Braintrust API key, to be used by the proxy at request time for looking up AI 107 | * provider keys. 108 | * @param param0.cachePut: Function to encrypt and insert into the credential 109 | * cache. 110 | * @returns 111 | */ 112 | export async function makeTempCredentials({ 113 | authToken, 114 | body: rawBody, 115 | orgName, 116 | cachePut, 117 | }: { 118 | authToken: string; 119 | body: unknown; 120 | orgName?: string; 121 | cachePut: ( 122 | encryptionKey: string, 123 | key: string, 124 | value: string, 125 | ttl_seconds?: number, 126 | ) => Promise; 127 | }) { 128 | if (isTempCredential(authToken)) { 129 | // Disallow issuing a temp credential that wraps another temp credential. 130 | // This is fine from a security standpoint because such a credential will 131 | // fail later while fetching API secrets. However, allowing this is a 132 | // confusing and perhaps hard to debug behavior, so we prefer to fail fast. 133 | throw new Error( 134 | "Temporary credential cannot be used to issue another temp credential.", 135 | ); 136 | } 137 | 138 | const body = credentialsRequestSchema.safeParse(rawBody); 139 | if (!body.success) { 140 | throw new Error(body.error.message); 141 | } 142 | 143 | const { credentialId, cachePayloadPlaintext, cacheEncryptionKey, jwt } = 144 | makeTempCredentialsJwt({ request: body.data, authToken, orgName }); 145 | 146 | const { ttl_seconds } = body.data; 147 | 148 | await cachePut( 149 | cacheEncryptionKey, 150 | credentialId, 151 | JSON.stringify(cachePayloadPlaintext), 152 | ttl_seconds, 153 | ); 154 | 155 | return jwt; 156 | } 157 | 158 | /** 159 | * Check whether the JWT appears to be a Braintrust temporary credential. This 160 | * function only checks for a syntactically valid JWT with a Braintrust `iss` 161 | * or `aud` field. 162 | * 163 | * In case this function returns some false positives when sniffing whether a 164 | * token is a Braintrust temp credential, this does not affect confidentiality 165 | * or integrity. However, we still want to be precise so we can show the proper 166 | * error message in case there are multiple token types using JWT. 167 | * 168 | * @param jwt The encoded JWT to check. 169 | * @returns True if the `jwt` satisfies the checks. 170 | */ 171 | export function isTempCredential(jwt: string): boolean { 172 | const looseJwtPayloadSchema = tempCredentialJwtPayloadSchema 173 | .pick({ iss: true }) 174 | .or(tempCredentialJwtPayloadSchema.pick({ aud: true })); 175 | return looseJwtPayloadSchema.safeParse( 176 | jsonwebtoken.decode(jwt, { complete: false, json: true }), 177 | ).success; 178 | } 179 | 180 | /** 181 | * Throws if the jwt has an invalid signature or is expired. Does not verify 182 | * Braintrust payload. 183 | * 184 | * @throws uncaught exceptions from the `jsonwebtoken` library: 185 | * https://www.npmjs.com/package/jsonwebtoken?activeTab=readme#errors--codes 186 | */ 187 | export function verifyJwtOnly({ 188 | jwt, 189 | credentialCacheValue, 190 | }: { 191 | jwt: string; 192 | credentialCacheValue: TempCredentialsCacheValue; 193 | }): void { 194 | jsonwebtoken.verify(jwt, credentialCacheValue.authToken, { 195 | algorithms: [JWT_ALGORITHM], 196 | }); 197 | } 198 | 199 | export interface VerifyTempCredentialsResult { 200 | jwtPayload: TempCredentialJwtPayload; 201 | credentialCacheValue: TempCredentialsCacheValue; 202 | } 203 | /** 204 | * Check whether the JWT has a valid signature and expiration, then use the 205 | * payload to retrieve and decrypt the cached user credential. 206 | * 207 | * @throws an exception if the credential is invalid for any reason. The 208 | * `message` does not contain sensitive information and can be safely returned 209 | * to the user. 210 | * 211 | * @param param0 212 | * @param param0.jwt The encoded JWT to check. 213 | * @param param0.cacheGet Function to get and decrypt from the credential cache. 214 | * @returns See {@link VerifyTempCredentialsResult}. 215 | */ 216 | export async function verifyTempCredentials({ 217 | jwt, 218 | cacheGet, 219 | }: { 220 | jwt: string; 221 | cacheGet: (encryptionKey: string, key: string) => Promise; 222 | }): Promise { 223 | // Decode, but do not verify, just to get the ID and encryption key. 224 | const jwtPayloadRaw = jsonwebtoken.decode(jwt, { 225 | complete: false, 226 | json: true, 227 | }); 228 | if (isEmpty(jwtPayloadRaw)) { 229 | throw new Error("Could not parse JWT format"); 230 | } 231 | 232 | // Safe to show exception message to the client because they already know the 233 | // request contents. 234 | const jwtPayload = tempCredentialJwtPayloadSchema.parse(jwtPayloadRaw); 235 | 236 | let credentialCacheValue: TempCredentialsCacheValue | undefined; 237 | try { 238 | const cacheValueString = await cacheGet( 239 | jwtPayload.bt.secret, 240 | jwtPayload.jti, 241 | ); 242 | if (!cacheValueString) { 243 | throw new Error("expired"); 244 | } 245 | credentialCacheValue = tempCredentialsCacheValueSchema.parse( 246 | JSON.parse(cacheValueString), 247 | ); 248 | } catch (error) { 249 | // Hide error detail to avoid accidentally disclosing Braintrust auth token. 250 | if (error instanceof Error && error.message !== "expired") { 251 | console.error( 252 | "Credential cache error:", 253 | error.stack || "stack trace not available", 254 | ); 255 | } 256 | throw new Error("Could not access credential cache"); 257 | } 258 | 259 | // Safe to show exception message to the client. 260 | verifyJwtOnly({ jwt, credentialCacheValue }); 261 | 262 | // At this point, the JWT signature has been verified. We can safely return 263 | // the previously decoded result. 264 | return { jwtPayload, credentialCacheValue }; 265 | } 266 | -------------------------------------------------------------------------------- /packages/proxy/utils/tests.ts: -------------------------------------------------------------------------------- 1 | /* eslint-disable turbo/no-undeclared-env-vars */ 2 | 3 | import { TextDecoder } from "util"; 4 | import { Buffer } from "node:buffer"; 5 | import { proxyV1 } from "../src/proxy"; 6 | import { getModelEndpointTypes } from "@schema"; 7 | import { createParser, ParsedEvent, ParseEvent } from "eventsource-parser"; 8 | 9 | export function createResponseStream(): [ 10 | WritableStream, 11 | Promise, 12 | ] { 13 | const chunks: Uint8Array[] = []; 14 | let resolveChunks: (chunks: Uint8Array[]) => void; 15 | let rejectChunks: (error: Error) => void; 16 | 17 | const chunksPromise = new Promise((resolve, reject) => { 18 | resolveChunks = resolve; 19 | rejectChunks = reject; 20 | }); 21 | 22 | const writableStream = new WritableStream({ 23 | write(chunk) { 24 | chunks.push(chunk); 25 | }, 26 | close() { 27 | resolveChunks(chunks); 28 | }, 29 | abort(reason) { 30 | rejectChunks(new Error(`Stream aborted: ${reason}`)); 31 | }, 32 | }); 33 | 34 | return [writableStream, chunksPromise]; 35 | } 36 | 37 | export function createHeaderHandlers() { 38 | const headers: Record = {}; 39 | let statusCode = 200; 40 | 41 | const setHeader = (name: string, value: string) => { 42 | headers[name] = value; 43 | }; 44 | 45 | const setStatusCode = (code: number) => { 46 | statusCode = code; 47 | }; 48 | 49 | return { headers, statusCode, setHeader, setStatusCode }; 50 | } 51 | 52 | export const getKnownApiSecrets: Parameters< 53 | typeof proxyV1 54 | >[0]["getApiSecrets"] = async ( 55 | useCache: boolean, 56 | authToken: string, 57 | model: string | null, 58 | ) => { 59 | const endpointTypes = model && getModelEndpointTypes(model); 60 | if (!endpointTypes?.length) throw new Error(`Unknown model: ${model}`); 61 | 62 | return [ 63 | { 64 | type: "anthropic" as const, 65 | secret: process.env.ANTHROPIC_API_KEY || "", 66 | name: "anthropic", 67 | }, 68 | { 69 | type: "google" as const, 70 | secret: process.env.GEMINI_API_KEY || "", 71 | name: "google", 72 | }, 73 | { 74 | type: "openai" as const, 75 | secret: process.env.OPENAI_API_KEY || "", 76 | name: "openai", 77 | }, 78 | { 79 | type: "vertex" as const, 80 | secret: process.env.VERTEX_AI_API_KEY || "", 81 | name: "vertex", 82 | metadata: { 83 | project: process.env.GCP_PROJECT_ID || "", 84 | authType: "access_token" as const, 85 | api_base: "", 86 | supportsStreaming: true, 87 | excludeDefaultModels: false, 88 | }, 89 | }, 90 | { 91 | type: "bedrock" as const, 92 | secret: process.env.AWS_SECRET_ACCESS_KEY || "", 93 | name: "bedrock" as const, 94 | metadata: { 95 | region: process.env.AWS_REGION || "", 96 | access_key: process.env.AWS_ACCESS_KEY_ID || "", 97 | session_token: process.env.AWS_SESSION_TOKEN || "", 98 | supportsStreaming: true, 99 | excludeDefaultModels: false, 100 | }, 101 | }, 102 | ].filter((secret) => !!secret.secret && endpointTypes.includes(secret.type)); 103 | }; 104 | 105 | export async function callProxyV1({ 106 | body, 107 | ...request 108 | }: Partial, "body">> & { 109 | body: Input; 110 | }) { 111 | const [writableStream, chunksPromise] = createResponseStream(); 112 | const { headers, statusCode, setHeader, setStatusCode } = 113 | createHeaderHandlers(); 114 | 115 | let timeoutId: NodeJS.Timeout | null = null; 116 | const timeoutPromise = new Promise((_, reject) => { 117 | timeoutId = setTimeout(() => { 118 | reject(new Error(`Request timed out after 30s`)); 119 | }, 30000); 120 | }); 121 | 122 | try { 123 | const requestBody = typeof body === "string" ? body : JSON.stringify(body); 124 | 125 | const proxyPromise = proxyV1({ 126 | method: "POST", 127 | url: "/chat/completions", 128 | proxyHeaders: { 129 | "content-type": "application/json", 130 | authorization: `Bearer dummy-token`, 131 | }, 132 | setHeader, 133 | setStatusCode, 134 | res: writableStream, 135 | getApiSecrets: getKnownApiSecrets, 136 | cacheGet: async () => null, 137 | cachePut: async () => {}, 138 | digest: async (message: string) => 139 | Buffer.from(message).toString("base64"), 140 | ...request, 141 | body: requestBody, 142 | }); 143 | 144 | await proxyPromise; 145 | 146 | const chunks = await Promise.race([chunksPromise, timeoutPromise]); 147 | const responseText = new TextDecoder().decode(Buffer.concat(chunks)); 148 | 149 | return { 150 | chunks, 151 | headers, 152 | statusCode, 153 | responseText, 154 | events() { 155 | return chucksToEvents(chunks); 156 | }, 157 | json() { 158 | try { 159 | return JSON.parse(responseText) as Output; 160 | } catch (e) { 161 | return null; 162 | } 163 | }, 164 | }; 165 | } catch (error) { 166 | throw error; 167 | } finally { 168 | if (timeoutId) { 169 | clearTimeout(timeoutId); 170 | } 171 | } 172 | } 173 | 174 | const chucksToEvents = (chunks: Uint8Array[]) => { 175 | const textDecoder = new TextDecoder(); 176 | const results: (Omit & { data: ChunkData })[] = []; 177 | 178 | const parser = createParser((event) => { 179 | if (event.type === "event" && event.data !== "[DONE]") { 180 | results.push({ 181 | ...event, 182 | data: JSON.parse(event.data) as ChunkData, 183 | }); 184 | } 185 | }); 186 | 187 | for (const chunk of chunks) { 188 | parser.feed(textDecoder.decode(chunk)); 189 | } 190 | 191 | return results; 192 | }; 193 | -------------------------------------------------------------------------------- /packages/proxy/vitest.config.js: -------------------------------------------------------------------------------- 1 | import tsconfigPaths from "vite-tsconfig-paths"; 2 | 3 | const config = { 4 | plugins: [tsconfigPaths()], 5 | test: { 6 | exclude: ["**/node_modules/**"], 7 | testTimeout: 30_000, 8 | }, 9 | }; 10 | export default config; 11 | -------------------------------------------------------------------------------- /pnpm-workspace.yaml: -------------------------------------------------------------------------------- 1 | packages: 2 | - "apis/*" 3 | - "packages/*" 4 | -------------------------------------------------------------------------------- /turbo.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://turbo.build/schema.json", 3 | "tasks": { 4 | "build": { 5 | "dependsOn": ["^build"], 6 | "outputs": [".next/**", "!.next/cache/**", "dist/**"], 7 | "env": ["BRAINTRUST_APP_URL", "ORG_NAME", "REDIS_HOST", "REDIS_PORT"] 8 | }, 9 | "test": { 10 | "dependsOn": ["^build"], 11 | "outputs": [] 12 | }, 13 | "lint": { 14 | "outputs": [] 15 | }, 16 | "dev": { 17 | "cache": false 18 | }, 19 | "start": { 20 | "cache": false 21 | }, 22 | "clean": { 23 | "cache": false 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /vitest.config.js: -------------------------------------------------------------------------------- 1 | import tsconfigPaths from "vite-tsconfig-paths"; 2 | 3 | const config = { 4 | plugins: [tsconfigPaths()], 5 | test: { 6 | exclude: ["**/node_modules/**"], 7 | testTimeout: 30_000, 8 | }, 9 | }; 10 | export default config; 11 | --------------------------------------------------------------------------------