├── .cz.toml ├── .github └── workflows │ ├── ci.yml │ ├── docker.yml │ └── release.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.lock ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.md ├── config-example.yaml ├── helm ├── Chart.yaml ├── templates │ ├── _helpers.tpl │ ├── configmap.yaml │ ├── deployment.yaml │ └── service.yaml └── values.yaml ├── img ├── logo-dark.png └── logo-light.png ├── src ├── ai_models │ ├── instance.rs │ ├── mod.rs │ └── registry.rs ├── config │ ├── constants.rs │ ├── lib.rs │ ├── mod.rs │ └── models.rs ├── lib.rs ├── main.rs ├── models │ ├── chat.rs │ ├── completion.rs │ ├── content.rs │ ├── embeddings.rs │ ├── logprob.rs │ ├── mod.rs │ ├── response_format.rs │ ├── streaming.rs │ ├── tool_calls.rs │ ├── tool_choice.rs │ ├── tool_definition.rs │ └── usage.rs ├── pipelines │ ├── mod.rs │ ├── otel.rs │ └── pipeline.rs ├── providers │ ├── anthropic │ │ ├── mod.rs │ │ ├── models.rs │ │ └── provider.rs │ ├── azure │ │ ├── mod.rs │ │ └── provider.rs │ ├── bedrock │ │ ├── logs │ │ │ ├── ai21_j2_mid_v1_completions.json │ │ │ ├── ai21_jamba_1_5_mini_v1_0_chat_completions.json │ │ │ ├── amazon_titan_embed_text_v2_0_embeddings.json │ │ │ ├── anthropic_claude_3_haiku_20240307_v1_0_chat_completion.json │ │ │ └── us_amazon_nova_lite_v1_0_chat_completion.json │ │ ├── mod.rs │ │ ├── models.rs │ │ ├── provider.rs │ │ └── test.rs │ ├── mod.rs │ ├── openai │ │ ├── mod.rs │ │ └── provider.rs │ ├── provider.rs │ ├── registry.rs │ └── vertexai │ │ ├── mod.rs │ │ ├── models.rs │ │ ├── provider.rs │ │ └── tests.rs ├── routes.rs └── state.rs └── tests └── cassettes └── vertexai ├── chat_completions.json ├── chat_completions_with_tools.json ├── completions.json └── embeddings.json /.cz.toml: -------------------------------------------------------------------------------- 1 | [tool.commitizen] 2 | name = "cz_conventional_commits" 3 | tag_format = "v$version" 4 | major_version_zero = true 5 | update_changelog_on_bump = true 6 | version = "0.0.0" 7 | version_files = ["Cargo.toml"] 8 | version_provider = "cargo" 9 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Build and test 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Check formatting 19 | run: cargo fmt --check 20 | - name: Check clippy 21 | run: cargo clippy -- -D warnings 22 | - name: Build 23 | run: cargo build --verbose 24 | - name: Run tests 25 | run: cargo test --verbose 26 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docker Images 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | workflow_dispatch: 8 | 9 | jobs: 10 | publish: 11 | name: Publish Docker Images 12 | runs-on: ubuntu-latest 13 | permissions: 14 | id-token: write 15 | contents: read 16 | packages: write 17 | 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up QEMU 23 | uses: docker/setup-qemu-action@v3 24 | - name: Set up Docker Buildx 25 | uses: docker/setup-buildx-action@v3 26 | 27 | - name: Log in to the GitHub Container registry 28 | uses: docker/login-action@v2 29 | with: 30 | registry: ghcr.io 31 | username: ${{ github.actor }} 32 | password: ${{ secrets.GITHUB_TOKEN }} 33 | 34 | - name: Login to Docker Hub 35 | uses: docker/login-action@v2 36 | with: 37 | username: ${{ secrets.DOCKERHUB_USERNAME }} 38 | password: ${{ secrets.DOCKERHUB_TOKEN }} 39 | 40 | - name: Extract metadata for Docker 41 | id: docker-metadata 42 | uses: docker/metadata-action@v4 43 | with: 44 | images: | 45 | ghcr.io/traceloop/hub # GitHub 46 | traceloop/hub # Docker Hub 47 | tags: | 48 | type=sha 49 | type=semver,pattern={{version}} 50 | type=semver,pattern={{major}}.{{minor}} 51 | - name: Build and push Docker image 52 | uses: docker/build-push-action@v4 53 | with: 54 | context: . 55 | push: true 56 | tags: ${{ steps.docker-metadata.outputs.tags }} 57 | labels: ${{ steps.docker-metadata.outputs.labels }} 58 | platforms: | 59 | linux/amd64, linux/arm64 60 | deploy: 61 | name: Deploy to Traceloop 62 | runs-on: ubuntu-latest 63 | needs: publish 64 | steps: 65 | - name: Install Octopus CLI 66 | uses: OctopusDeploy/install-octopus-cli-action@v3 67 | with: 68 | version: "*" 69 | 70 | - name: Create Octopus Release 71 | env: 72 | OCTOPUS_API_KEY: ${{ secrets.OCTOPUS_API_KEY }} 73 | OCTOPUS_URL: ${{ secrets.OCTOPUS_URL }} 74 | OCTOPUS_SPACE: ${{ secrets.OCTOPUS_SPACE }} 75 | run: octopus release create --project hub --version=sha-${GITHUB_SHA::7} --packageVersion=sha-${GITHUB_SHA::7} --no-prompt 76 | 77 | - name: Deploy Octopus release 78 | env: 79 | OCTOPUS_API_KEY: ${{ secrets.OCTOPUS_API_KEY }} 80 | OCTOPUS_URL: ${{ secrets.OCTOPUS_URL }} 81 | OCTOPUS_SPACE: ${{ secrets.OCTOPUS_SPACE }} 82 | run: octopus release deploy --project hub --version=sha-${GITHUB_SHA::7} --environment Staging --no-prompt 83 | push-helm-chart: 84 | name: Push Helm Chart to Dockerhub 85 | runs-on: ubuntu-latest 86 | needs: publish 87 | permissions: 88 | contents: write 89 | 90 | steps: 91 | - name: Checkout 92 | uses: actions/checkout@v4 93 | 94 | - name: Get Chart Version 95 | id: chartVersion 96 | run: | 97 | CHART_VERSION=$(grep '^version:' helm/Chart.yaml | awk '{print $2}') 98 | echo "CHART_VERSION=$CHART_VERSION" 99 | echo "chart_version=$CHART_VERSION" >> $GITHUB_OUTPUT 100 | 101 | - name: Get Chart Name 102 | id: chartName 103 | run: | 104 | CHART_NAME=$(grep '^name:' helm/Chart.yaml | awk '{print $2}') 105 | echo "CHART_NAME=$CHART_NAME" 106 | echo "chart_name=$CHART_NAME" >> $GITHUB_OUTPUT 107 | 108 | - name: Login to Docker Hub as OCI registry 109 | run: | 110 | echo "${{ secrets.DOCKERHUB_TOKEN }}" | helm registry login registry-1.docker.io \ 111 | --username "${{ secrets.DOCKERHUB_USERNAME }}" \ 112 | --password-stdin 113 | 114 | - name: Package Helm chart 115 | run: | 116 | helm package helm/ 117 | 118 | - name: Push Helm Chart 119 | run: | 120 | helm push "${{ steps.chartName.outputs.chart_name }}-${{ steps.chartVersion.outputs.chart_version }}.tgz" oci://registry-1.docker.io/traceloop 121 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release a New Version 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | release: 8 | name: Bump Versions 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | with: 14 | persist-credentials: false 15 | fetch-depth: 0 16 | 17 | - id: cz 18 | name: Bump Version, Create Tag and Changelog 19 | uses: commitizen-tools/commitizen-action@master 20 | with: 21 | github_token: ${{ secrets.GH_ACCESS_TOKEN }} 22 | changelog_increment_filename: body.md 23 | 24 | - name: Create Release 25 | uses: softprops/action-gh-release@v2 26 | with: 27 | body_path: "body.md" 28 | tag_name: ${{ env.REVISION }} 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} 31 | 32 | - name: Print Version 33 | run: echo "Bumped to version ${{ steps.cz.outputs.version }}" 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | config.yaml 17 | 18 | # RustRover 19 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 20 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 21 | # and can be added to the global gitignore or merged into this file. For a more nuclear 22 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 23 | #.idea/ 24 | .vscode/ -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.4.3 (2025-05-29) 2 | 3 | ### Fix 4 | 5 | - make general optional again (#43) 6 | 7 | ## v0.4.2 (2025-05-22) 8 | 9 | ### Fix 10 | 11 | - **tracing**: support disabling tracing of prompts and completions (#42) 12 | 13 | ## v0.4.1 (2025-05-20) 14 | 15 | ### Fix 16 | 17 | - **openai**: support custom base URL (#40) 18 | - **azure**: add support for custom base URL in AzureProvider endpoint (#41) 19 | 20 | ## v0.4.0 (2025-05-16) 21 | 22 | ### Feat 23 | 24 | - **provider**: add Google VertexAI support (#24) 25 | - support AWS bedrock base models (#25) 26 | - add max_completion_tokens to ChatCompletionRequest (#36) 27 | - support structured output (#33) 28 | 29 | ### Fix 30 | 31 | - replace eprintln with tracing info for API request errors in Azure and OpenAI providers (#37) 32 | - make optional json_schema field to ResponseFormat (#35) 33 | 34 | ## v0.3.0 (2025-03-04) 35 | 36 | ### Feat 37 | 38 | - add logprobs and top_logprobs options to ChatCompletionRequest (#27) 39 | 40 | ### Fix 41 | 42 | - **cd**: correct docker hub secrets (#31) 43 | - **azure**: embeddings structs improvement (#29) 44 | - add proper error logging for azure and openai calls (#18) 45 | - **anthropic**: separate system from messages (#17) 46 | 47 | ## v0.2.1 (2024-12-01) 48 | 49 | ### Fix 50 | 51 | - tool call support (#16) 52 | - restructure providers, separate request/response conversion (#15) 53 | 54 | ## v0.2.0 (2024-11-25) 55 | 56 | ### Feat 57 | 58 | - **openai**: support streaming (#10) 59 | - add prometheus metrics (#13) 60 | - **cd**: deploy to traceloop on workflow distpatch (#11) 61 | 62 | ### Fix 63 | 64 | - config file path from env var instead of command argument (#12) 65 | 66 | ## v0.1.0 (2024-11-16) 67 | 68 | ### Feat 69 | 70 | - otel support (#7) 71 | - implement pipeline steering logic (#5) 72 | - dynamic pipeline routing (#4) 73 | - azure openai provider (#3) 74 | - initial completions and embeddings routes with openai and anthropic providers (#1) 75 | 76 | ### Fix 77 | 78 | - dockerfile and release pipeline (#2) 79 | - make anthropic work (#8) 80 | - cleanups; lint warnings fail CI (#9) 81 | - missing model name in response; 404 for model not found (#6) 82 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hub" 3 | version = "0.4.3" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | axum = "0.7" 10 | tokio = { version = "1.0", features = ["full"] } 11 | serde = { version = "1.0", features = ["derive"] } 12 | reqwest = { version = "0.12", features = ["json", "stream"] } 13 | serde_json = "1.0" 14 | tracing = "0.1" 15 | tracing-subscriber = "0.3" 16 | serde_yaml = "0.9" 17 | tower = { version = "0.5.1", features = ["full"] } 18 | anyhow = "1.0.92" 19 | chrono = "0.4.38" 20 | opentelemetry = { version = "0.27", default-features = false, features = [ 21 | "trace", 22 | ] } 23 | opentelemetry_sdk = { version = "0.27", default-features = false, features = [ 24 | "trace", 25 | "rt-tokio", 26 | ] } 27 | opentelemetry-semantic-conventions = { version = "0.27.0", features = [ 28 | "semconv_experimental", 29 | ] } 30 | opentelemetry-otlp = { version = "0.27.0", features = [ 31 | "http-proto", 32 | "reqwest-client", 33 | "reqwest-rustls", 34 | ] } 35 | axum-prometheus = "0.7.0" 36 | reqwest-streams = { version = "0.8.1", features = ["json"] } 37 | futures = "0.3.31" 38 | async-stream = "0.3.6" 39 | yup-oauth2 = "8.3.0" 40 | aws-sdk-bedrockruntime = "1.66.0" 41 | aws-config = "1.5.12" 42 | aws-credential-types = { version = "1.2.1", features = [ 43 | "hardcoded-credentials", 44 | ] } 45 | http = "1.1.0" 46 | aws-smithy-runtime = { version = "1.7.6", features = ["test-util"] } 47 | aws-smithy-types = "1.2.11" 48 | aws-types = "1.3.3" 49 | uuid = { version = "1.16.0", features = ["v4"] } 50 | 51 | [dev-dependencies] 52 | surf = "2.3.2" 53 | surf-vcr = "0.2.0" 54 | wiremock = "0.5" 55 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust:1.82-bookworm AS builder 2 | 3 | WORKDIR /app 4 | COPY . . 5 | RUN cargo build --release --bin hub 6 | 7 | FROM debian:bookworm-slim AS runtime 8 | RUN apt-get update && apt-get install -y openssl ca-certificates 9 | WORKDIR /app 10 | COPY --from=builder /app/target/release/hub /usr/local/bin 11 | WORKDIR /etc 12 | 13 | ENV PORT 3000 14 | EXPOSE 3000 15 | 16 | ENTRYPOINT ["/usr/local/bin/hub"] 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 | 6 | 7 | 8 |

9 |

10 |

Open-source, high-performance LLM gateway written in Rust. Connect to any LLM provider with a single API. Observability Included.

11 |

12 |

13 | Get started » 14 |
15 |
16 | Slack | 17 | Docs 18 |

19 | 20 |

21 | 22 | 23 | 24 | 25 | Traceloop Hub is released under the Apache-2.0 License 26 | 27 | 28 | 29 | 30 | 31 | git commit activity 32 | 33 | 34 | 35 | PRs welcome! 36 | 37 | 38 | Slack community channel 39 | 40 | 41 | Traceloop Twitter 42 | 43 |

44 | 45 | Hub is a next generation smart proxy for LLM applications. It centralizes control and tracing of all LLM calls and traces. 46 | It's built in Rust so it's fast and efficient. It's completely open-source and free to use. 47 | 48 | Built and maintained by Traceloop under the Apache 2.0 license. 49 | 50 | ## 🚀 Getting Started 51 | 52 | Make sure to copy a `config.yaml` file from `config-example.yaml` and set the correct values, following the [configuration](https://www.traceloop.com/docs/hub/configuration) instructions. 53 | 54 | You can then run the hub using the docker image: 55 | 56 | ``` 57 | docker run --rm -p 3000:3000 -v $(pwd)/config.yaml:/etc/hub/config.yaml:ro -e CONFIG_FILE_PATH='/etc/hub/config.yaml' -t traceloop/hub 58 | ``` 59 | 60 | You can also run it locally. Make sure you have `rust` v1.82 and above installed and then run: 61 | 62 | ``` 63 | cargo run 64 | ``` 65 | 66 | Connect to the hub by using the OpenAI SDK on any language, and setting the base URL to: 67 | 68 | ``` 69 | http://localhost:3000/api/v1 70 | ``` 71 | 72 | For example, in Python: 73 | 74 | ``` 75 | client = OpenAI( 76 | base_url="http://localhost:3000/api/v1", 77 | api_key=os.getenv("OPENAI_API_KEY"), 78 | # default_headers={"x-traceloop-pipeline": "azure-only"}, 79 | ) 80 | completion = client.chat.completions.create( 81 | model="claude-3-5-sonnet-20241022", 82 | messages=[{"role": "user", "content": "Tell me a joke about opentelemetry"}], 83 | max_tokens=1000, 84 | ) 85 | ``` 86 | 87 | ## 🌱 Contributing 88 | 89 | Whether big or small, we love contributions ❤️ Check out our guide to see how to [get started](https://traceloop.com/docs/hub/contributing/overview). 90 | 91 | Not sure where to get started? You can: 92 | 93 | - [Book a free pairing session with one of our teammates](mailto:nir@traceloop.com?subject=Pairing%20session&body=I'd%20like%20to%20do%20a%20pairing%20session!)! 94 | - Join our Slack, and ask us any questions there. 95 | 96 | ## 💚 Community & Support 97 | 98 | - [Slack](https://traceloop.com/slack) (For live discussion with the community and the Traceloop team) 99 | - [GitHub Discussions](https://github.com/traceloop/hub/discussions) (For help with building and deeper conversations about features) 100 | - [GitHub Issues](https://github.com/traceloop/hub/issues) (For any bugs and errors you encounter using OpenLLMetry) 101 | - [Twitter](https://twitter.com/traceloopdev) (Get news fast) 102 | 103 | # Hub 104 | 105 | A unified API interface for routing LLM requests to various providers. 106 | 107 | ## Supported Providers 108 | 109 | - OpenAI 110 | - Anthropic 111 | - Azure OpenAI 112 | - Google VertexAI (Gemini) 113 | 114 | ## Configuration 115 | 116 | See `config-example.yaml` for a complete configuration example. 117 | 118 | ### Provider Configuration 119 | 120 | #### OpenAI 121 | 122 | ```yaml 123 | providers: 124 | - key: openai 125 | type: openai 126 | api_key: "" 127 | ``` 128 | 129 | #### Azure OpenAI 130 | 131 | ```yaml 132 | providers: 133 | - key: azure-openai 134 | type: azure 135 | api_key: "" 136 | resource_name: "" 137 | api_version: "" 138 | ``` 139 | 140 | #### Google VertexAI (Gemini) 141 | 142 | ```yaml 143 | providers: 144 | - key: vertexai 145 | type: vertexai 146 | api_key: "" 147 | project_id: "" 148 | location: "" 149 | credentials_path: "/path/to/service-account.json" 150 | ``` 151 | 152 | Authentication Methods: 153 | 1. API Key Authentication: 154 | - Set the `api_key` field with your GCP API key 155 | - Leave `credentials_path` empty 156 | 2. Service Account Authentication: 157 | - Set `credentials_path` to your service account JSON file path 158 | - Can also use `GOOGLE_APPLICATION_CREDENTIALS` environment variable 159 | - Leave `api_key` empty when using service account auth 160 | 161 | Supported Features: 162 | - Chat Completions (with Gemini models) 163 | - Text Completions 164 | - Embeddings 165 | - Streaming Support 166 | - Function/Tool Calling 167 | - Multi-modal Inputs (images + text) 168 | 169 | Example Model Configuration: 170 | ```yaml 171 | models: 172 | # Chat and Completion model 173 | - key: gemini-1.5-flash 174 | type: gemini-1.5-flash 175 | provider: vertexai 176 | 177 | # Embeddings model 178 | - key: textembedding-gecko 179 | type: textembedding-gecko 180 | provider: vertexai 181 | ``` 182 | 183 | Example Usage with OpenAI SDK: 184 | ```python 185 | from openai import OpenAI 186 | 187 | client = OpenAI( 188 | base_url="http://localhost:3000/api/v1", 189 | api_key="not-needed-for-vertexai" 190 | ) 191 | 192 | # Chat completion 193 | response = client.chat.completions.create( 194 | model="gemini-1.5-flash", 195 | messages=[{"role": "user", "content": "Tell me a joke"}] 196 | ) 197 | 198 | # Embeddings 199 | response = client.embeddings.create( 200 | model="textembedding-gecko", 201 | input="Sample text for embedding" 202 | ) 203 | ``` 204 | 205 | ### Pipeline Configuration 206 | 207 | ```yaml 208 | pipelines: 209 | - name: default 210 | type: chat 211 | plugins: 212 | - model-router: 213 | models: 214 | - gemini-pro 215 | ``` 216 | 217 | ## Development 218 | 219 | ### Running Tests 220 | 221 | The test suite uses recorded HTTP interactions (cassettes) to make tests reproducible without requiring actual API credentials. 222 | 223 | To run tests: 224 | ```bash 225 | cargo test 226 | ``` 227 | 228 | To record new test cassettes: 229 | 1. Set up your API credentials: 230 | - For service account auth: Set `VERTEXAI_CREDENTIALS_PATH` to your service account key file path 231 | - For API key auth: Use the test with API key (currently marked as ignored) 232 | 2. Delete the existing cassette files in `tests/cassettes/vertexai/` 233 | 3. Run the tests with recording enabled: 234 | ```bash 235 | RECORD_MODE=1 cargo test 236 | ``` 237 | 238 | Additional test configurations: 239 | - `RETRY_DELAY`: Set the delay in seconds between retries when hitting quota limits (default: 60) 240 | - Tests automatically retry up to 3 times when hitting quota limits 241 | 242 | Note: Some tests may be marked as `#[ignore]` if they require specific credentials or are not ready for general use. 243 | 244 | ## License 245 | 246 | See LICENSE file. 247 | -------------------------------------------------------------------------------- /config-example.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | trace_content_enabled: true # Optional, defaults to true, set to false to disable tracing of request and response content 3 | providers: 4 | # Azure OpenAI configuration 5 | - key: azure-openai 6 | type: azure 7 | api_key: "" 8 | resource_name: "" 9 | api_version: "" 10 | 11 | # OpenAI configuration 12 | - key: openai 13 | type: openai 14 | api_key: "" 15 | base_url: "optional base url. If not provided, defaults to https://api.openai.com/v1" 16 | - key: bedrock 17 | type: bedrock 18 | api_key: ""# Not used for AWS Bedrock 19 | region: "" # like "us-east-1" 20 | inference_profile_id: "" # like "us" 21 | AWS_ACCESS_KEY_ID: "" 22 | AWS_SECRET_ACCESS_KEY: "" 23 | AWS_SESSION_TOKEN: "" # Optional 24 | 25 | # Vertex AI configuration 26 | # Uses service account authentication 27 | - key: vertexai 28 | type: vertexai 29 | api_key: "" # Required field but not used with service account auth 30 | project_id: "" 31 | location: "" # e.g., us-central1 32 | credentials_path: "/path/to/service-account.json" # Path to your service account key file 33 | 34 | models: 35 | # OpenAI Models 36 | - key: gpt-4 37 | type: gpt-4 38 | provider: openai 39 | - key: gpt-3.5-turbo 40 | type: gpt-3.5-turbo 41 | provider: openai 42 | 43 | # Azure OpenAI Models 44 | - key: gpt-4-azure 45 | type: gpt-4 46 | provider: azure-openai 47 | deployment: "" 48 | - key: gpt-35-turbo-azure 49 | type: gpt-35-turbo 50 | provider: azure-openai 51 | deployment: "" 52 | 53 | # Bedrock Models 54 | - key: bedrock-model 55 | # some models are region specific, it is a good idea to get ARN from cross region reference tab 56 | type: "< model-id or Inference profile ARN or Inference profile ID>" 57 | provider: bedrock 58 | model_provider: "anthropic" # can be: ai21, titan, anthropic 59 | model_version: "v2:0" # optional, defaults to "v1:0" 60 | 61 | # Vertex AI Models 62 | # Chat and Completion model 63 | - key: gemini-1.5-flash 64 | type: gemini-1.5-flash # Supports both chat and completion endpoints 65 | provider: vertexai 66 | # Embeddings model 67 | - key: textembedding-gecko 68 | type: textembedding-gecko # Supports embeddings endpoint 69 | provider: vertexai 70 | deployment: "" 71 | 72 | pipelines: 73 | # Default pipeline for chat completions 74 | - name: default 75 | type: chat 76 | plugins: 77 | - logging: 78 | level: info # Supported levels: debug, info, warning, error 79 | - tracing: # Optional tracing configuration 80 | endpoint: "https://api.traceloop.com/v1/traces" 81 | api_key: "" 82 | - model-router: 83 | models: # List the models you want to use for chat 84 | - gpt-4 85 | - gpt-4-azure 86 | - gemini-1.5-flash 87 | 88 | # Pipeline for text completions 89 | - name: completions 90 | type: completion 91 | plugins: 92 | - model-router: 93 | models: # List the models you want to use for completions 94 | - gpt-3.5-turbo 95 | - gpt-35-turbo-azure 96 | - gemini-1.5-flash 97 | 98 | # Pipeline for embeddings 99 | - name: embeddings 100 | type: embeddings 101 | plugins: 102 | - model-router: 103 | models: # List the models you want to use for embeddings 104 | - textembedding-gecko 105 | -------------------------------------------------------------------------------- /helm/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: helm-hub 3 | description: A Helm chart for Hub application 4 | type: application 5 | version: 0.3.0 6 | -------------------------------------------------------------------------------- /helm/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "app.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "app.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "app.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "app.labels" -}} 37 | helm.sh/chart: {{ include "app.chart" . }} 38 | {{ include "app.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.Version | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "app.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "app.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | -------------------------------------------------------------------------------- /helm/templates/configmap.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ConfigMap 3 | metadata: 4 | name: {{ include "app.fullname" . }}-config 5 | labels: 6 | {{- include "app.labels" . | nindent 4 }} 7 | data: 8 | config.yaml: | 9 | providers: 10 | {{- range .Values.config.providers }} 11 | - {{ toYaml . | nindent 8 | trim }} 12 | {{- end }} 13 | models: 14 | {{- range .Values.config.models }} 15 | - key: {{ .key }} 16 | type: {{ .type }} 17 | provider: {{ .provider }} 18 | {{- if .deployment }} 19 | deployment: "{{ .deployment }}" 20 | {{- end }} 21 | {{- end }} 22 | 23 | pipelines: 24 | {{- range .Values.config.pipelines }} 25 | - name: {{ .name }} 26 | type: {{ .type }} 27 | plugins: 28 | {{- range .plugins }} 29 | - {{ toYaml . | nindent 12 | trim }} 30 | {{- end }} 31 | {{- end }} 32 | -------------------------------------------------------------------------------- /helm/templates/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: {{ include "app.fullname" . }} 5 | labels: 6 | {{- include "app.labels" . | nindent 4 }} 7 | spec: 8 | replicas: {{ .Values.replicaCount }} 9 | selector: 10 | matchLabels: 11 | {{- include "app.selectorLabels" . | nindent 6 }} 12 | template: 13 | metadata: 14 | labels: 15 | {{- include "app.selectorLabels" . | nindent 8 }} 16 | {{- with .Values.podAnnotations }} 17 | annotations: 18 | {{- toYaml . | nindent 8 }} 19 | {{- end }} 20 | spec: 21 | securityContext: 22 | {{- toYaml .Values.podSecurityContext | nindent 8 }} 23 | initContainers: 24 | - name: config-init 25 | image: busybox:latest 26 | command: ["sh", "-c"] 27 | args: 28 | - | 29 | cp /config-template/config.yaml /config/config.yaml 30 | # Dynamic substitution for all API keys 31 | {{- range $key, $value := .Values.secrets.keyMapping }} 32 | {{- $upperKey := upper (regexReplaceAll "ApiKey$" $key "") }} 33 | sed -i "s/{{ $upperKey }}_API_KEY_PLACEHOLDER/${{ $upperKey }}_API_KEY/g" /config/config.yaml 34 | {{- end }} 35 | env: 36 | {{- range $key, $value := .Values.secrets.keyMapping }} 37 | {{- $upperKey := upper (regexReplaceAll "ApiKey$" $key "") }} 38 | - name: {{ $upperKey }}_API_KEY 39 | valueFrom: 40 | secretKeyRef: 41 | name: {{ $.Values.secrets.existingSecretName }} 42 | key: {{ $value }} 43 | optional: {{ ne $key "openaiApiKey" }} 44 | {{- end }} 45 | volumeMounts: 46 | - name: config-template 47 | mountPath: /config-template 48 | - name: config-volume 49 | mountPath: /config 50 | containers: 51 | - name: {{ .Chart.Name }} 52 | securityContext: 53 | {{- toYaml .Values.securityContext | nindent 12 }} 54 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.Version }}" 55 | imagePullPolicy: {{ .Values.image.pullPolicy }} 56 | workingDir: /app 57 | env: 58 | - name: PORT 59 | value: {{ .Values.service.port | quote }} 60 | ports: 61 | - name: http 62 | containerPort: {{ .Values.service.port }} 63 | protocol: TCP 64 | livenessProbe: 65 | httpGet: 66 | path: /health 67 | port: http 68 | initialDelaySeconds: {{ .Values.probes.liveness.initialDelaySeconds | default 30 }} 69 | periodSeconds: {{ .Values.probes.liveness.periodSeconds | default 10 }} 70 | timeoutSeconds: {{ .Values.probes.liveness.timeoutSeconds | default 5 }} 71 | successThreshold: {{ .Values.probes.liveness.successThreshold | default 1 }} 72 | failureThreshold: {{ .Values.probes.liveness.failureThreshold | default 3 }} 73 | readinessProbe: 74 | httpGet: 75 | path: /health 76 | port: http 77 | initialDelaySeconds: {{ .Values.probes.readiness.initialDelaySeconds | default 10 }} 78 | periodSeconds: {{ .Values.probes.readiness.periodSeconds | default 10 }} 79 | timeoutSeconds: {{ .Values.probes.readiness.timeoutSeconds | default 5 }} 80 | successThreshold: {{ .Values.probes.readiness.successThreshold | default 1 }} 81 | failureThreshold: {{ .Values.probes.readiness.failureThreshold | default 3 }} 82 | volumeMounts: 83 | - name: config-volume 84 | mountPath: /app/config.yaml 85 | subPath: config.yaml 86 | resources: 87 | {{- toYaml .Values.resources | nindent 12 }} 88 | volumes: 89 | - name: config-template 90 | configMap: 91 | name: {{ include "app.fullname" . }}-config 92 | - name: config-volume 93 | emptyDir: {} 94 | {{- with .Values.nodeSelector }} 95 | nodeSelector: 96 | {{- toYaml . | nindent 8 }} 97 | {{- end }} 98 | {{- with .Values.affinity }} 99 | affinity: 100 | {{- toYaml . | nindent 8 }} 101 | {{- end }} 102 | {{- with .Values.tolerations }} 103 | tolerations: 104 | {{- toYaml . | nindent 8 }} 105 | {{- end }} 106 | -------------------------------------------------------------------------------- /helm/templates/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: {{ include "app.fullname" . }} 5 | labels: 6 | {{- include "app.labels" . | nindent 4 }} 7 | spec: 8 | type: {{ .Values.service.type }} 9 | ports: 10 | - port: {{ .Values.service.port }} 11 | targetPort: http 12 | protocol: TCP 13 | name: http 14 | selector: 15 | {{- include "app.selectorLabels" . | nindent 4 }} 16 | -------------------------------------------------------------------------------- /helm/values.yaml: -------------------------------------------------------------------------------- 1 | # Default values for hub. 2 | # This is a YAML-formatted file. 3 | # Declare variables to be passed into your templates. 4 | 5 | replicaCount: 1 6 | 7 | image: 8 | repository: docker.io/traceloop/hub 9 | pullPolicy: IfNotPresent 10 | # Overrides the image tag whose default is the chart appVersion. 11 | # tag: "latest" 12 | # Optionally specify an array of imagePullSecrets. 13 | # imagePullSecrets: 14 | # - name: myRegistryKeySecretName 15 | 16 | nameOverride: "" 17 | fullnameOverride: "" 18 | 19 | podAnnotations: {} 20 | 21 | podSecurityContext: {} 22 | # fsGroup: 2000 23 | 24 | securityContext: {} 25 | # capabilities: 26 | # drop: 27 | # - ALL 28 | # readOnlyRootFilesystem: true 29 | # runAsNonRoot: true 30 | # runAsUser: 1000 31 | 32 | service: 33 | type: ClusterIP 34 | port: 3100 35 | 36 | # Kubernetes probe configuration 37 | probes: 38 | liveness: 39 | initialDelaySeconds: 30 40 | periodSeconds: 10 41 | timeoutSeconds: 5 42 | successThreshold: 1 43 | failureThreshold: 3 44 | readiness: 45 | initialDelaySeconds: 10 46 | periodSeconds: 10 47 | timeoutSeconds: 5 48 | successThreshold: 1 49 | failureThreshold: 3 50 | 51 | resources: 52 | limits: 53 | cpu: 300m 54 | memory: 400Mi 55 | requests: 56 | cpu: 100m 57 | memory: 200Mi 58 | 59 | nodeSelector: {} 60 | 61 | tolerations: [] 62 | 63 | affinity: {} 64 | 65 | # Environment variables to pass to the hub container 66 | env: [] 67 | # - name: DEBUG 68 | # value: "true" 69 | 70 | # Configuration for the hub 71 | config: 72 | # Define providers that will be available in the hub 73 | providers: 74 | - key: openai 75 | # Configuration for the OpenAI provider 76 | # Secret management configuration 77 | secrets: 78 | # Name of the existing Kubernetes secret containing API keys 79 | existingSecretName: "llm-api-keys" 80 | # Mapping of API keys in the secret 81 | keyMapping: 82 | # The key in the secret for OpenAI API key 83 | openaiApiKey: "OPENAI_API_KEY" 84 | # Uncomment to use Azure OpenAI API key 85 | # azureApiKey: "AZURE_API_KEY" 86 | # Uncomment to use Anthropic API key 87 | # anthropicApiKey: "ANTHROPIC_API_KEY" 88 | # Uncomment to use Traceloop API key for tracing 89 | # traceloopApiKey: "TRACELOOP_API_KEY" 90 | # Add mappings for any other provider API keys as needed 91 | -------------------------------------------------------------------------------- /img/logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/traceloop/hub/0beb6ad838abe80d32c6f924598f9be2b95747cf/img/logo-dark.png -------------------------------------------------------------------------------- /img/logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/traceloop/hub/0beb6ad838abe80d32c6f924598f9be2b95747cf/img/logo-light.png -------------------------------------------------------------------------------- /src/ai_models/instance.rs: -------------------------------------------------------------------------------- 1 | use crate::config::models::ModelConfig; 2 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 3 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 4 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; 5 | use crate::providers::provider::Provider; 6 | use axum::http::StatusCode; 7 | use std::sync::Arc; 8 | 9 | pub struct ModelInstance { 10 | pub name: String, 11 | pub model_type: String, 12 | pub provider: Arc, 13 | pub config: ModelConfig, 14 | } 15 | 16 | impl ModelInstance { 17 | pub async fn chat_completions( 18 | &self, 19 | mut payload: ChatCompletionRequest, 20 | ) -> Result { 21 | payload.model = self.model_type.clone(); 22 | self.provider.chat_completions(payload, &self.config).await 23 | } 24 | 25 | pub async fn completions( 26 | &self, 27 | mut payload: CompletionRequest, 28 | ) -> Result { 29 | payload.model = self.model_type.clone(); 30 | 31 | self.provider.completions(payload, &self.config).await 32 | } 33 | 34 | pub async fn embeddings( 35 | &self, 36 | mut payload: EmbeddingsRequest, 37 | ) -> Result { 38 | payload.model = self.model_type.clone(); 39 | self.provider.embeddings(payload, &self.config).await 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/ai_models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod instance; 2 | pub mod registry; 3 | -------------------------------------------------------------------------------- /src/ai_models/registry.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use std::collections::HashMap; 3 | use std::sync::Arc; 4 | 5 | use super::instance::ModelInstance; 6 | use crate::config::models::ModelConfig; 7 | use crate::providers::registry::ProviderRegistry; 8 | 9 | #[derive(Clone)] 10 | pub struct ModelRegistry { 11 | models: HashMap>, 12 | } 13 | 14 | impl ModelRegistry { 15 | pub fn new( 16 | model_configs: &[ModelConfig], 17 | provider_registry: Arc, 18 | ) -> Result { 19 | let mut models = HashMap::new(); 20 | 21 | for config in model_configs { 22 | if let Some(provider) = provider_registry.get(&config.provider) { 23 | let model = Arc::new(ModelInstance { 24 | name: config.key.clone(), 25 | model_type: config.r#type.clone(), 26 | provider, 27 | config: config.clone(), 28 | }); 29 | 30 | models.insert(config.key.clone(), model); 31 | } 32 | } 33 | 34 | Ok(Self { models }) 35 | } 36 | 37 | pub fn get(&self, name: &str) -> Option> { 38 | self.models.get(name).cloned() 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/config/constants.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | pub fn stream_buffer_size_bytes() -> usize { 4 | env::var("STREAM_BUFFER_SIZE_BYTES") 5 | .unwrap_or_else(|_| "1000".to_string()) 6 | .parse() 7 | .unwrap_or(1000) 8 | } 9 | 10 | pub fn default_max_tokens() -> u32 { 11 | env::var("DEFAULT_MAX_TOKENS") 12 | .unwrap_or_else(|_| "4096".to_string()) 13 | .parse() 14 | .unwrap_or(4096) 15 | } 16 | 17 | // Required field for the TitanEmbeddingRequest 18 | pub fn default_embedding_dimension() -> u32 { 19 | env::var("DEFAULT_EMBEDDING_DIMENSION") 20 | .unwrap_or_else(|_| "512".to_string()) 21 | .parse() 22 | .unwrap_or(512) 23 | } 24 | // Required field for the TitanEmbeddingRequest 25 | pub fn default_embedding_normalize() -> bool { 26 | env::var("DEFAULT_EMBEDDING_NORMALIZE") 27 | .unwrap_or_else(|_| "true".to_string()) 28 | .parse() 29 | .unwrap_or(true) 30 | } 31 | -------------------------------------------------------------------------------- /src/config/lib.rs: -------------------------------------------------------------------------------- 1 | use std::sync::OnceLock; 2 | 3 | use crate::config::models::Config; 4 | 5 | pub static TRACE_CONTENT_ENABLED: OnceLock = OnceLock::new(); 6 | 7 | pub fn load_config(path: &str) -> Result> { 8 | let contents = std::fs::read_to_string(path)?; 9 | let config: Config = serde_yaml::from_str(&contents)?; 10 | TRACE_CONTENT_ENABLED 11 | .set( 12 | config 13 | .general 14 | .as_ref() 15 | .is_none_or(|g| g.trace_content_enabled), 16 | ) 17 | .expect("Failed to set trace content enabled flag"); 18 | Ok(config) 19 | } 20 | 21 | pub fn get_trace_content_enabled() -> bool { 22 | *TRACE_CONTENT_ENABLED.get_or_init(|| true) 23 | } 24 | -------------------------------------------------------------------------------- /src/config/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod constants; 2 | pub mod lib; 3 | pub mod models; 4 | -------------------------------------------------------------------------------- /src/config/models.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Deserialize, Serialize, Clone)] 6 | pub struct Config { 7 | pub general: Option, 8 | pub providers: Vec, 9 | pub models: Vec, 10 | pub pipelines: Vec, 11 | } 12 | 13 | #[derive(Debug, Deserialize, Serialize, Clone, Default)] 14 | pub struct General { 15 | #[serde(default = "default_trace_content_enabled")] 16 | pub trace_content_enabled: bool, 17 | } 18 | 19 | #[derive(Debug, Deserialize, Serialize, Clone, Default)] 20 | pub struct Provider { 21 | pub key: String, 22 | pub r#type: String, 23 | #[serde(default = "no_api_key")] 24 | pub api_key: String, 25 | #[serde(flatten)] 26 | pub params: HashMap, 27 | } 28 | 29 | #[derive(Debug, Deserialize, Serialize, Clone, Default)] 30 | pub struct ModelConfig { 31 | pub key: String, 32 | pub r#type: String, 33 | pub provider: String, 34 | #[serde(flatten)] 35 | pub params: HashMap, 36 | } 37 | 38 | #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] 39 | #[serde(rename_all = "lowercase")] 40 | pub enum PipelineType { 41 | Chat, 42 | Completion, 43 | Embeddings, 44 | } 45 | 46 | #[derive(Debug, Deserialize, Serialize, Clone)] 47 | pub struct Pipeline { 48 | pub name: String, 49 | pub r#type: PipelineType, 50 | #[serde(with = "serde_yaml::with::singleton_map_recursive")] 51 | pub plugins: Vec, 52 | } 53 | 54 | #[derive(Debug, Deserialize, Serialize, Clone)] 55 | #[serde(rename_all = "kebab-case")] 56 | pub enum PluginConfig { 57 | Logging { 58 | #[serde(default = "default_log_level")] 59 | level: String, 60 | }, 61 | Tracing { 62 | endpoint: String, 63 | api_key: String, 64 | }, 65 | ModelRouter { 66 | models: Vec, 67 | }, 68 | } 69 | 70 | fn default_trace_content_enabled() -> bool { 71 | true 72 | } 73 | 74 | fn default_log_level() -> String { 75 | "warning".to_string() 76 | } 77 | 78 | fn no_api_key() -> String { 79 | "".to_string() 80 | } 81 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod ai_models; 2 | pub mod config; 3 | pub mod models; 4 | pub mod pipelines; 5 | pub mod providers; 6 | pub mod routes; 7 | pub mod state; 8 | 9 | pub use axum; 10 | pub use reqwest; 11 | pub use serde; 12 | pub use serde_json; 13 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use hub::{config::lib::load_config, routes, state::AppState}; 2 | use std::sync::Arc; 3 | use tracing::info; 4 | 5 | #[tokio::main] 6 | async fn main() -> Result<(), anyhow::Error> { 7 | tracing_subscriber::fmt::init(); 8 | 9 | info!("Starting Traceloop Hub..."); 10 | 11 | let config_path = std::env::var("CONFIG_FILE_PATH").unwrap_or("config.yaml".to_string()); 12 | 13 | info!("Loading configuration from {}", config_path); 14 | let config = load_config(&config_path) 15 | .map_err(|e| anyhow::anyhow!("Failed to load configuration: {}", e))?; 16 | let state = Arc::new( 17 | AppState::new(config).map_err(|e| anyhow::anyhow!("Failed to create app state: {}", e))?, 18 | ); 19 | let app = routes::create_router(state); 20 | let port: String = std::env::var("PORT").unwrap_or("3000".to_string()); 21 | let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port)) 22 | .await 23 | .unwrap(); 24 | 25 | info!("Server is running on port {}", port); 26 | axum::serve(listener, app).await.unwrap(); 27 | 28 | Ok(()) 29 | } 30 | -------------------------------------------------------------------------------- /src/models/chat.rs: -------------------------------------------------------------------------------- 1 | use futures::stream::BoxStream; 2 | use reqwest_streams::error::StreamBodyError; 3 | use serde::{Deserialize, Serialize}; 4 | use std::collections::HashMap; 5 | 6 | use super::content::ChatCompletionMessage; 7 | use super::logprob::LogProbs; 8 | use super::response_format::ResponseFormat; 9 | use super::streaming::ChatCompletionChunk; 10 | use super::tool_choice::ToolChoice; 11 | use super::tool_definition::ToolDefinition; 12 | use super::usage::Usage; 13 | 14 | #[derive(Deserialize, Serialize, Clone)] 15 | pub struct ChatCompletionRequest { 16 | pub model: String, 17 | pub messages: Vec, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | pub temperature: Option, 20 | #[serde(skip_serializing_if = "Option::is_none")] 21 | pub top_p: Option, 22 | #[serde(skip_serializing_if = "Option::is_none")] 23 | pub n: Option, 24 | #[serde(skip_serializing_if = "Option::is_none")] 25 | pub stream: Option, 26 | #[serde(skip_serializing_if = "Option::is_none")] 27 | pub stop: Option>, 28 | #[serde(skip_serializing_if = "Option::is_none")] 29 | pub max_tokens: Option, 30 | #[serde(skip_serializing_if = "Option::is_none")] 31 | pub max_completion_tokens: Option, 32 | #[serde(skip_serializing_if = "Option::is_none")] 33 | pub parallel_tool_calls: Option, 34 | #[serde(skip_serializing_if = "Option::is_none")] 35 | pub presence_penalty: Option, 36 | #[serde(skip_serializing_if = "Option::is_none")] 37 | pub frequency_penalty: Option, 38 | #[serde(skip_serializing_if = "Option::is_none")] 39 | pub logit_bias: Option>, 40 | #[serde(skip_serializing_if = "Option::is_none")] 41 | pub tool_choice: Option, 42 | #[serde(skip_serializing_if = "Option::is_none")] 43 | pub tools: Option>, 44 | #[serde(skip_serializing_if = "Option::is_none")] 45 | pub user: Option, 46 | #[serde(skip_serializing_if = "Option::is_none")] 47 | pub logprobs: Option, 48 | #[serde(skip_serializing_if = "Option::is_none")] 49 | pub top_logprobs: Option, 50 | #[serde(skip_serializing_if = "Option::is_none")] 51 | pub response_format: Option, 52 | } 53 | 54 | pub enum ChatCompletionResponse { 55 | Stream(BoxStream<'static, Result>), 56 | NonStream(ChatCompletion), 57 | } 58 | 59 | #[derive(Deserialize, Serialize, Clone)] 60 | pub struct ChatCompletion { 61 | pub id: String, 62 | #[serde(skip_serializing_if = "Option::is_none")] 63 | pub object: Option, 64 | #[serde(skip_serializing_if = "Option::is_none")] 65 | pub created: Option, 66 | pub model: String, 67 | pub choices: Vec, 68 | pub usage: Usage, 69 | pub system_fingerprint: Option, 70 | } 71 | 72 | #[derive(Deserialize, Serialize, Clone)] 73 | pub struct ChatCompletionChoice { 74 | pub index: u32, 75 | pub message: ChatCompletionMessage, 76 | #[serde(skip_serializing_if = "Option::is_none")] 77 | pub finish_reason: Option, 78 | #[serde(skip_serializing_if = "Option::is_none")] 79 | pub logprobs: Option, 80 | } 81 | -------------------------------------------------------------------------------- /src/models/completion.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::collections::HashMap; 3 | 4 | use super::usage::Usage; 5 | 6 | #[derive(Deserialize, Serialize, Clone)] 7 | pub struct CompletionRequest { 8 | pub model: String, 9 | pub prompt: String, 10 | #[serde(skip_serializing_if = "Option::is_none")] 11 | pub suffix: Option, 12 | #[serde(skip_serializing_if = "Option::is_none")] 13 | pub max_tokens: Option, 14 | #[serde(skip_serializing_if = "Option::is_none")] 15 | pub temperature: Option, 16 | #[serde(skip_serializing_if = "Option::is_none")] 17 | pub top_p: Option, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | pub n: Option, 20 | #[serde(skip_serializing_if = "Option::is_none")] 21 | pub stream: Option, 22 | #[serde(skip_serializing_if = "Option::is_none")] 23 | pub logprobs: Option, 24 | #[serde(skip_serializing_if = "Option::is_none")] 25 | pub echo: Option, 26 | #[serde(skip_serializing_if = "Option::is_none")] 27 | pub stop: Option>, 28 | #[serde(skip_serializing_if = "Option::is_none")] 29 | pub presence_penalty: Option, 30 | #[serde(skip_serializing_if = "Option::is_none")] 31 | pub frequency_penalty: Option, 32 | #[serde(skip_serializing_if = "Option::is_none")] 33 | pub best_of: Option, 34 | #[serde(skip_serializing_if = "Option::is_none")] 35 | pub logit_bias: Option>, 36 | #[serde(skip_serializing_if = "Option::is_none")] 37 | pub user: Option, 38 | } 39 | 40 | #[derive(Deserialize, Serialize, Clone)] 41 | pub struct CompletionResponse { 42 | pub id: String, 43 | pub object: String, 44 | pub created: u64, 45 | pub model: String, 46 | pub choices: Vec, 47 | pub usage: Usage, 48 | } 49 | 50 | #[derive(Deserialize, Serialize, Clone)] 51 | pub struct CompletionChoice { 52 | pub text: String, 53 | pub index: u32, 54 | #[serde(skip_serializing_if = "Option::is_none")] 55 | pub logprobs: Option, 56 | #[serde(skip_serializing_if = "Option::is_none")] 57 | pub finish_reason: Option, 58 | } 59 | 60 | #[derive(Deserialize, Serialize, Clone)] 61 | pub struct LogProbs { 62 | pub tokens: Vec, 63 | pub token_logprobs: Vec, 64 | pub top_logprobs: Vec>, 65 | pub text_offset: Vec, 66 | } 67 | -------------------------------------------------------------------------------- /src/models/content.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use super::tool_calls::ChatMessageToolCall; 4 | 5 | #[derive(Deserialize, Serialize, Clone)] 6 | #[serde(untagged)] 7 | pub enum ChatMessageContent { 8 | String(String), 9 | Array(Vec), 10 | } 11 | 12 | #[derive(Deserialize, Serialize, Clone)] 13 | pub struct ChatMessageContentPart { 14 | #[serde(rename = "type")] 15 | pub r#type: String, 16 | pub text: String, 17 | } 18 | 19 | #[derive(Deserialize, Serialize, Clone)] 20 | pub struct ChatCompletionMessage { 21 | pub role: String, 22 | #[serde(skip_serializing_if = "Option::is_none")] 23 | pub content: Option, 24 | #[serde(skip_serializing_if = "Option::is_none")] 25 | pub name: Option, 26 | #[serde(skip_serializing_if = "Option::is_none")] 27 | pub tool_calls: Option>, 28 | #[serde(skip_serializing_if = "Option::is_none")] 29 | pub refusal: Option, 30 | } 31 | -------------------------------------------------------------------------------- /src/models/embeddings.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use serde_json::Value; 3 | 4 | use super::usage::EmbeddingUsage; 5 | 6 | #[derive(Deserialize, Serialize, Clone)] 7 | pub struct EmbeddingsRequest { 8 | pub model: String, 9 | pub input: EmbeddingsInput, 10 | #[serde(skip_serializing_if = "Option::is_none")] 11 | pub user: Option, 12 | #[serde(skip_serializing_if = "Option::is_none")] 13 | pub encoding_format: Option, 14 | } 15 | 16 | #[derive(Deserialize, Serialize, Clone)] 17 | #[serde(untagged)] 18 | pub enum EmbeddingsInput { 19 | Single(String), 20 | Multiple(Vec), 21 | SingleTokenIds(Vec), 22 | MultipleTokenIds(Vec>), 23 | } 24 | 25 | #[derive(Deserialize, Serialize, Clone)] 26 | pub struct EmbeddingsResponse { 27 | pub object: String, 28 | pub data: Vec, 29 | pub model: String, 30 | pub usage: EmbeddingUsage, 31 | } 32 | 33 | #[derive(Deserialize, Serialize, Clone)] 34 | pub struct Embeddings { 35 | pub object: String, 36 | pub embedding: Embedding, 37 | pub index: usize, 38 | } 39 | 40 | #[derive(Deserialize, Serialize, Clone)] 41 | #[serde(untagged)] 42 | pub enum Embedding { 43 | String(String), 44 | Float(Vec), 45 | Json(Value), 46 | } 47 | -------------------------------------------------------------------------------- /src/models/logprob.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Deserialize, Serialize, Clone, Debug)] 4 | pub struct LogProbs { 5 | pub content: Vec, 6 | } 7 | 8 | #[derive(Deserialize, Serialize, Clone, Debug)] 9 | pub struct LogProbContent { 10 | pub token: String, 11 | pub logprob: f32, 12 | pub bytes: Vec, 13 | pub top_logprobs: Vec, 14 | } 15 | 16 | #[derive(Deserialize, Serialize, Clone, Debug)] 17 | pub struct TopLogprob { 18 | pub token: String, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | pub bytes: Option>, 21 | pub logprob: f64, 22 | } 23 | 24 | #[derive(Deserialize, Serialize, Clone, Debug)] 25 | pub struct ChatCompletionTokenLogprob { 26 | pub token: String, 27 | #[serde(skip_serializing_if = "Option::is_none")] 28 | pub bytes: Option>, 29 | pub logprob: f64, 30 | pub top_logprobs: Vec, 31 | } 32 | 33 | #[derive(Deserialize, Serialize, Clone, Debug)] 34 | pub struct ChoiceLogprobs { 35 | #[serde(skip_serializing_if = "Option::is_none")] 36 | pub content: Option>, 37 | #[serde(skip_serializing_if = "Option::is_none")] 38 | pub refusal: Option>, 39 | } 40 | -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod chat; 2 | pub mod completion; 3 | pub mod content; 4 | pub mod embeddings; 5 | pub mod logprob; 6 | pub mod response_format; 7 | pub mod streaming; 8 | pub mod tool_calls; 9 | pub mod tool_choice; 10 | pub mod tool_definition; 11 | pub mod usage; 12 | -------------------------------------------------------------------------------- /src/models/response_format.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Deserialize, Serialize, Clone)] 4 | pub struct ResponseFormat { 5 | #[serde(rename = "type")] 6 | pub r#type: String, 7 | #[serde(skip_serializing_if = "Option::is_none")] 8 | pub json_schema: Option, 9 | } 10 | 11 | #[derive(Deserialize, Serialize, Clone)] 12 | pub struct JsonSchema { 13 | pub name: String, 14 | #[serde(skip_serializing_if = "Option::is_none")] 15 | pub description: Option, 16 | #[serde(skip_serializing_if = "Option::is_none")] 17 | pub schema: Option, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | pub strict: Option, 20 | } 21 | -------------------------------------------------------------------------------- /src/models/streaming.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use super::logprob::ChoiceLogprobs; 4 | use super::tool_calls::ChatMessageToolCall; 5 | use super::usage::Usage; 6 | 7 | #[derive(Deserialize, Serialize, Clone, Debug, Default)] 8 | pub struct Delta { 9 | pub role: Option, 10 | pub content: Option, 11 | pub function_call: Option, 12 | pub tool_calls: Option>, 13 | } 14 | 15 | #[derive(Deserialize, Serialize, Clone, Debug)] 16 | pub struct ChoiceDelta { 17 | #[serde(skip_serializing_if = "Option::is_none")] 18 | pub content: Option, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | pub role: Option, 21 | #[serde(skip_serializing_if = "Option::is_none")] 22 | pub tool_calls: Option>, 23 | } 24 | 25 | #[derive(Deserialize, Serialize, Clone, Debug)] 26 | pub struct Choice { 27 | pub delta: ChoiceDelta, 28 | #[serde(skip_serializing_if = "Option::is_none")] 29 | pub finish_reason: Option, 30 | pub index: u32, 31 | #[serde(skip_serializing_if = "Option::is_none")] 32 | pub logprobs: Option, 33 | } 34 | 35 | #[derive(Deserialize, Serialize, Clone, Debug)] 36 | pub struct ChatCompletionChunk { 37 | pub id: String, 38 | pub choices: Vec, 39 | pub created: i64, 40 | pub model: String, 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | pub service_tier: Option, 43 | #[serde(skip_serializing_if = "Option::is_none")] 44 | pub system_fingerprint: Option, 45 | #[serde(skip_serializing_if = "Option::is_none")] 46 | pub usage: Option, 47 | } 48 | -------------------------------------------------------------------------------- /src/models/tool_calls.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Deserialize, Serialize, Clone, Debug)] 4 | pub struct FunctionCall { 5 | pub arguments: String, 6 | pub name: String, 7 | } 8 | 9 | #[derive(Deserialize, Serialize, Clone, Debug)] 10 | pub struct ChatMessageToolCall { 11 | pub id: String, 12 | pub function: FunctionCall, 13 | #[serde(rename = "type")] 14 | pub r#type: String, // Using `function` as the only valid value 15 | } 16 | -------------------------------------------------------------------------------- /src/models/tool_choice.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Debug, Clone, Serialize, Deserialize)] 4 | #[serde(untagged)] 5 | pub enum ToolChoice { 6 | Simple(SimpleToolChoice), 7 | Named(ChatCompletionNamedToolChoice), 8 | } 9 | 10 | #[derive(Debug, Clone, Serialize, Deserialize)] 11 | #[serde(rename_all = "lowercase")] 12 | pub enum SimpleToolChoice { 13 | None, 14 | Auto, 15 | Required, 16 | } 17 | 18 | #[derive(Debug, Clone, Serialize, Deserialize)] 19 | pub struct ChatCompletionNamedToolChoice { 20 | #[serde(rename = "type")] 21 | pub tool_type: ToolType, 22 | pub function: Function, 23 | } 24 | 25 | #[derive(Debug, Clone, Serialize, Deserialize)] 26 | #[serde(rename_all = "lowercase")] 27 | pub enum ToolType { 28 | Function, 29 | } 30 | 31 | #[derive(Debug, Clone, Serialize, Deserialize)] 32 | pub struct Function { 33 | pub name: String, 34 | } 35 | -------------------------------------------------------------------------------- /src/models/tool_definition.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Serialize, Deserialize, Clone)] 6 | pub struct ToolDefinition { 7 | pub function: FunctionDefinition, 8 | 9 | #[serde(rename = "type")] 10 | pub tool_type: String, // Will only accept "function" value 11 | } 12 | 13 | /// A definition of a function that can be called. 14 | #[derive(Debug, Serialize, Deserialize, Clone)] 15 | pub struct FunctionDefinition { 16 | pub name: String, 17 | #[serde(skip_serializing_if = "Option::is_none")] 18 | pub description: Option, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | pub parameters: Option>, 21 | #[serde(skip_serializing_if = "Option::is_none")] 22 | pub strict: Option, 23 | } 24 | -------------------------------------------------------------------------------- /src/models/usage.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Deserialize, Serialize, Clone, Debug)] 4 | pub struct CompletionTokensDetails { 5 | #[serde(skip_serializing_if = "Option::is_none")] 6 | pub accepted_prediction_tokens: Option, 7 | #[serde(skip_serializing_if = "Option::is_none")] 8 | pub audio_tokens: Option, 9 | #[serde(skip_serializing_if = "Option::is_none")] 10 | pub reasoning_tokens: Option, 11 | #[serde(skip_serializing_if = "Option::is_none")] 12 | pub rejected_prediction_tokens: Option, 13 | } 14 | 15 | #[derive(Deserialize, Serialize, Clone, Debug)] 16 | pub struct PromptTokensDetails { 17 | #[serde(skip_serializing_if = "Option::is_none")] 18 | pub audio_tokens: Option, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | pub cached_tokens: Option, 21 | } 22 | 23 | #[derive(Deserialize, Serialize, Clone, Debug, Default)] 24 | pub struct Usage { 25 | pub prompt_tokens: u32, 26 | pub completion_tokens: u32, 27 | pub total_tokens: u32, 28 | #[serde(skip_serializing_if = "Option::is_none")] 29 | pub completion_tokens_details: Option, 30 | #[serde(skip_serializing_if = "Option::is_none")] 31 | pub prompt_tokens_details: Option, 32 | } 33 | 34 | #[derive(Deserialize, Serialize, Clone, Debug, Default)] 35 | pub struct EmbeddingUsage { 36 | #[serde(skip_serializing_if = "Option::is_none")] 37 | pub prompt_tokens: Option, 38 | pub total_tokens: Option, 39 | } 40 | -------------------------------------------------------------------------------- /src/pipelines/mod.rs: -------------------------------------------------------------------------------- 1 | mod otel; 2 | pub mod pipeline; 3 | -------------------------------------------------------------------------------- /src/pipelines/otel.rs: -------------------------------------------------------------------------------- 1 | use crate::config::lib::get_trace_content_enabled; 2 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; 3 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 4 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; 5 | use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse}; 6 | use crate::models::streaming::ChatCompletionChunk; 7 | use crate::models::usage::{EmbeddingUsage, Usage}; 8 | use opentelemetry::global::{BoxedSpan, ObjectSafeSpan}; 9 | use opentelemetry::trace::{SpanKind, Status, Tracer}; 10 | use opentelemetry::{global, KeyValue}; 11 | use opentelemetry_otlp::{SpanExporter, WithExportConfig, WithHttpConfig}; 12 | use opentelemetry_sdk::propagation::TraceContextPropagator; 13 | use opentelemetry_sdk::trace::TracerProvider; 14 | use opentelemetry_semantic_conventions::attribute::GEN_AI_REQUEST_MODEL; 15 | use opentelemetry_semantic_conventions::trace::*; 16 | use std::collections::HashMap; 17 | 18 | pub trait RecordSpan { 19 | fn record_span(&self, span: &mut BoxedSpan); 20 | } 21 | 22 | pub struct OtelTracer { 23 | span: BoxedSpan, 24 | accumulated_completion: Option, 25 | } 26 | 27 | impl OtelTracer { 28 | pub fn init(endpoint: String, api_key: String) { 29 | global::set_text_map_propagator(TraceContextPropagator::new()); 30 | let mut headers = HashMap::new(); 31 | headers.insert("Authorization".to_string(), format!("Bearer {}", api_key)); 32 | 33 | let exporter: SpanExporter = SpanExporter::builder() 34 | .with_http() 35 | .with_endpoint(endpoint) 36 | .with_headers(headers) 37 | .build() 38 | .expect("Failed to initialize OpenTelemetry"); 39 | 40 | let provider = TracerProvider::builder() 41 | .with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio) 42 | .build(); 43 | 44 | global::set_tracer_provider(provider); 45 | } 46 | 47 | pub fn start(operation: &str, request: &T) -> Self { 48 | let tracer = global::tracer("traceloop_hub"); 49 | let mut span = tracer 50 | .span_builder(format!("traceloop_hub.{}", operation)) 51 | .with_kind(SpanKind::Client) 52 | .start(&tracer); 53 | 54 | request.record_span(&mut span); 55 | 56 | Self { 57 | span, 58 | accumulated_completion: None, 59 | } 60 | } 61 | 62 | pub fn log_chunk(&mut self, chunk: &ChatCompletionChunk) { 63 | if self.accumulated_completion.is_none() { 64 | self.accumulated_completion = Some(ChatCompletion { 65 | id: chunk.id.clone(), 66 | object: None, 67 | created: None, 68 | model: chunk.model.clone(), 69 | choices: vec![], 70 | usage: Usage::default(), 71 | system_fingerprint: chunk.system_fingerprint.clone(), 72 | }); 73 | } 74 | 75 | if let Some(completion) = &mut self.accumulated_completion { 76 | for chunk_choice in &chunk.choices { 77 | if let Some(existing_choice) = 78 | completion.choices.get_mut(chunk_choice.index as usize) 79 | { 80 | if let Some(content) = &chunk_choice.delta.content { 81 | if let Some(ChatMessageContent::String(existing_content)) = 82 | &mut existing_choice.message.content 83 | { 84 | existing_content.push_str(content); 85 | } 86 | } 87 | if chunk_choice.finish_reason.is_some() { 88 | existing_choice.finish_reason = chunk_choice.finish_reason.clone(); 89 | } 90 | if let Some(tool_calls) = &chunk_choice.delta.tool_calls { 91 | existing_choice.message.tool_calls = Some(tool_calls.clone()); 92 | } 93 | } else { 94 | completion.choices.push(ChatCompletionChoice { 95 | index: chunk_choice.index, 96 | message: ChatCompletionMessage { 97 | name: None, 98 | role: chunk_choice 99 | .delta 100 | .role 101 | .clone() 102 | .unwrap_or_else(|| "assistant".to_string()), 103 | content: Some(ChatMessageContent::String( 104 | chunk_choice.delta.content.clone().unwrap_or_default(), 105 | )), 106 | tool_calls: chunk_choice.delta.tool_calls.clone(), 107 | refusal: None, 108 | }, 109 | finish_reason: chunk_choice.finish_reason.clone(), 110 | logprobs: None, 111 | }); 112 | } 113 | } 114 | } 115 | } 116 | 117 | pub fn streaming_end(&mut self) { 118 | if let Some(completion) = self.accumulated_completion.take() { 119 | completion.record_span(&mut self.span); 120 | self.span.set_status(Status::Ok); 121 | } 122 | } 123 | 124 | pub fn log_success(&mut self, response: &T) { 125 | response.record_span(&mut self.span); 126 | self.span.set_status(Status::Ok); 127 | } 128 | 129 | pub fn log_error(&mut self, description: String) { 130 | self.span.set_status(Status::error(description)); 131 | } 132 | } 133 | 134 | impl RecordSpan for ChatCompletionRequest { 135 | fn record_span(&self, span: &mut BoxedSpan) { 136 | span.set_attribute(KeyValue::new("llm.request.type", "chat")); 137 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); 138 | 139 | if let Some(freq_penalty) = self.frequency_penalty { 140 | span.set_attribute(KeyValue::new( 141 | GEN_AI_REQUEST_FREQUENCY_PENALTY, 142 | freq_penalty as f64, 143 | )); 144 | } 145 | if let Some(pres_penalty) = self.presence_penalty { 146 | span.set_attribute(KeyValue::new( 147 | GEN_AI_REQUEST_PRESENCE_PENALTY, 148 | pres_penalty as f64, 149 | )); 150 | } 151 | if let Some(top_p) = self.top_p { 152 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64)); 153 | } 154 | if let Some(temp) = self.temperature { 155 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64)); 156 | } 157 | 158 | if get_trace_content_enabled() { 159 | for (i, message) in self.messages.iter().enumerate() { 160 | if let Some(content) = &message.content { 161 | span.set_attribute(KeyValue::new( 162 | format!("gen_ai.prompt.{}.role", i), 163 | message.role.clone(), 164 | )); 165 | span.set_attribute(KeyValue::new( 166 | format!("gen_ai.prompt.{}.content", i), 167 | match &content { 168 | ChatMessageContent::String(content) => content.clone(), 169 | ChatMessageContent::Array(content) => { 170 | serde_json::to_string(content).unwrap_or_default() 171 | } 172 | }, 173 | )); 174 | } 175 | } 176 | } 177 | } 178 | } 179 | 180 | impl RecordSpan for ChatCompletion { 181 | fn record_span(&self, span: &mut BoxedSpan) { 182 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_MODEL, self.model.clone())); 183 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_ID, self.id.clone())); 184 | 185 | self.usage.record_span(span); 186 | 187 | if get_trace_content_enabled() { 188 | for choice in &self.choices { 189 | if let Some(content) = &choice.message.content { 190 | span.set_attribute(KeyValue::new( 191 | format!("gen_ai.completion.{}.role", choice.index), 192 | choice.message.role.clone(), 193 | )); 194 | span.set_attribute(KeyValue::new( 195 | format!("gen_ai.completion.{}.content", choice.index), 196 | match &content { 197 | ChatMessageContent::String(content) => content.clone(), 198 | ChatMessageContent::Array(content) => { 199 | serde_json::to_string(content).unwrap_or_default() 200 | } 201 | }, 202 | )); 203 | } 204 | span.set_attribute(KeyValue::new( 205 | format!("gen_ai.completion.{}.finish_reason", choice.index), 206 | choice.finish_reason.clone().unwrap_or_default(), 207 | )); 208 | } 209 | } 210 | } 211 | } 212 | 213 | impl RecordSpan for CompletionRequest { 214 | fn record_span(&self, span: &mut BoxedSpan) { 215 | span.set_attribute(KeyValue::new("llm.request.type", "completion")); 216 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); 217 | span.set_attribute(KeyValue::new("gen_ai.prompt", self.prompt.clone())); 218 | 219 | if let Some(freq_penalty) = self.frequency_penalty { 220 | span.set_attribute(KeyValue::new( 221 | GEN_AI_REQUEST_FREQUENCY_PENALTY, 222 | freq_penalty as f64, 223 | )); 224 | } 225 | if let Some(pres_penalty) = self.presence_penalty { 226 | span.set_attribute(KeyValue::new( 227 | GEN_AI_REQUEST_PRESENCE_PENALTY, 228 | pres_penalty as f64, 229 | )); 230 | } 231 | if let Some(top_p) = self.top_p { 232 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64)); 233 | } 234 | if let Some(temp) = self.temperature { 235 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64)); 236 | } 237 | } 238 | } 239 | 240 | impl RecordSpan for CompletionResponse { 241 | fn record_span(&self, span: &mut BoxedSpan) { 242 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_MODEL, self.model.clone())); 243 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_ID, self.id.clone())); 244 | 245 | self.usage.record_span(span); 246 | 247 | for choice in &self.choices { 248 | span.set_attribute(KeyValue::new( 249 | format!("gen_ai.completion.{}.role", choice.index), 250 | "assistant".to_string(), 251 | )); 252 | span.set_attribute(KeyValue::new( 253 | format!("gen_ai.completion.{}.content", choice.index), 254 | choice.text.clone(), 255 | )); 256 | span.set_attribute(KeyValue::new( 257 | format!("gen_ai.completion.{}.finish_reason", choice.index), 258 | choice.finish_reason.clone().unwrap_or_default(), 259 | )); 260 | } 261 | } 262 | } 263 | 264 | impl RecordSpan for EmbeddingsRequest { 265 | fn record_span(&self, span: &mut BoxedSpan) { 266 | span.set_attribute(KeyValue::new("llm.request.type", "embeddings")); 267 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone())); 268 | 269 | if get_trace_content_enabled() { 270 | match &self.input { 271 | EmbeddingsInput::Single(text) => { 272 | span.set_attribute(KeyValue::new("llm.prompt.0.content", text.clone())); 273 | } 274 | EmbeddingsInput::Multiple(texts) => { 275 | for (i, text) in texts.iter().enumerate() { 276 | span.set_attribute(KeyValue::new( 277 | format!("llm.prompt.{}.role", i), 278 | "user".to_string(), 279 | )); 280 | span.set_attribute(KeyValue::new( 281 | format!("llm.prompt.{}.content", i), 282 | text.clone(), 283 | )); 284 | } 285 | } 286 | EmbeddingsInput::SingleTokenIds(token_ids) => { 287 | span.set_attribute(KeyValue::new( 288 | "llm.prompt.0.content", 289 | format!("{:?}", token_ids), 290 | )); 291 | } 292 | EmbeddingsInput::MultipleTokenIds(token_ids) => { 293 | for (i, token_ids) in token_ids.iter().enumerate() { 294 | span.set_attribute(KeyValue::new( 295 | format!("llm.prompt.{}.role", i), 296 | "user".to_string(), 297 | )); 298 | span.set_attribute(KeyValue::new( 299 | format!("llm.prompt.{}.content", i), 300 | format!("{:?}", token_ids), 301 | )); 302 | } 303 | } 304 | } 305 | } 306 | } 307 | } 308 | impl RecordSpan for EmbeddingsResponse { 309 | fn record_span(&self, span: &mut BoxedSpan) { 310 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_MODEL, self.model.clone())); 311 | 312 | self.usage.record_span(span); 313 | } 314 | } 315 | 316 | impl RecordSpan for Usage { 317 | fn record_span(&self, span: &mut BoxedSpan) { 318 | span.set_attribute(KeyValue::new( 319 | "gen_ai.usage.prompt_tokens", 320 | self.prompt_tokens as i64, 321 | )); 322 | span.set_attribute(KeyValue::new( 323 | "gen_ai.usage.completion_tokens", 324 | self.completion_tokens as i64, 325 | )); 326 | span.set_attribute(KeyValue::new( 327 | "gen_ai.usage.total_tokens", 328 | self.total_tokens as i64, 329 | )); 330 | } 331 | } 332 | 333 | impl RecordSpan for EmbeddingUsage { 334 | fn record_span(&self, span: &mut BoxedSpan) { 335 | span.set_attribute(KeyValue::new( 336 | "gen_ai.usage.prompt_tokens", 337 | self.prompt_tokens.unwrap_or(0) as i64, 338 | )); 339 | span.set_attribute(KeyValue::new( 340 | "gen_ai.usage.total_tokens", 341 | self.total_tokens.unwrap_or(0) as i64, 342 | )); 343 | } 344 | } 345 | -------------------------------------------------------------------------------- /src/pipelines/pipeline.rs: -------------------------------------------------------------------------------- 1 | use crate::config::models::PipelineType; 2 | use crate::models::chat::ChatCompletionResponse; 3 | use crate::models::completion::CompletionRequest; 4 | use crate::models::embeddings::EmbeddingsRequest; 5 | use crate::models::streaming::ChatCompletionChunk; 6 | use crate::pipelines::otel::OtelTracer; 7 | use crate::{ 8 | ai_models::registry::ModelRegistry, 9 | config::models::{Pipeline, PluginConfig}, 10 | models::chat::ChatCompletionRequest, 11 | }; 12 | use async_stream::stream; 13 | use axum::response::sse::{Event, KeepAlive}; 14 | use axum::response::{IntoResponse, Sse}; 15 | use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; 16 | use futures::stream::BoxStream; 17 | use futures::{Stream, StreamExt}; 18 | use reqwest_streams::error::StreamBodyError; 19 | use std::sync::Arc; 20 | 21 | pub fn create_pipeline(pipeline: &Pipeline, model_registry: &ModelRegistry) -> Router { 22 | let mut router = Router::new(); 23 | 24 | for plugin in pipeline.plugins.clone() { 25 | router = match plugin { 26 | PluginConfig::Tracing { endpoint, api_key } => { 27 | OtelTracer::init(endpoint, api_key); 28 | router 29 | } 30 | PluginConfig::ModelRouter { models } => match pipeline.r#type { 31 | PipelineType::Chat => router.route( 32 | "/chat/completions", 33 | post(move |state, payload| chat_completions(state, payload, models)), 34 | ), 35 | PipelineType::Completion => router.route( 36 | "/completions", 37 | post(move |state, payload| completions(state, payload, models)), 38 | ), 39 | PipelineType::Embeddings => router.route( 40 | "/embeddings", 41 | post(move |state, payload| embeddings(state, payload, models)), 42 | ), 43 | }, 44 | _ => router, 45 | }; 46 | } 47 | 48 | router.with_state(Arc::new(model_registry.clone())) 49 | } 50 | 51 | fn trace_and_stream( 52 | mut tracer: OtelTracer, 53 | stream: BoxStream<'static, Result>, 54 | ) -> impl Stream> { 55 | stream! { 56 | let mut stream = stream; 57 | while let Some(result) = stream.next().await { 58 | yield match result { 59 | Ok(chunk) => { 60 | tracer.log_chunk(&chunk); 61 | Event::default().json_data(chunk) 62 | } 63 | Err(e) => { 64 | eprintln!("Error in stream: {:?}", e); 65 | tracer.log_error(e.to_string()); 66 | Err(axum::Error::new(e)) 67 | } 68 | }; 69 | } 70 | tracer.streaming_end(); 71 | } 72 | } 73 | 74 | pub async fn chat_completions( 75 | State(model_registry): State>, 76 | Json(payload): Json, 77 | model_keys: Vec, 78 | ) -> Result { 79 | let mut tracer = OtelTracer::start("chat", &payload); 80 | 81 | for model_key in model_keys { 82 | let model = model_registry.get(&model_key).unwrap(); 83 | 84 | if payload.model == model.model_type { 85 | let response = model 86 | .chat_completions(payload.clone()) 87 | .await 88 | .inspect_err(|e| { 89 | eprintln!("Chat completion error for model {}: {:?}", model_key, e); 90 | })?; 91 | 92 | if let ChatCompletionResponse::NonStream(completion) = response { 93 | tracer.log_success(&completion); 94 | return Ok(Json(completion).into_response()); 95 | } 96 | 97 | if let ChatCompletionResponse::Stream(stream) = response { 98 | return Ok(Sse::new(trace_and_stream(tracer, stream)) 99 | .keep_alive(KeepAlive::default()) 100 | .into_response()); 101 | } 102 | } 103 | } 104 | 105 | tracer.log_error("No matching model found".to_string()); 106 | eprintln!("No matching model found for: {}", payload.model); 107 | Err(StatusCode::NOT_FOUND) 108 | } 109 | 110 | pub async fn completions( 111 | State(model_registry): State>, 112 | Json(payload): Json, 113 | model_keys: Vec, 114 | ) -> impl IntoResponse { 115 | let mut tracer = OtelTracer::start("completion", &payload); 116 | 117 | for model_key in model_keys { 118 | let model = model_registry.get(&model_key).unwrap(); 119 | 120 | if payload.model == model.model_type { 121 | let response = model.completions(payload.clone()).await.inspect_err(|e| { 122 | eprintln!("Completion error for model {}: {:?}", model_key, e); 123 | })?; 124 | tracer.log_success(&response); 125 | return Ok(Json(response)); 126 | } 127 | } 128 | 129 | tracer.log_error("No matching model found".to_string()); 130 | eprintln!("No matching model found for: {}", payload.model); 131 | Err(StatusCode::NOT_FOUND) 132 | } 133 | 134 | pub async fn embeddings( 135 | State(model_registry): State>, 136 | Json(payload): Json, 137 | model_keys: Vec, 138 | ) -> impl IntoResponse { 139 | let mut tracer = OtelTracer::start("embeddings", &payload); 140 | 141 | for model_key in model_keys { 142 | let model = model_registry.get(&model_key).unwrap(); 143 | 144 | if payload.model == model.model_type { 145 | let response = model.embeddings(payload.clone()).await.inspect_err(|e| { 146 | eprintln!("Embeddings error for model {}: {:?}", model_key, e); 147 | })?; 148 | tracer.log_success(&response); 149 | return Ok(Json(response)); 150 | } 151 | } 152 | 153 | tracer.log_error("No matching model found".to_string()); 154 | eprintln!("No matching model found for: {}", payload.model); 155 | Err(StatusCode::NOT_FOUND) 156 | } 157 | -------------------------------------------------------------------------------- /src/providers/anthropic/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod models; 2 | mod provider; 3 | 4 | pub use models::{AnthropicChatCompletionRequest, AnthropicChatCompletionResponse}; 5 | pub use provider::AnthropicProvider; 6 | -------------------------------------------------------------------------------- /src/providers/anthropic/models.rs: -------------------------------------------------------------------------------- 1 | use crate::config::constants::default_max_tokens; 2 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; 3 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart}; 4 | use crate::models::tool_calls::{ChatMessageToolCall, FunctionCall}; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | #[derive(Deserialize, Serialize, Clone)] 8 | pub struct AnthropicChatCompletionRequest { 9 | pub max_tokens: u32, 10 | pub model: String, 11 | pub messages: Vec, 12 | #[serde(skip_serializing_if = "Option::is_none")] 13 | pub temperature: Option, 14 | #[serde(skip_serializing_if = "Option::is_none")] 15 | pub tool_choice: Option, 16 | pub tools: Vec, 17 | #[serde(skip_serializing_if = "Option::is_none")] 18 | pub top_p: Option, 19 | #[serde(skip_serializing_if = "Option::is_none")] 20 | pub stream: Option, 21 | #[serde(skip_serializing_if = "Option::is_none")] 22 | pub system: Option, 23 | } 24 | 25 | #[derive(Deserialize, Serialize, Clone)] 26 | pub struct AnthropicChatCompletionResponse { 27 | pub id: String, 28 | pub model: String, 29 | pub content: Vec, 30 | pub usage: Usage, 31 | } 32 | 33 | #[derive(Deserialize, Serialize, Clone)] 34 | #[serde(tag = "type")] 35 | pub enum ContentBlock { 36 | #[serde(rename = "text")] 37 | Text { text: String }, 38 | #[serde(rename = "tool_use")] 39 | ToolUse { 40 | id: String, 41 | input: serde_json::Value, 42 | name: String, 43 | }, 44 | } 45 | 46 | #[derive(Deserialize, Serialize, Clone)] 47 | pub struct Usage { 48 | pub input_tokens: u32, 49 | pub output_tokens: u32, 50 | } 51 | 52 | #[derive(Deserialize, Serialize, Clone)] 53 | pub(crate) struct InputSchemaTyped { 54 | #[serde(rename = "type")] 55 | pub r#type: String, 56 | #[serde(skip_serializing_if = "Option::is_none")] 57 | pub properties: Option, 58 | } 59 | 60 | pub(crate) type InputSchema = serde_json::Value; 61 | 62 | #[derive(Deserialize, Serialize, Clone)] 63 | pub struct ToolParam { 64 | pub input_schema: InputSchema, 65 | pub name: String, 66 | #[serde(skip_serializing_if = "Option::is_none")] 67 | pub description: Option, 68 | } 69 | 70 | #[derive(Deserialize, Serialize, Clone)] 71 | #[serde(tag = "type")] 72 | pub enum ToolChoice { 73 | #[serde(rename = "auto")] 74 | Auto { disable_parallel_tool_use: bool }, 75 | #[serde(rename = "any")] 76 | Any { disable_parallel_tool_use: bool }, 77 | #[serde(rename = "tool")] 78 | Tool { 79 | name: String, 80 | disable_parallel_tool_use: bool, 81 | }, 82 | } 83 | 84 | impl From for AnthropicChatCompletionRequest { 85 | fn from(request: ChatCompletionRequest) -> Self { 86 | let should_include_tools = !matches!( 87 | request.tool_choice, 88 | Some(crate::models::tool_choice::ToolChoice::Simple( 89 | crate::models::tool_choice::SimpleToolChoice::None 90 | )) 91 | ); 92 | 93 | let system = request 94 | .messages 95 | .iter() 96 | .find(|msg| msg.role == "system") 97 | .and_then(|message| match &message.content { 98 | Some(ChatMessageContent::String(text)) => Some(text.clone()), 99 | Some(ChatMessageContent::Array(parts)) => parts 100 | .iter() 101 | .find(|part| part.r#type == "text") 102 | .map(|part| part.text.clone()), 103 | _ => None, 104 | }); 105 | 106 | let messages: Vec = request 107 | .messages 108 | .into_iter() 109 | .filter(|msg| msg.role != "system") 110 | .collect(); 111 | 112 | let max_tokens = match request.max_completion_tokens { 113 | Some(val) if val > 0 => val, 114 | _ => request.max_tokens.unwrap_or_else(default_max_tokens), 115 | }; 116 | 117 | AnthropicChatCompletionRequest { 118 | max_tokens, 119 | model: request.model, 120 | messages, 121 | temperature: request.temperature, 122 | top_p: request.top_p, 123 | stream: request.stream, 124 | system, 125 | tool_choice: request.tool_choice.map(|choice| match choice { 126 | crate::models::tool_choice::ToolChoice::Simple(simple) => match simple { 127 | crate::models::tool_choice::SimpleToolChoice::None 128 | | crate::models::tool_choice::SimpleToolChoice::Auto => ToolChoice::Auto { 129 | disable_parallel_tool_use: request.parallel_tool_calls.unwrap_or(false), 130 | }, 131 | crate::models::tool_choice::SimpleToolChoice::Required => ToolChoice::Any { 132 | disable_parallel_tool_use: request.parallel_tool_calls.unwrap_or(false), 133 | }, 134 | }, 135 | crate::models::tool_choice::ToolChoice::Named(named) => ToolChoice::Tool { 136 | name: named.function.name, 137 | disable_parallel_tool_use: request.parallel_tool_calls.unwrap_or(false), 138 | }, 139 | }), 140 | tools: if should_include_tools { 141 | request 142 | .tools 143 | .unwrap_or_default() 144 | .into_iter() 145 | .map(|tool| ToolParam { 146 | name: tool.function.name, 147 | description: tool.function.description, 148 | input_schema: serde_json::to_value(tool.function.parameters) 149 | .unwrap_or_default(), 150 | }) 151 | .collect() 152 | } else { 153 | Vec::new() 154 | }, 155 | } 156 | } 157 | } 158 | 159 | impl From> for ChatCompletionMessage { 160 | fn from(blocks: Vec) -> Self { 161 | let mut text_content = Vec::::new(); 162 | let mut tool_calls = Vec::::new(); 163 | 164 | for block in blocks { 165 | match block { 166 | ContentBlock::Text { text } => { 167 | text_content.push(ChatMessageContentPart { 168 | r#type: "text".to_string(), 169 | text, 170 | }); 171 | } 172 | ContentBlock::ToolUse { name, input, id } => { 173 | tool_calls.push(ChatMessageToolCall { 174 | id, 175 | function: FunctionCall { 176 | name, 177 | arguments: input.to_string(), 178 | }, 179 | r#type: "function".to_string(), 180 | }); 181 | } 182 | } 183 | } 184 | 185 | ChatCompletionMessage { 186 | role: "assistant".to_string(), 187 | content: Some(ChatMessageContent::Array(text_content)), 188 | name: None, 189 | refusal: None, 190 | tool_calls: if tool_calls.is_empty() { 191 | None 192 | } else { 193 | Some(tool_calls) 194 | }, 195 | } 196 | } 197 | } 198 | 199 | impl From for ChatCompletion { 200 | fn from(response: AnthropicChatCompletionResponse) -> Self { 201 | ChatCompletion { 202 | id: response.id, 203 | object: None, 204 | created: None, 205 | model: response.model, 206 | choices: vec![ChatCompletionChoice { 207 | index: 0, 208 | message: response.content.into(), 209 | finish_reason: Some("stop".to_string()), 210 | logprobs: None, 211 | }], 212 | usage: crate::models::usage::Usage { 213 | prompt_tokens: response.usage.input_tokens, 214 | completion_tokens: response.usage.output_tokens, 215 | total_tokens: response.usage.input_tokens + response.usage.output_tokens, 216 | completion_tokens_details: None, 217 | prompt_tokens_details: None, 218 | }, 219 | system_fingerprint: None, 220 | } 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /src/providers/anthropic/provider.rs: -------------------------------------------------------------------------------- 1 | use axum::async_trait; 2 | use axum::http::StatusCode; 3 | use reqwest::Client; 4 | 5 | use super::models::{AnthropicChatCompletionRequest, AnthropicChatCompletionResponse}; 6 | use crate::config::models::{ModelConfig, Provider as ProviderConfig}; 7 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 8 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 9 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; 10 | use crate::providers::provider::Provider; 11 | 12 | pub struct AnthropicProvider { 13 | api_key: String, 14 | config: ProviderConfig, 15 | http_client: Client, 16 | } 17 | 18 | #[async_trait] 19 | impl Provider for AnthropicProvider { 20 | fn new(config: &ProviderConfig) -> Self { 21 | Self { 22 | api_key: config.api_key.clone(), 23 | config: config.clone(), 24 | http_client: Client::new(), 25 | } 26 | } 27 | 28 | fn key(&self) -> String { 29 | self.config.key.clone() 30 | } 31 | 32 | fn r#type(&self) -> String { 33 | "anthropic".to_string() 34 | } 35 | 36 | async fn chat_completions( 37 | &self, 38 | payload: ChatCompletionRequest, 39 | _model_config: &ModelConfig, 40 | ) -> Result { 41 | let request = AnthropicChatCompletionRequest::from(payload); 42 | let response = self 43 | .http_client 44 | .post("https://api.anthropic.com/v1/messages") 45 | .header("x-api-key", &self.api_key) 46 | .header("anthropic-version", "2023-06-01") 47 | .json(&request) 48 | .send() 49 | .await 50 | .map_err(|e| { 51 | eprintln!("Anthropic API request error: {}", e); 52 | StatusCode::INTERNAL_SERVER_ERROR 53 | })?; 54 | 55 | let status = response.status(); 56 | if status.is_success() { 57 | if request.stream.unwrap_or(false) { 58 | unimplemented!() 59 | } else { 60 | let anthropic_response: AnthropicChatCompletionResponse = response 61 | .json() 62 | .await 63 | .expect("Failed to parse Anthropic response"); 64 | Ok(ChatCompletionResponse::NonStream(anthropic_response.into())) 65 | } 66 | } else { 67 | eprintln!( 68 | "Anthropic API request error: {}", 69 | response.text().await.unwrap() 70 | ); 71 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 72 | } 73 | } 74 | 75 | async fn completions( 76 | &self, 77 | _payload: CompletionRequest, 78 | _model_config: &ModelConfig, 79 | ) -> Result { 80 | unimplemented!() 81 | } 82 | 83 | async fn embeddings( 84 | &self, 85 | _payload: EmbeddingsRequest, 86 | _model_config: &ModelConfig, 87 | ) -> Result { 88 | unimplemented!() 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/providers/azure/mod.rs: -------------------------------------------------------------------------------- 1 | mod provider; 2 | 3 | pub use provider::AzureProvider; 4 | -------------------------------------------------------------------------------- /src/providers/azure/provider.rs: -------------------------------------------------------------------------------- 1 | use axum::async_trait; 2 | use axum::http::StatusCode; 3 | use reqwest_streams::JsonStreamResponse; 4 | 5 | use crate::config::constants::stream_buffer_size_bytes; 6 | use crate::config::models::{ModelConfig, Provider as ProviderConfig}; 7 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 8 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 9 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; 10 | use crate::models::streaming::ChatCompletionChunk; 11 | use crate::providers::provider::Provider; 12 | use reqwest::Client; 13 | use tracing::info; 14 | 15 | pub struct AzureProvider { 16 | config: ProviderConfig, 17 | http_client: Client, 18 | } 19 | 20 | impl AzureProvider { 21 | fn endpoint(&self) -> String { 22 | if let Some(base_url) = self.config.params.get("base_url") { 23 | base_url.clone() 24 | } else { 25 | format!( 26 | "https://{}.openai.azure.com/openai/deployments", 27 | self.config.params.get("resource_name").unwrap(), 28 | ) 29 | } 30 | } 31 | fn api_version(&self) -> String { 32 | self.config.params.get("api_version").unwrap().clone() 33 | } 34 | } 35 | 36 | #[async_trait] 37 | impl Provider for AzureProvider { 38 | fn new(config: &ProviderConfig) -> Self { 39 | Self { 40 | config: config.clone(), 41 | http_client: Client::new(), 42 | } 43 | } 44 | 45 | fn key(&self) -> String { 46 | self.config.key.clone() 47 | } 48 | 49 | fn r#type(&self) -> String { 50 | "azure".to_string() 51 | } 52 | 53 | async fn chat_completions( 54 | &self, 55 | payload: ChatCompletionRequest, 56 | model_config: &ModelConfig, 57 | ) -> Result { 58 | let deployment = model_config.params.get("deployment").unwrap(); 59 | let api_version = self.api_version(); 60 | let url = format!( 61 | "{}/{}/chat/completions?api-version={}", 62 | self.endpoint(), 63 | deployment, 64 | api_version 65 | ); 66 | 67 | let response = self 68 | .http_client 69 | .post(&url) 70 | .header("api-key", &self.config.api_key) 71 | .json(&payload) 72 | .send() 73 | .await 74 | .map_err(|e| { 75 | eprintln!("Azure OpenAI API request error: {}", e); 76 | StatusCode::INTERNAL_SERVER_ERROR 77 | })?; 78 | 79 | let status = response.status(); 80 | if status.is_success() { 81 | if payload.stream.unwrap_or(false) { 82 | let stream = 83 | response.json_array_stream::(stream_buffer_size_bytes()); 84 | Ok(ChatCompletionResponse::Stream(stream)) 85 | } else { 86 | response 87 | .json() 88 | .await 89 | .map(ChatCompletionResponse::NonStream) 90 | .map_err(|e| { 91 | eprintln!("Azure OpenAI API response error: {}", e); 92 | StatusCode::INTERNAL_SERVER_ERROR 93 | }) 94 | } 95 | } else { 96 | info!( 97 | "Azure OpenAI API request error: {}", 98 | response.text().await.unwrap() 99 | ); 100 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 101 | } 102 | } 103 | 104 | async fn completions( 105 | &self, 106 | payload: CompletionRequest, 107 | model_config: &ModelConfig, 108 | ) -> Result { 109 | let deployment = model_config.params.get("deployment").unwrap(); 110 | let api_version = self.api_version(); 111 | let url = format!( 112 | "{}/{}/completions?api-version={}", 113 | self.endpoint(), 114 | deployment, 115 | api_version 116 | ); 117 | 118 | let response = self 119 | .http_client 120 | .post(&url) 121 | .header("api-key", &self.config.api_key) 122 | .json(&payload) 123 | .send() 124 | .await 125 | .map_err(|e| { 126 | eprintln!("Azure OpenAI API request error: {}", e); 127 | StatusCode::INTERNAL_SERVER_ERROR 128 | })?; 129 | 130 | let status = response.status(); 131 | if status.is_success() { 132 | response.json().await.map_err(|e| { 133 | eprintln!("Azure OpenAI API response error: {}", e); 134 | StatusCode::INTERNAL_SERVER_ERROR 135 | }) 136 | } else { 137 | eprintln!( 138 | "Azure OpenAI API request error: {}", 139 | response.text().await.unwrap() 140 | ); 141 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 142 | } 143 | } 144 | 145 | async fn embeddings( 146 | &self, 147 | payload: EmbeddingsRequest, 148 | model_config: &ModelConfig, 149 | ) -> Result { 150 | let deployment = model_config.params.get("deployment").unwrap(); 151 | let api_version = self.api_version(); 152 | 153 | let url = format!( 154 | "{}/{}/embeddings?api-version={}", 155 | self.endpoint(), 156 | deployment, 157 | api_version 158 | ); 159 | 160 | let response = self 161 | .http_client 162 | .post(&url) 163 | .header("api-key", &self.config.api_key) 164 | .json(&payload) 165 | .send() 166 | .await 167 | .map_err(|e| { 168 | eprintln!("Azure OpenAI API request error: {}", e); 169 | StatusCode::INTERNAL_SERVER_ERROR 170 | })?; 171 | 172 | let status = response.status(); 173 | if status.is_success() { 174 | response.json().await.map_err(|e| { 175 | eprintln!("Azure OpenAI Embeddings API response error: {}", e); 176 | StatusCode::INTERNAL_SERVER_ERROR 177 | }) 178 | } else { 179 | eprintln!( 180 | "Azure OpenAI Embeddings API request error: {}", 181 | response.text().await.unwrap() 182 | ); 183 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /src/providers/bedrock/logs/ai21_j2_mid_v1_completions.json: -------------------------------------------------------------------------------- 1 | {"id":1234,"prompt":{"text":"Tell me a joke","tokens":[{"generatedToken":{"token":"▁Tell▁me","logprob":-15.517871856689453,"raw_logprob":-15.517871856689453},"topTokens":null,"textRange":{"start":0,"end":7}},{"generatedToken":{"token":"▁a▁joke","logprob":-6.4708662033081055,"raw_logprob":-6.4708662033081055},"topTokens":null,"textRange":{"start":7,"end":14}}]},"completions":[{"data":{"text":"\nWell, if you break the internet, will you be held responsible? \n \nBecause... \n \ndownsizing is coming!","tokens":[{"generatedToken":{"token":"<|newline|>","logprob":-4.389514506328851E-4,"raw_logprob":-4.389514506328851E-4},"topTokens":null,"textRange":{"start":0,"end":1}},{"generatedToken":{"token":"▁Well","logprob":-4.107903003692627,"raw_logprob":-4.107903003692627},"topTokens":null,"textRange":{"start":1,"end":5}},{"generatedToken":{"token":",","logprob":-0.020943328738212585,"raw_logprob":-0.020943328738212585},"topTokens":null,"textRange":{"start":5,"end":6}},{"generatedToken":{"token":"▁if▁you","logprob":-1.9317010641098022,"raw_logprob":-1.9317010641098022},"topTokens":null,"textRange":{"start":6,"end":13}},{"generatedToken":{"token":"▁break","logprob":-1.2613706588745117,"raw_logprob":-1.2613706588745117},"topTokens":null,"textRange":{"start":13,"end":19}},{"generatedToken":{"token":"▁the▁internet","logprob":-0.8270013928413391,"raw_logprob":-0.8270013928413391},"topTokens":null,"textRange":{"start":19,"end":32}},{"generatedToken":{"token":",","logprob":-0.0465129017829895,"raw_logprob":-0.0465129017829895},"topTokens":null,"textRange":{"start":32,"end":33}},{"generatedToken":{"token":"▁will▁you","logprob":-5.168251991271973,"raw_logprob":-5.168251991271973},"topTokens":null,"textRange":{"start":33,"end":42}},{"generatedToken":{"token":"▁be▁held▁responsible","logprob":-1.7928266525268555,"raw_logprob":-1.7928266525268555},"topTokens":null,"textRange":{"start":42,"end":62}},{"generatedToken":{"token":"?","logprob":-0.38259002566337585,"raw_logprob":-0.38259002566337585},"topTokens":null,"textRange":{"start":62,"end":63}},{"generatedToken":{"token":"▁","logprob":-0.5640338063240051,"raw_logprob":-0.5640338063240051},"topTokens":null,"textRange":{"start":63,"end":64}},{"generatedToken":{"token":"<|newline|>","logprob":-0.8146487474441528,"raw_logprob":-0.8146487474441528},"topTokens":null,"textRange":{"start":64,"end":65}},{"generatedToken":{"token":"▁▁","logprob":-1.8104121685028076,"raw_logprob":-1.8104121685028076},"topTokens":null,"textRange":{"start":65,"end":66}},{"generatedToken":{"token":"<|newline|>","logprob":-0.11257430166006088,"raw_logprob":-0.11257430166006088},"topTokens":null,"textRange":{"start":66,"end":67}},{"generatedToken":{"token":"▁Because","logprob":-0.9798339605331421,"raw_logprob":-0.9798339605331421},"topTokens":null,"textRange":{"start":67,"end":74}},{"generatedToken":{"token":"...","logprob":-3.552177906036377,"raw_logprob":-3.552177906036377},"topTokens":null,"textRange":{"start":74,"end":77}},{"generatedToken":{"token":"▁","logprob":-1.0605108737945557,"raw_logprob":-1.0605108737945557},"topTokens":null,"textRange":{"start":77,"end":78}},{"generatedToken":{"token":"<|newline|>","logprob":-0.04952247813344002,"raw_logprob":-0.04952247813344002},"topTokens":null,"textRange":{"start":78,"end":79}},{"generatedToken":{"token":"▁▁","logprob":-0.054540783166885376,"raw_logprob":-0.054540783166885376},"topTokens":null,"textRange":{"start":79,"end":80}},{"generatedToken":{"token":"<|newline|>","logprob":-3.5523738915799186E-5,"raw_logprob":-3.5523738915799186E-5},"topTokens":null,"textRange":{"start":80,"end":81}},{"generatedToken":{"token":"▁downsizing","logprob":-10.688655853271484,"raw_logprob":-10.688655853271484},"topTokens":null,"textRange":{"start":81,"end":91}},{"generatedToken":{"token":"▁is▁coming","logprob":-3.518221378326416,"raw_logprob":-3.518221378326416},"topTokens":null,"textRange":{"start":91,"end":101}},{"generatedToken":{"token":"!","logprob":-0.8572316765785217,"raw_logprob":-0.8572316765785217},"topTokens":null,"textRange":{"start":101,"end":102}},{"generatedToken":{"token":"<|endoftext|>","logprob":-0.002994698006659746,"raw_logprob":-0.002994698006659746},"topTokens":null,"textRange":{"start":102,"end":102}}]},"finishReason":{"reason":"endoftext"}}]} -------------------------------------------------------------------------------- /src/providers/bedrock/logs/ai21_jamba_1_5_mini_v1_0_chat_completions.json: -------------------------------------------------------------------------------- 1 | {"id":"chatcmpl-7c961099b52f4798bede41e6bb087b37","choices":[{"index":0,"message":{"role":"assistant","content":" Why don't skeletons fight each other?\n\nThey don't have the guts.","tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":15,"completion_tokens":21,"total_tokens":36},"meta":{"requestDurationMillis":173},"model":"jamba-1.5-mini"} -------------------------------------------------------------------------------- /src/providers/bedrock/logs/anthropic_claude_3_haiku_20240307_v1_0_chat_completion.json: -------------------------------------------------------------------------------- 1 | {"id":"msg_bdrk_01EUKDUQgXoPmpVJR79cggfA","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Here's a short joke for you:\n\nWhy can't a bicycle stand up on its own? Because it's two-tired!"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":12,"output_tokens":30}} -------------------------------------------------------------------------------- /src/providers/bedrock/logs/us_amazon_nova_lite_v1_0_chat_completion.json: -------------------------------------------------------------------------------- 1 | {"output":{"message":{"content":[{"text":"Paris"}],"role":"assistant"}},"stopReason":"end_turn","usage":{"inputTokens":12,"outputTokens":1,"totalTokens":13,"cacheReadInputTokenCount":0,"cacheWriteInputTokenCount":0}} -------------------------------------------------------------------------------- /src/providers/bedrock/mod.rs: -------------------------------------------------------------------------------- 1 | mod models; 2 | 3 | mod provider; 4 | #[cfg(test)] 5 | mod test; 6 | 7 | pub use provider::BedrockProvider; 8 | -------------------------------------------------------------------------------- /src/providers/bedrock/models.rs: -------------------------------------------------------------------------------- 1 | use crate::config::constants::{ 2 | default_embedding_dimension, default_embedding_normalize, default_max_tokens, 3 | }; 4 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; 5 | use crate::models::completion::{ 6 | CompletionChoice, CompletionRequest, CompletionResponse, LogProbs, 7 | }; 8 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; 9 | use crate::models::embeddings::{ 10 | Embedding, Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse, 11 | }; 12 | use crate::models::usage::{EmbeddingUsage, Usage}; 13 | use serde::{Deserialize, Serialize}; 14 | 15 | /** 16 | * Titan models 17 | */ 18 | 19 | #[derive(Serialize, Deserialize, Clone)] 20 | pub struct TitanMessageContent { 21 | pub text: String, 22 | } 23 | 24 | #[derive(Serialize, Deserialize, Clone)] 25 | pub struct TitanMessage { 26 | pub role: String, 27 | pub content: Vec, 28 | } 29 | 30 | #[derive(Serialize, Deserialize, Clone)] 31 | pub struct TitanInferenceConfig { 32 | pub max_new_tokens: u32, 33 | } 34 | 35 | #[derive(Serialize, Deserialize, Clone)] 36 | pub struct TitanChatCompletionRequest { 37 | #[serde(rename = "inferenceConfig")] 38 | pub inference_config: TitanInferenceConfig, 39 | pub messages: Vec, 40 | } 41 | 42 | #[derive(Deserialize, Serialize)] 43 | pub struct TitanChatCompletionResponse { 44 | pub output: TitanOutput, 45 | #[serde(rename = "stopReason")] 46 | pub stop_reason: String, 47 | pub usage: TitanUsage, 48 | } 49 | 50 | #[derive(Deserialize, Serialize)] 51 | pub struct TitanOutput { 52 | pub message: TitanMessage, 53 | } 54 | 55 | #[derive(Deserialize, Serialize)] 56 | pub struct TitanUsage { 57 | #[serde(rename = "inputTokens")] 58 | pub input_tokens: u32, 59 | #[serde(rename = "outputTokens")] 60 | pub output_tokens: u32, 61 | #[serde(rename = "totalTokens")] 62 | pub total_tokens: u32, 63 | } 64 | 65 | impl From for TitanChatCompletionRequest { 66 | fn from(request: ChatCompletionRequest) -> Self { 67 | let messages = request 68 | .messages 69 | .into_iter() 70 | .map(|msg| { 71 | let content_text = match msg.content { 72 | Some(ChatMessageContent::String(text)) => text, 73 | Some(ChatMessageContent::Array(parts)) => parts 74 | .into_iter() 75 | .filter(|part| part.r#type == "text") 76 | .map(|part| part.text) 77 | .collect::>() 78 | .join(" "), 79 | None => String::new(), 80 | }; 81 | 82 | TitanMessage { 83 | role: msg.role, 84 | content: vec![TitanMessageContent { text: content_text }], 85 | } 86 | }) 87 | .collect(); 88 | 89 | TitanChatCompletionRequest { 90 | inference_config: TitanInferenceConfig { 91 | max_new_tokens: request.max_tokens.unwrap_or(default_max_tokens()), 92 | }, 93 | messages, 94 | } 95 | } 96 | } 97 | 98 | impl From for ChatCompletion { 99 | fn from(response: TitanChatCompletionResponse) -> Self { 100 | let message = ChatCompletionMessage { 101 | role: response.output.message.role, 102 | content: Some(ChatMessageContent::String( 103 | response 104 | .output 105 | .message 106 | .content 107 | .into_iter() 108 | .map(|c| c.text) 109 | .collect::>() 110 | .join(" "), 111 | )), 112 | name: None, 113 | tool_calls: None, 114 | refusal: None, //this is not returned titan as at 1/04/2025 115 | }; 116 | 117 | ChatCompletion { 118 | id: "".to_string(), // _response.id is private in aws sdk , can't access 119 | object: None, 120 | created: None, 121 | model: "".to_string(), 122 | choices: vec![ChatCompletionChoice { 123 | index: 0, 124 | message, 125 | finish_reason: Some(response.stop_reason), 126 | logprobs: None, 127 | }], 128 | usage: Usage { 129 | prompt_tokens: response.usage.input_tokens, 130 | completion_tokens: response.usage.output_tokens, 131 | total_tokens: response.usage.total_tokens, 132 | completion_tokens_details: None, 133 | prompt_tokens_details: None, 134 | }, 135 | system_fingerprint: None, 136 | } 137 | } 138 | } 139 | 140 | #[derive(Debug, Serialize, Deserialize)] 141 | pub struct TitanEmbeddingRequest { 142 | #[serde(rename = "inputText")] 143 | pub input_text: String, 144 | pub dimensions: u32, 145 | pub normalize: bool, 146 | } 147 | 148 | #[derive(Debug, Serialize, Deserialize)] 149 | pub struct TitanEmbeddingResponse { 150 | pub embedding: Vec, 151 | #[serde(rename = "embeddingsByType")] 152 | pub embeddings_by_type: EmbeddingsByType, 153 | #[serde(rename = "inputTextTokenCount")] 154 | pub input_text_token_count: u32, 155 | } 156 | 157 | #[derive(Debug, Serialize, Deserialize)] 158 | pub struct EmbeddingsByType { 159 | pub float: Vec, 160 | } 161 | 162 | impl From for TitanEmbeddingRequest { 163 | fn from(request: EmbeddingsRequest) -> Self { 164 | let input_text = match request.input { 165 | EmbeddingsInput::Single(text) => text, 166 | EmbeddingsInput::Multiple(texts) => { 167 | texts.first().map(|s| s.to_string()).unwrap_or_default() 168 | } 169 | EmbeddingsInput::SingleTokenIds(token_ids) => token_ids 170 | .iter() 171 | .map(|id| id.to_string()) 172 | .collect::>() 173 | .join(" "), 174 | EmbeddingsInput::MultipleTokenIds(all_token_ids) => all_token_ids 175 | .first() 176 | .map(|token_ids| { 177 | token_ids 178 | .iter() 179 | .map(|id| id.to_string()) 180 | .collect::>() 181 | .join(" ") 182 | }) 183 | .unwrap_or_default(), 184 | }; 185 | 186 | TitanEmbeddingRequest { 187 | input_text, 188 | dimensions: default_embedding_dimension(), 189 | normalize: default_embedding_normalize(), 190 | } 191 | } 192 | } 193 | 194 | impl From for EmbeddingsResponse { 195 | fn from(response: TitanEmbeddingResponse) -> Self { 196 | EmbeddingsResponse { 197 | object: "list".to_string(), 198 | data: vec![Embeddings { 199 | object: "embedding".to_string(), 200 | embedding: Embedding::Float(response.embedding), 201 | index: 0, 202 | }], 203 | model: "".to_string(), 204 | usage: EmbeddingUsage { 205 | prompt_tokens: Some(response.input_text_token_count), 206 | total_tokens: Some(response.input_text_token_count), 207 | }, 208 | } 209 | } 210 | } 211 | 212 | /* 213 | Ai21 models 214 | */ 215 | 216 | #[derive(Debug, Deserialize, Serialize, Clone)] 217 | pub struct Ai21Message { 218 | pub role: String, 219 | pub content: String, 220 | } 221 | 222 | #[derive(Debug, Deserialize, Serialize, Clone)] 223 | pub struct Ai21ChatCompletionRequest { 224 | pub messages: Vec, 225 | pub max_tokens: u32, 226 | #[serde(skip_serializing_if = "Option::is_none")] 227 | pub temperature: Option, 228 | #[serde(skip_serializing_if = "Option::is_none")] 229 | pub top_p: Option, 230 | } 231 | 232 | #[derive(Deserialize, Serialize, Clone)] 233 | pub struct Ai21ChatCompletionResponse { 234 | pub id: String, 235 | pub choices: Vec, 236 | pub model: String, 237 | pub usage: Ai21Usage, 238 | pub meta: Ai21Meta, 239 | } 240 | 241 | #[derive(Deserialize, Serialize, Clone)] 242 | pub struct Ai21Choice { 243 | pub finish_reason: String, 244 | pub index: u32, 245 | pub message: Ai21Message, 246 | } 247 | 248 | #[derive(Deserialize, Serialize, Clone)] 249 | pub struct Ai21Meta { 250 | #[serde(rename = "requestDurationMillis")] 251 | pub request_duration_millis: u64, 252 | } 253 | 254 | #[derive(Deserialize, Serialize, Clone)] 255 | pub struct Ai21Usage { 256 | pub completion_tokens: u32, 257 | pub prompt_tokens: u32, 258 | pub total_tokens: u32, 259 | } 260 | 261 | impl From for Ai21ChatCompletionRequest { 262 | fn from(request: ChatCompletionRequest) -> Self { 263 | let messages = request 264 | .messages 265 | .into_iter() 266 | .map(|msg| { 267 | let content = match msg.content { 268 | Some(ChatMessageContent::String(text)) => text, 269 | Some(ChatMessageContent::Array(parts)) => parts 270 | .into_iter() 271 | .filter(|part| part.r#type == "text") 272 | .map(|part| part.text) 273 | .collect::>() 274 | .join(" "), 275 | None => String::new(), 276 | }; 277 | 278 | Ai21Message { 279 | role: msg.role, 280 | content, 281 | } 282 | }) 283 | .collect(); 284 | 285 | Ai21ChatCompletionRequest { 286 | messages, 287 | max_tokens: request.max_tokens.unwrap_or(default_max_tokens()), 288 | temperature: request.temperature, 289 | top_p: request.top_p, 290 | } 291 | } 292 | } 293 | 294 | impl From for ChatCompletion { 295 | fn from(response: Ai21ChatCompletionResponse) -> Self { 296 | ChatCompletion { 297 | id: response.id, 298 | object: None, 299 | created: None, 300 | model: response.model, 301 | choices: response 302 | .choices 303 | .into_iter() 304 | .map(|choice| ChatCompletionChoice { 305 | index: choice.index, 306 | message: ChatCompletionMessage { 307 | role: choice.message.role, 308 | content: Some(ChatMessageContent::String(choice.message.content)), 309 | name: None, 310 | tool_calls: None, 311 | refusal: None, //Ai21 does not return this as at 1/04/2025 312 | }, 313 | finish_reason: Some(choice.finish_reason), 314 | logprobs: None, 315 | }) 316 | .collect(), 317 | usage: Usage { 318 | prompt_tokens: response.usage.prompt_tokens, 319 | completion_tokens: response.usage.completion_tokens, 320 | total_tokens: response.usage.total_tokens, 321 | completion_tokens_details: None, 322 | prompt_tokens_details: None, 323 | }, 324 | system_fingerprint: None, 325 | } 326 | } 327 | } 328 | 329 | #[derive(Debug, Serialize, Deserialize, Clone)] 330 | pub struct Ai21CompletionsRequest { 331 | pub prompt: String, 332 | #[serde(rename = "maxTokens")] 333 | pub max_tokens: u32, 334 | #[serde(skip_serializing_if = "Option::is_none")] 335 | pub temperature: Option, 336 | #[serde(rename = "topP", skip_serializing_if = "Option::is_none")] 337 | pub top_p: Option, 338 | #[serde(rename = "stopSequences")] 339 | pub stop_sequences: Vec, 340 | #[serde(rename = "countPenalty")] 341 | pub count_penalty: PenaltyConfig, 342 | #[serde(rename = "presencePenalty")] 343 | pub presence_penalty: PenaltyConfig, 344 | #[serde(rename = "frequencyPenalty")] 345 | pub frequency_penalty: PenaltyConfig, 346 | } 347 | 348 | #[derive(Debug, Serialize, Deserialize, Clone)] 349 | pub struct PenaltyConfig { 350 | pub scale: i32, 351 | } 352 | 353 | #[derive(Debug, Serialize, Deserialize, Clone)] 354 | pub struct Ai21CompletionsResponse { 355 | pub id: i64, 356 | pub prompt: Ai21Prompt, 357 | pub completions: Vec, 358 | } 359 | 360 | #[derive(Debug, Serialize, Deserialize, Clone)] 361 | pub struct Ai21CompletionWrapper { 362 | pub data: Ai21CompletionData, 363 | #[serde(rename = "finishReason")] 364 | pub finish_reason: Ai21FinishReason, 365 | } 366 | 367 | #[derive(Debug, Serialize, Deserialize, Clone)] 368 | pub struct Ai21Prompt { 369 | pub text: String, 370 | pub tokens: Vec, 371 | } 372 | 373 | #[derive(Debug, Serialize, Deserialize, Clone)] 374 | pub struct Ai21CompletionData { 375 | pub text: String, 376 | pub tokens: Vec, 377 | } 378 | 379 | #[derive(Debug, Serialize, Deserialize, Clone)] 380 | pub struct Ai21Token { 381 | #[serde(rename = "generatedToken")] 382 | pub generated_token: Option, 383 | #[serde(rename = "textRange")] 384 | pub text_range: TextRange, 385 | #[serde(rename = "topTokens")] 386 | pub top_tokens: Option>, 387 | } 388 | 389 | #[derive(Debug, Serialize, Deserialize, Clone)] 390 | pub struct GeneratedToken { 391 | pub token: String, 392 | #[serde(rename = "logprob")] 393 | pub log_prob: f64, 394 | #[serde(rename = "raw_logprob")] 395 | pub raw_log_prob: f64, 396 | } 397 | 398 | #[derive(Debug, Serialize, Deserialize, Clone)] 399 | pub struct TextRange { 400 | pub start: i32, 401 | pub end: i32, 402 | } 403 | 404 | #[derive(Debug, Serialize, Deserialize, Clone)] 405 | pub struct TopToken { 406 | pub token: String, 407 | pub logprob: f64, 408 | } 409 | 410 | #[derive(Debug, Serialize, Deserialize, Clone)] 411 | pub struct Ai21FinishReason { 412 | pub reason: String, 413 | } 414 | 415 | impl From for Ai21CompletionsRequest { 416 | fn from(request: CompletionRequest) -> Self { 417 | Self { 418 | prompt: request.prompt, 419 | max_tokens: request.max_tokens.unwrap_or(default_max_tokens()), 420 | temperature: request.temperature, 421 | top_p: request.top_p, 422 | stop_sequences: request.stop.unwrap_or_default(), 423 | count_penalty: PenaltyConfig { scale: 0 }, 424 | presence_penalty: PenaltyConfig { 425 | scale: if let Some(penalty) = request.presence_penalty { 426 | penalty as i32 427 | } else { 428 | 0 429 | }, 430 | }, 431 | frequency_penalty: PenaltyConfig { 432 | scale: if let Some(penalty) = request.frequency_penalty { 433 | penalty as i32 434 | } else { 435 | 0 436 | }, 437 | }, 438 | } 439 | } 440 | } 441 | 442 | impl From for CompletionResponse { 443 | fn from(response: Ai21CompletionsResponse) -> Self { 444 | let total_prompt_tokens = response.prompt.tokens.len() as u32; 445 | let total_completion_tokens = response 446 | .completions 447 | .iter() 448 | .map(|c| c.data.tokens.len() as u32) 449 | .sum(); 450 | 451 | CompletionResponse { 452 | id: response.id.to_string(), 453 | object: "".to_string(), 454 | created: chrono::Utc::now().timestamp() as u64, 455 | model: "".to_string(), 456 | choices: response 457 | .completions 458 | .into_iter() 459 | .enumerate() 460 | .map(|(index, completion)| CompletionChoice { 461 | text: completion.data.text, 462 | index: index as u32, 463 | logprobs: Some(LogProbs { 464 | tokens: completion 465 | .data 466 | .tokens 467 | .iter() 468 | .filter_map(|t| t.generated_token.as_ref().map(|gt| gt.token.clone())) 469 | .collect(), 470 | token_logprobs: completion 471 | .data 472 | .tokens 473 | .iter() 474 | .filter_map(|t| t.generated_token.as_ref().map(|gt| gt.log_prob as f32)) 475 | .collect(), 476 | top_logprobs: completion 477 | .data 478 | .tokens 479 | .iter() 480 | .map(|t| { 481 | t.top_tokens 482 | .clone() 483 | .map(|tt| { 484 | tt.into_iter() 485 | .map(|top| (top.token, top.logprob as f32)) 486 | .collect() 487 | }) 488 | .unwrap_or_default() 489 | }) 490 | .collect(), 491 | text_offset: completion 492 | .data 493 | .tokens 494 | .iter() 495 | .map(|t| t.text_range.start as usize) 496 | .collect(), 497 | }), 498 | finish_reason: Some(completion.finish_reason.reason), 499 | }) 500 | .collect(), 501 | usage: Usage { 502 | prompt_tokens: total_prompt_tokens, 503 | completion_tokens: total_completion_tokens, 504 | total_tokens: total_prompt_tokens + total_completion_tokens, 505 | completion_tokens_details: None, 506 | prompt_tokens_details: None, 507 | }, 508 | } 509 | } 510 | } 511 | -------------------------------------------------------------------------------- /src/providers/bedrock/provider.rs: -------------------------------------------------------------------------------- 1 | use axum::async_trait; 2 | use axum::http::StatusCode; 3 | use std::error::Error; 4 | 5 | use aws_sdk_bedrockruntime::Client as BedrockRuntimeClient; 6 | 7 | use crate::config::models::{ModelConfig, Provider as ProviderConfig}; 8 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 9 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 10 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; 11 | use crate::providers::provider::Provider; 12 | 13 | use crate::providers::anthropic::{ 14 | AnthropicChatCompletionRequest, AnthropicChatCompletionResponse, 15 | }; 16 | use crate::providers::bedrock::models::{ 17 | Ai21ChatCompletionRequest, Ai21ChatCompletionResponse, Ai21CompletionsRequest, 18 | Ai21CompletionsResponse, TitanChatCompletionRequest, TitanChatCompletionResponse, 19 | TitanEmbeddingRequest, TitanEmbeddingResponse, 20 | }; 21 | use aws_sdk_bedrockruntime::primitives::Blob; 22 | 23 | struct AI21Implementation; 24 | struct TitanImplementation; 25 | struct AnthropicImplementation; 26 | 27 | pub struct BedrockProvider { 28 | pub(crate) config: ProviderConfig, 29 | } 30 | 31 | pub trait ClientProvider { 32 | async fn create_client(&self) -> Result; 33 | } 34 | 35 | #[cfg(not(test))] 36 | impl ClientProvider for BedrockProvider { 37 | async fn create_client(&self) -> Result { 38 | use aws_config::BehaviorVersion; 39 | use aws_config::Region; 40 | use aws_credential_types::Credentials; 41 | 42 | let region = self.config.params.get("region").unwrap().clone(); 43 | let access_key_id = self.config.params.get("AWS_ACCESS_KEY_ID").unwrap().clone(); 44 | let secret_access_key = self 45 | .config 46 | .params 47 | .get("AWS_SECRET_ACCESS_KEY") 48 | .unwrap() 49 | .clone(); 50 | let session_token = self.config.params.get("AWS_SESSION_TOKEN").cloned(); 51 | 52 | let credentials = Credentials::from_keys(access_key_id, secret_access_key, session_token); 53 | 54 | let sdk_config = aws_config::defaults(BehaviorVersion::latest()) 55 | .region(Region::new(region)) 56 | .credentials_provider(credentials) 57 | .load() 58 | .await; 59 | 60 | Ok(BedrockRuntimeClient::new(&sdk_config)) 61 | } 62 | } 63 | 64 | impl BedrockProvider { 65 | fn get_provider_implementation( 66 | &self, 67 | model_config: &ModelConfig, 68 | ) -> Box { 69 | let bedrock_model_provider = model_config.params.get("model_provider").unwrap(); 70 | 71 | let provider_implementation: Box = 72 | match bedrock_model_provider.as_str() { 73 | "ai21" => Box::new(AI21Implementation), 74 | "titan" => Box::new(TitanImplementation), 75 | "anthropic" => Box::new(AnthropicImplementation), 76 | _ => panic!("Invalid bedrock model provider"), 77 | }; 78 | 79 | provider_implementation 80 | } 81 | } 82 | 83 | #[async_trait] 84 | impl Provider for BedrockProvider { 85 | fn new(config: &ProviderConfig) -> Self { 86 | Self { 87 | config: config.clone(), 88 | } 89 | } 90 | 91 | fn key(&self) -> String { 92 | self.config.key.clone() 93 | } 94 | 95 | fn r#type(&self) -> String { 96 | "bedrock".to_string() 97 | } 98 | 99 | async fn chat_completions( 100 | &self, 101 | payload: ChatCompletionRequest, 102 | model_config: &ModelConfig, 103 | ) -> Result { 104 | let client = self.create_client().await.map_err(|e| { 105 | eprintln!("Failed to create Bedrock client: {}", e); 106 | StatusCode::INTERNAL_SERVER_ERROR 107 | })?; 108 | 109 | // Transform model name to include provider prefix 110 | let model_provider = model_config.params.get("model_provider").unwrap(); 111 | let inference_profile_id = self.config.params.get("inference_profile_id"); 112 | let mut transformed_payload = payload; 113 | let model_version = model_config 114 | .params 115 | .get("model_version") 116 | .map_or("v1:0", |s| &**s); 117 | transformed_payload.model = if let Some(profile_id) = inference_profile_id { 118 | format!( 119 | "{}.{}.{}-{}", 120 | profile_id, model_provider, transformed_payload.model, model_version 121 | ) 122 | } else { 123 | format!( 124 | "{}.{}-{}", 125 | model_provider, transformed_payload.model, model_version 126 | ) 127 | }; 128 | 129 | self.get_provider_implementation(model_config) 130 | .chat_completion(&client, transformed_payload) 131 | .await 132 | } 133 | 134 | async fn completions( 135 | &self, 136 | payload: CompletionRequest, 137 | model_config: &ModelConfig, 138 | ) -> Result { 139 | let client = self.create_client().await.map_err(|e| { 140 | eprintln!("Failed to create Bedrock client: {}", e); 141 | StatusCode::INTERNAL_SERVER_ERROR 142 | })?; 143 | 144 | self.get_provider_implementation(model_config) 145 | .completion(&client, payload) 146 | .await 147 | } 148 | 149 | async fn embeddings( 150 | &self, 151 | payload: EmbeddingsRequest, 152 | model_config: &ModelConfig, 153 | ) -> Result { 154 | let client = self.create_client().await.map_err(|e| { 155 | eprintln!("Failed to create Bedrock client: {}", e); 156 | StatusCode::INTERNAL_SERVER_ERROR 157 | })?; 158 | 159 | self.get_provider_implementation(model_config) 160 | .embedding(&client, payload) 161 | .await 162 | } 163 | } 164 | 165 | /** 166 | BEDROCK IMPLEMENTATION TEMPLATE - WILL SERVE AS LAYOUT FOR OTHER IMPLEMENTATIONS 167 | */ 168 | 169 | #[async_trait] 170 | trait BedrockModelImplementation: Send + Sync { 171 | async fn chat_completion( 172 | &self, 173 | client: &BedrockRuntimeClient, 174 | payload: ChatCompletionRequest, 175 | ) -> Result; 176 | 177 | async fn completion( 178 | &self, 179 | _client: &BedrockRuntimeClient, 180 | _payload: CompletionRequest, 181 | ) -> Result { 182 | Err(StatusCode::NOT_IMPLEMENTED) 183 | } 184 | 185 | async fn embedding( 186 | &self, 187 | _client: &BedrockRuntimeClient, 188 | _payload: EmbeddingsRequest, 189 | ) -> Result { 190 | Err(StatusCode::NOT_IMPLEMENTED) 191 | } 192 | } 193 | 194 | trait BedrockRequestHandler { 195 | async fn handle_bedrock_request( 196 | &self, 197 | client: &BedrockRuntimeClient, 198 | model_id: &str, 199 | request: T, 200 | error_context: &str, 201 | ) -> Result 202 | where 203 | T: serde::Serialize + std::marker::Send, 204 | U: for<'de> serde::Deserialize<'de>, 205 | { 206 | // Serialize request 207 | let request_json = serde_json::to_vec(&request).map_err(|e| { 208 | eprintln!("Failed to serialize {}: {}", error_context, e); 209 | StatusCode::INTERNAL_SERVER_ERROR 210 | })?; 211 | 212 | // Make API call 213 | let response = client 214 | .invoke_model() 215 | .body(Blob::new(request_json)) 216 | .model_id(model_id) 217 | .send() 218 | .await 219 | .map_err(|e| { 220 | eprintln!("Bedrock API error for {}: {:?}", error_context, e); 221 | eprintln!( 222 | "Error details - Source: {}, Raw error: {:?}", 223 | e.source().unwrap_or(&e), 224 | e.raw_response() 225 | ); 226 | StatusCode::INTERNAL_SERVER_ERROR 227 | })?; 228 | 229 | // Deserialize response 230 | serde_json::from_slice(&response.body.into_inner()).map_err(|e| { 231 | eprintln!("Failed to deserialize {} response: {}", error_context, e); 232 | StatusCode::INTERNAL_SERVER_ERROR 233 | }) 234 | } 235 | } 236 | 237 | impl BedrockRequestHandler for AI21Implementation {} 238 | impl BedrockRequestHandler for TitanImplementation {} 239 | impl BedrockRequestHandler for AnthropicImplementation {} 240 | 241 | /** 242 | AI21 IMPLEMENTATION 243 | */ 244 | 245 | #[async_trait] 246 | impl BedrockModelImplementation for AI21Implementation { 247 | async fn chat_completion( 248 | &self, 249 | client: &BedrockRuntimeClient, 250 | payload: ChatCompletionRequest, 251 | ) -> Result { 252 | let ai21_request = Ai21ChatCompletionRequest::from(payload.clone()); 253 | let ai21_response: Ai21ChatCompletionResponse = self 254 | .handle_bedrock_request(client, &payload.model, ai21_request, "AI21 chat completion") 255 | .await?; 256 | 257 | Ok(ChatCompletionResponse::NonStream(ai21_response.into())) 258 | } 259 | 260 | async fn completion( 261 | &self, 262 | client: &BedrockRuntimeClient, 263 | payload: CompletionRequest, 264 | ) -> Result { 265 | // Bedrock AI21 supports completions in legacy models similar to openai 266 | let ai21_request = Ai21CompletionsRequest::from(payload.clone()); 267 | let ai21_response: Ai21CompletionsResponse = self 268 | .handle_bedrock_request(client, &payload.model, ai21_request, "AI21 completion") 269 | .await?; 270 | 271 | Ok(CompletionResponse::from(ai21_response)) 272 | } 273 | } 274 | 275 | /** 276 | TITAN IMPLEMENTATION 277 | */ 278 | 279 | #[async_trait] 280 | impl BedrockModelImplementation for TitanImplementation { 281 | async fn chat_completion( 282 | &self, 283 | client: &BedrockRuntimeClient, 284 | payload: ChatCompletionRequest, 285 | ) -> Result { 286 | let titan_request = TitanChatCompletionRequest::from(payload.clone()); 287 | let titan_response: TitanChatCompletionResponse = self 288 | .handle_bedrock_request( 289 | client, 290 | &payload.model, 291 | titan_request, 292 | "Titan chat completion", 293 | ) 294 | .await?; 295 | 296 | Ok(ChatCompletionResponse::NonStream(titan_response.into())) 297 | } 298 | 299 | async fn embedding( 300 | &self, 301 | client: &BedrockRuntimeClient, 302 | payload: EmbeddingsRequest, 303 | ) -> Result { 304 | let titan_request = TitanEmbeddingRequest::from(payload.clone()); 305 | let titan_response: TitanEmbeddingResponse = self 306 | .handle_bedrock_request(client, &payload.model, titan_request, "Titan embedding") 307 | .await?; 308 | 309 | Ok(EmbeddingsResponse::from(titan_response)) 310 | } 311 | } 312 | 313 | /** 314 | ANTHROPIC IMPLEMENTATION 315 | */ 316 | 317 | #[async_trait] 318 | impl BedrockModelImplementation for AnthropicImplementation { 319 | async fn chat_completion( 320 | &self, 321 | client: &BedrockRuntimeClient, 322 | payload: ChatCompletionRequest, 323 | ) -> Result { 324 | let anthropic_request = AnthropicChatCompletionRequest::from(payload.clone()); 325 | 326 | // Convert to Value for Bedrock-specific modifications 327 | let mut request_value = serde_json::to_value(&anthropic_request).map_err(|e| { 328 | eprintln!("Failed to serialize Anthropic request: {}", e); 329 | StatusCode::INTERNAL_SERVER_ERROR 330 | })?; 331 | 332 | if let serde_json::Value::Object(ref mut map) = request_value { 333 | map.remove("model"); 334 | map.insert( 335 | "anthropic_version".to_string(), 336 | serde_json::Value::String("bedrock-2023-05-31".to_string()), 337 | ); 338 | } 339 | 340 | let anthropic_response: AnthropicChatCompletionResponse = self 341 | .handle_bedrock_request( 342 | client, 343 | &payload.model, 344 | request_value, 345 | "Anthropic chat completion", 346 | ) 347 | .await?; 348 | 349 | Ok(ChatCompletionResponse::NonStream(anthropic_response.into())) 350 | } 351 | } 352 | -------------------------------------------------------------------------------- /src/providers/bedrock/test.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | impl crate::providers::bedrock::provider::ClientProvider 3 | for crate::providers::bedrock::BedrockProvider 4 | { 5 | // COMMENT OUT THIS BLOCK TO RUN AGAINST ACTUAL AWS SERVICES 6 | // OR CHANGE YOUR ENVIRONMENT FROM TEST TO PROD 7 | async fn create_client(&self) -> Result { 8 | let handler = self 9 | .config 10 | .params 11 | .get("test_response_handler") 12 | .map(|s| s.as_str()); 13 | let mock_responses = match handler { 14 | Some("anthropic_chat_completion") => vec![dummy_anthropic_chat_completion_response()], 15 | Some("ai21_chat_completion") => vec![dummy_ai21_chat_completion_response()], 16 | Some("ai21_completion") => vec![dummy_ai21_completion_response()], 17 | Some("titan_chat_completion") => vec![dummy_titan_chat_completion_response()], 18 | Some("titan_embedding") => vec![dummy_titan_embedding_response()], 19 | _ => vec![], 20 | }; 21 | let test_client = create_test_bedrock_client(mock_responses).await; 22 | Ok(test_client) 23 | } 24 | } 25 | 26 | #[cfg(test)] 27 | fn get_test_provider_config( 28 | region: &str, 29 | test_response_handler: &str, 30 | ) -> crate::config::models::Provider { 31 | use std::collections::HashMap; 32 | 33 | let mut params = HashMap::new(); 34 | params.insert("region".to_string(), region.to_string()); 35 | 36 | let aws_access_key_id = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or("test_id".to_string()); 37 | let aws_secret_access_key = 38 | std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or("test_key".to_string()); 39 | 40 | params.insert("AWS_ACCESS_KEY_ID".to_string(), aws_access_key_id); 41 | params.insert("AWS_SECRET_ACCESS_KEY".to_string(), aws_secret_access_key); 42 | 43 | params.insert( 44 | "test_response_handler".to_string(), 45 | format!("{}", test_response_handler).to_string(), 46 | ); 47 | 48 | crate::config::models::Provider { 49 | key: "test_key".to_string(), 50 | r#type: "".to_string(), 51 | api_key: "".to_string(), 52 | params, 53 | } 54 | } 55 | #[cfg(test)] 56 | fn get_test_model_config( 57 | model_type: &str, 58 | provider_type: &str, 59 | ) -> crate::config::models::ModelConfig { 60 | use std::collections::HashMap; 61 | 62 | let mut params = HashMap::new(); 63 | params.insert("model_provider".to_string(), provider_type.to_string()); 64 | 65 | crate::config::models::ModelConfig { 66 | key: "test-model".to_string(), 67 | r#type: model_type.to_string(), 68 | provider: "bedrock".to_string(), 69 | params, 70 | } 71 | } 72 | 73 | #[cfg(test)] 74 | mod antropic_tests { 75 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 76 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; 77 | use crate::providers::bedrock::test::{get_test_model_config, get_test_provider_config}; 78 | use crate::providers::bedrock::BedrockProvider; 79 | use crate::providers::provider::Provider; 80 | 81 | #[test] 82 | fn test_bedrock_provider_new() { 83 | let config = get_test_provider_config("us-east-1", ""); 84 | let provider = BedrockProvider::new(&config); 85 | 86 | assert_eq!(provider.key(), "test_key"); 87 | assert_eq!(provider.r#type(), "bedrock"); 88 | } 89 | 90 | #[tokio::test] 91 | async fn test_bedrock_provider_chat_completions() { 92 | let config = get_test_provider_config("us-east-2", "anthropic_chat_completion"); 93 | let provider = BedrockProvider::new(&config); 94 | 95 | let model_config = 96 | get_test_model_config("us.anthropic.claude-3-haiku-20240307-v1:0", "anthropic"); 97 | 98 | let payload = ChatCompletionRequest { 99 | model: "us.anthropic.claude-3-haiku-20240307-v1:0".to_string(), 100 | messages: vec![ChatCompletionMessage { 101 | role: "user".to_string(), 102 | content: Some(ChatMessageContent::String( 103 | "Tell me a short joke".to_string(), 104 | )), 105 | name: None, 106 | tool_calls: None, 107 | refusal: None, 108 | }], 109 | temperature: None, 110 | top_p: None, 111 | n: None, 112 | stream: None, 113 | stop: None, 114 | max_tokens: None, 115 | max_completion_tokens: None, 116 | parallel_tool_calls: None, 117 | presence_penalty: None, 118 | frequency_penalty: None, 119 | logit_bias: None, 120 | tool_choice: None, 121 | tools: None, 122 | user: None, 123 | logprobs: None, 124 | top_logprobs: None, 125 | response_format: None, 126 | }; 127 | 128 | let result = provider.chat_completions(payload, &model_config).await; 129 | 130 | assert!(result.is_ok(), "Chat completion failed: {:?}", result.err()); 131 | 132 | if let Ok(ChatCompletionResponse::NonStream(completion)) = result { 133 | assert!(!completion.choices.is_empty(), "Expected non-empty choices"); 134 | assert!( 135 | completion.usage.total_tokens > 0, 136 | "Expected non-zero token usage" 137 | ); 138 | 139 | let first_choice = &completion.choices[0]; 140 | assert!( 141 | first_choice.message.content.is_some(), 142 | "Expected message content" 143 | ); 144 | assert_eq!( 145 | first_choice.message.role, "assistant", 146 | "Expected assistant role" 147 | ); 148 | } 149 | } 150 | } 151 | 152 | #[cfg(test)] 153 | mod titan_tests { 154 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 155 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; 156 | use crate::models::embeddings::EmbeddingsInput::Single; 157 | use crate::models::embeddings::{Embedding, EmbeddingsRequest}; 158 | use crate::providers::bedrock::test::{get_test_model_config, get_test_provider_config}; 159 | use crate::providers::bedrock::BedrockProvider; 160 | use crate::providers::provider::Provider; 161 | 162 | #[test] 163 | fn test_titan_provider_new() { 164 | let config = get_test_provider_config("us-east-2", ""); 165 | let provider = BedrockProvider::new(&config); 166 | 167 | assert_eq!(provider.key(), "test_key"); 168 | assert_eq!(provider.r#type(), "bedrock"); 169 | } 170 | 171 | #[tokio::test] 172 | async fn test_embeddings() { 173 | let config = get_test_provider_config("us-east-2", "titan_embedding"); 174 | let provider = BedrockProvider::new(&config); 175 | let model_config = get_test_model_config("amazon.titan-embed-text-v2:0", "titan"); 176 | 177 | let payload = EmbeddingsRequest { 178 | model: "amazon.titan-embed-text-v2:0".to_string(), 179 | user: None, 180 | input: Single("this is where you place your input text".to_string()), 181 | encoding_format: None, 182 | }; 183 | 184 | let result = provider.embeddings(payload, &model_config).await; 185 | assert!( 186 | result.is_ok(), 187 | "Titan Embeddings generation failed: {:?}", 188 | result.err() 189 | ); 190 | let response = result.unwrap(); 191 | assert!( 192 | !response.data.is_empty(), 193 | "Expected non-empty embeddings data" 194 | ); 195 | assert!( 196 | matches!(&response.data[0].embedding, Embedding::Float(vec) if !vec.is_empty()), 197 | "Expected non-empty Float embedding vector", 198 | ); 199 | assert!( 200 | response.usage.prompt_tokens > Some(0), 201 | "Expected non-zero token usage" 202 | ); 203 | } 204 | 205 | #[tokio::test] 206 | async fn test_chat_completions() { 207 | let config = get_test_provider_config("us-east-2", "titan_chat_completion"); 208 | let provider = BedrockProvider::new(&config); 209 | 210 | let model_config = get_test_model_config("amazon.titan-embed-text-v2:0", "titan"); 211 | 212 | let payload = ChatCompletionRequest { 213 | model: "us.amazon.nova-lite-v1:0".to_string(), 214 | messages: vec![ChatCompletionMessage { 215 | role: "user".to_string(), 216 | content: Some(ChatMessageContent::String( 217 | "What is the capital of France? Answer in one word.".to_string(), 218 | )), 219 | name: None, 220 | tool_calls: None, 221 | refusal: None, 222 | }], 223 | temperature: None, 224 | top_p: None, 225 | n: None, 226 | stream: None, 227 | stop: None, 228 | max_tokens: None, 229 | max_completion_tokens: None, 230 | parallel_tool_calls: None, 231 | presence_penalty: None, 232 | frequency_penalty: None, 233 | logit_bias: None, 234 | tool_choice: None, 235 | tools: None, 236 | user: None, 237 | logprobs: None, 238 | top_logprobs: None, 239 | response_format: None, 240 | }; 241 | 242 | let result = provider.chat_completions(payload, &model_config).await; 243 | assert!(result.is_ok(), "Chat completion failed: {:?}", result.err()); 244 | 245 | if let Ok(ChatCompletionResponse::NonStream(completion)) = result { 246 | assert!(!completion.choices.is_empty(), "Expected non-empty choices"); 247 | assert!( 248 | completion.usage.total_tokens > 0, 249 | "Expected non-zero token usage" 250 | ); 251 | 252 | let first_choice = &completion.choices[0]; 253 | assert!( 254 | first_choice.message.content.is_some(), 255 | "Expected message content" 256 | ); 257 | assert_eq!( 258 | first_choice.message.role, "assistant", 259 | "Expected assistant role" 260 | ); 261 | } 262 | } 263 | } 264 | 265 | #[cfg(test)] 266 | mod ai21_tests { 267 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 268 | use crate::models::completion::CompletionRequest; 269 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; 270 | use crate::providers::bedrock::test::{get_test_model_config, get_test_provider_config}; 271 | use crate::providers::bedrock::BedrockProvider; 272 | use crate::providers::provider::Provider; 273 | 274 | #[test] 275 | fn test_ai21_provider_new() { 276 | let config = get_test_provider_config("us-east-1", ""); 277 | let provider = BedrockProvider::new(&config); 278 | 279 | assert_eq!(provider.key(), "test_key"); 280 | assert_eq!(provider.r#type(), "bedrock"); 281 | } 282 | 283 | #[tokio::test] 284 | async fn test_ai21_provider_completions() { 285 | let config = get_test_provider_config("us-east-1", "ai21_completion"); 286 | let provider = BedrockProvider::new(&config); 287 | 288 | let model_config = get_test_model_config("ai21.j2-mid-v1", "ai21"); 289 | 290 | let payload = CompletionRequest { 291 | model: "ai21.j2-mid-v1".to_string(), 292 | prompt: "Tell me a joke".to_string(), 293 | suffix: None, 294 | max_tokens: Some(400), 295 | temperature: None, 296 | top_p: None, 297 | n: None, 298 | stream: None, 299 | logprobs: None, 300 | echo: None, 301 | stop: None, 302 | presence_penalty: None, 303 | frequency_penalty: None, 304 | best_of: None, 305 | logit_bias: None, 306 | user: None, 307 | }; 308 | 309 | let result = provider.completions(payload, &model_config).await; 310 | assert!(result.is_ok(), "Completion failed: {:?}", result.err()); 311 | 312 | let response = result.unwrap(); 313 | assert!(!response.choices.is_empty(), "Expected non-empty choices"); 314 | assert!( 315 | response.usage.total_tokens > 0, 316 | "Expected non-zero token usage" 317 | ); 318 | 319 | let first_choice = &response.choices[0]; 320 | assert!( 321 | !first_choice.text.is_empty(), 322 | "Expected non-empty completion text" 323 | ); 324 | assert!( 325 | first_choice.logprobs.is_some(), 326 | "Expected logprobs to be present" 327 | ); 328 | } 329 | 330 | #[tokio::test] 331 | async fn test_ai21_provider_chat_completions() { 332 | let config = get_test_provider_config("us-east-1", "ai21_chat_completion"); 333 | let provider = BedrockProvider::new(&config); 334 | 335 | let model_config = get_test_model_config("ai21.jamba-1-5-mini-v1:0", "ai21"); 336 | 337 | let payload = ChatCompletionRequest { 338 | model: "ai21.jamba-1-5-mini-v1:0".to_string(), 339 | messages: vec![ChatCompletionMessage { 340 | role: "user".to_string(), 341 | content: Some(ChatMessageContent::String( 342 | "Tell me a short joke".to_string(), 343 | )), 344 | name: None, 345 | tool_calls: None, 346 | refusal: None, 347 | }], 348 | temperature: Some(0.8), 349 | top_p: Some(0.8), 350 | n: None, 351 | stream: None, 352 | stop: None, 353 | max_tokens: None, 354 | max_completion_tokens: None, 355 | parallel_tool_calls: None, 356 | presence_penalty: None, 357 | frequency_penalty: None, 358 | logit_bias: None, 359 | tool_choice: None, 360 | tools: None, 361 | user: None, 362 | logprobs: None, 363 | top_logprobs: None, 364 | response_format: None, 365 | }; 366 | 367 | let result = provider.chat_completions(payload, &model_config).await; 368 | assert!(result.is_ok(), "Chat completion failed: {:?}", result.err()); 369 | 370 | if let Ok(ChatCompletionResponse::NonStream(completion)) = result { 371 | assert!(!completion.choices.is_empty(), "Expected non-empty choices"); 372 | assert!( 373 | completion.usage.total_tokens > 0, 374 | "Expected non-zero token usage" 375 | ); 376 | 377 | let first_choice = &completion.choices[0]; 378 | assert!( 379 | first_choice.message.content.is_some(), 380 | "Expected message content" 381 | ); 382 | assert_eq!( 383 | first_choice.message.role, "assistant", 384 | "Expected assistant role" 385 | ); 386 | } 387 | } 388 | } 389 | /** 390 | 391 | Helper functions for creating test clients and mock responses 392 | 393 | */ 394 | #[cfg(test)] 395 | async fn create_test_bedrock_client( 396 | mock_responses: Vec, 397 | ) -> aws_sdk_bedrockruntime::Client { 398 | use aws_config::BehaviorVersion; 399 | use aws_credential_types::provider::SharedCredentialsProvider; 400 | use aws_credential_types::Credentials; 401 | use aws_smithy_runtime::client::http::test_util::StaticReplayClient; 402 | use aws_types::region::Region; 403 | 404 | let replay_client = StaticReplayClient::new(mock_responses); 405 | 406 | let credentials = Credentials::new("test-key", "test-secret", None, None, "testing"); 407 | let credentials_provider = SharedCredentialsProvider::new(credentials); 408 | 409 | let config = aws_config::SdkConfig::builder() 410 | .behavior_version(BehaviorVersion::latest()) 411 | .region(Region::new("us-east-1".to_string())) 412 | .credentials_provider(credentials_provider) 413 | .http_client(replay_client) 414 | .build(); 415 | 416 | aws_sdk_bedrockruntime::Client::new(&config) 417 | } 418 | #[cfg(test)] 419 | fn read_response_file(filename: &str) -> Result { 420 | use std::fs; 421 | use std::io::Read; 422 | use std::path::Path; 423 | 424 | let log_dir = Path::new("src/providers/bedrock/logs"); 425 | let file_path = log_dir.join(filename); 426 | 427 | let mut file = fs::File::open(file_path)?; 428 | let mut contents = String::new(); 429 | file.read_to_string(&mut contents)?; 430 | 431 | Ok(contents) 432 | } 433 | 434 | /** 435 | 436 | Mock responses for the Bedrock API 437 | 438 | */ 439 | #[cfg(test)] 440 | fn dummy_anthropic_chat_completion_response( 441 | ) -> aws_smithy_runtime::client::http::test_util::ReplayEvent { 442 | use aws_smithy_types::body::SdkBody; 443 | 444 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new( 445 | http::Request::builder() 446 | .method("POST") 447 | .uri("https://bedrock-runtime.us-east-2.amazonaws.com/model/us.anthropic.claude-3-haiku-20240307-v1:0/invoke") 448 | .body(SdkBody::empty()) 449 | .unwrap(), 450 | http::Response::builder() 451 | .status(200) 452 | .body(SdkBody::from(read_response_file("anthropic_claude_3_haiku_20240307_v1_0_chat_completion.json").unwrap())) 453 | .unwrap(), 454 | ) 455 | } 456 | #[cfg(test)] 457 | fn dummy_ai21_chat_completion_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent 458 | { 459 | use aws_smithy_types::body::SdkBody; 460 | 461 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new( 462 | http::Request::builder() 463 | .method("POST") 464 | .uri("https://bedrock-runtime.us-east-1.amazonaws.com/model/ai21.jamba-1-5-mini-v1:0/invoke") 465 | .body(SdkBody::empty()) 466 | .unwrap(), 467 | http::Response::builder() 468 | .status(200) 469 | .body(SdkBody::from(read_response_file("ai21_jamba_1_5_mini_v1_0_chat_completions.json").unwrap())) 470 | .unwrap(), 471 | ) 472 | } 473 | #[cfg(test)] 474 | fn dummy_ai21_completion_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent { 475 | use aws_smithy_types::body::SdkBody; 476 | 477 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new( 478 | http::Request::builder() 479 | .method("POST") 480 | .uri("https://bedrock-runtime.us-east-1.amazonaws.com/model/ai21.j2-mid-v1/invoke") 481 | .body(SdkBody::empty()) 482 | .unwrap(), 483 | http::Response::builder() 484 | .status(200) 485 | .body(SdkBody::from( 486 | read_response_file("ai21_j2_mid_v1_completions.json").unwrap(), 487 | )) 488 | .unwrap(), 489 | ) 490 | } 491 | #[cfg(test)] 492 | fn dummy_titan_embedding_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent { 493 | use aws_smithy_types::body::SdkBody; 494 | 495 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new( 496 | http::Request::builder() 497 | .method("POST") 498 | .uri("https://bedrock-runtime.us-east-2.amazonaws.com/model/amazon.titan-embed-text-v2:0/invoke") 499 | .body(SdkBody::empty()) 500 | .unwrap(), 501 | http::Response::builder() 502 | .status(200) 503 | .body(SdkBody::from(read_response_file("amazon_titan_embed_text_v2_0_embeddings.json").unwrap())) 504 | .unwrap(), 505 | ) 506 | } 507 | #[cfg(test)] 508 | fn dummy_titan_chat_completion_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent 509 | { 510 | use aws_smithy_types::body::SdkBody; 511 | 512 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new( 513 | http::Request::builder() 514 | .method("POST") 515 | .uri("https://bedrock-runtime.us-east-2.amazonaws.com/model/us.amazon.nova-lite-v1:0/invoke") 516 | .body(SdkBody::empty()) 517 | .unwrap(), 518 | http::Response::builder() 519 | .status(200) 520 | .body(SdkBody::from(read_response_file("us_amazon_nova_lite_v1_0_chat_completion.json").unwrap())) 521 | .unwrap(), 522 | ) 523 | } 524 | -------------------------------------------------------------------------------- /src/providers/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod anthropic; 2 | pub mod azure; 3 | pub mod bedrock; 4 | pub mod openai; 5 | pub mod provider; 6 | pub mod registry; 7 | pub mod vertexai; 8 | -------------------------------------------------------------------------------- /src/providers/openai/mod.rs: -------------------------------------------------------------------------------- 1 | mod provider; 2 | 3 | pub use provider::OpenAIProvider; 4 | -------------------------------------------------------------------------------- /src/providers/openai/provider.rs: -------------------------------------------------------------------------------- 1 | use crate::config::constants::stream_buffer_size_bytes; 2 | use crate::config::models::{ModelConfig, Provider as ProviderConfig}; 3 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 4 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 5 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; 6 | use crate::models::streaming::ChatCompletionChunk; 7 | use crate::providers::provider::Provider; 8 | use axum::async_trait; 9 | use axum::http::StatusCode; 10 | use reqwest::Client; 11 | use reqwest_streams::*; 12 | use tracing::info; 13 | 14 | pub struct OpenAIProvider { 15 | config: ProviderConfig, 16 | http_client: Client, 17 | } 18 | 19 | impl OpenAIProvider { 20 | fn base_url(&self) -> String { 21 | self.config 22 | .params 23 | .get("base_url") 24 | .unwrap_or(&String::from("https://api.openai.com/v1")) 25 | .to_string() 26 | } 27 | } 28 | 29 | #[async_trait] 30 | impl Provider for OpenAIProvider { 31 | fn new(config: &ProviderConfig) -> Self { 32 | Self { 33 | config: config.clone(), 34 | http_client: Client::new(), 35 | } 36 | } 37 | 38 | fn key(&self) -> String { 39 | self.config.key.clone() 40 | } 41 | 42 | fn r#type(&self) -> String { 43 | "openai".to_string() 44 | } 45 | 46 | async fn chat_completions( 47 | &self, 48 | payload: ChatCompletionRequest, 49 | _model_config: &ModelConfig, 50 | ) -> Result { 51 | let response = self 52 | .http_client 53 | .post(format!("{}/chat/completions", self.base_url())) 54 | .header("Authorization", format!("Bearer {}", self.config.api_key)) 55 | .json(&payload) 56 | .send() 57 | .await 58 | .map_err(|e| { 59 | eprintln!("OpenAI API request error: {}", e); 60 | StatusCode::INTERNAL_SERVER_ERROR 61 | })?; 62 | 63 | let status = response.status(); 64 | if status.is_success() { 65 | if payload.stream.unwrap_or(false) { 66 | let stream = 67 | response.json_array_stream::(stream_buffer_size_bytes()); 68 | Ok(ChatCompletionResponse::Stream(stream)) 69 | } else { 70 | response 71 | .json() 72 | .await 73 | .map(ChatCompletionResponse::NonStream) 74 | .map_err(|e| { 75 | eprintln!("OpenAI API response error: {}", e); 76 | StatusCode::INTERNAL_SERVER_ERROR 77 | }) 78 | } 79 | } else { 80 | info!( 81 | "OpenAI API request error: {}", 82 | response.text().await.unwrap() 83 | ); 84 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 85 | } 86 | } 87 | 88 | async fn completions( 89 | &self, 90 | payload: CompletionRequest, 91 | _model_config: &ModelConfig, 92 | ) -> Result { 93 | let response = self 94 | .http_client 95 | .post(format!("{}/completions", self.base_url())) 96 | .header("Authorization", format!("Bearer {}", self.config.api_key)) 97 | .json(&payload) 98 | .send() 99 | .await 100 | .map_err(|e| { 101 | eprintln!("OpenAI API request error: {}", e); 102 | StatusCode::INTERNAL_SERVER_ERROR 103 | })?; 104 | 105 | let status = response.status(); 106 | if status.is_success() { 107 | response.json().await.map_err(|e| { 108 | eprintln!("OpenAI API response error: {}", e); 109 | StatusCode::INTERNAL_SERVER_ERROR 110 | }) 111 | } else { 112 | eprintln!( 113 | "OpenAI API request error: {}", 114 | response.text().await.unwrap() 115 | ); 116 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 117 | } 118 | } 119 | 120 | async fn embeddings( 121 | &self, 122 | payload: EmbeddingsRequest, 123 | _model_config: &ModelConfig, 124 | ) -> Result { 125 | let response = self 126 | .http_client 127 | .post(format!("{}/embeddings", self.base_url())) 128 | .header("Authorization", format!("Bearer {}", self.config.api_key)) 129 | .json(&payload) 130 | .send() 131 | .await 132 | .map_err(|e| { 133 | eprintln!("OpenAI API request error: {}", e); 134 | StatusCode::INTERNAL_SERVER_ERROR 135 | })?; 136 | 137 | let status = response.status(); 138 | if status.is_success() { 139 | response.json().await.map_err(|e| { 140 | eprintln!("OpenAI API response error: {}", e); 141 | StatusCode::INTERNAL_SERVER_ERROR 142 | }) 143 | } else { 144 | eprintln!( 145 | "OpenAI API request error: {}", 146 | response.text().await.unwrap() 147 | ); 148 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 149 | } 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/providers/provider.rs: -------------------------------------------------------------------------------- 1 | use axum::async_trait; 2 | use axum::http::StatusCode; 3 | 4 | use crate::config::models::{ModelConfig, Provider as ProviderConfig}; 5 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 6 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 7 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse}; 8 | 9 | #[async_trait] 10 | pub trait Provider: Send + Sync { 11 | fn new(config: &ProviderConfig) -> Self 12 | where 13 | Self: Sized; 14 | fn key(&self) -> String; 15 | fn r#type(&self) -> String; 16 | 17 | async fn chat_completions( 18 | &self, 19 | payload: ChatCompletionRequest, 20 | model_config: &ModelConfig, 21 | ) -> Result; 22 | 23 | async fn completions( 24 | &self, 25 | payload: CompletionRequest, 26 | model_config: &ModelConfig, 27 | ) -> Result; 28 | 29 | async fn embeddings( 30 | &self, 31 | payload: EmbeddingsRequest, 32 | model_config: &ModelConfig, 33 | ) -> Result; 34 | } 35 | -------------------------------------------------------------------------------- /src/providers/registry.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use std::collections::HashMap; 3 | use std::sync::Arc; 4 | 5 | use crate::config::models::Provider as ProviderConfig; 6 | use crate::providers::{ 7 | anthropic::AnthropicProvider, azure::AzureProvider, bedrock::BedrockProvider, 8 | openai::OpenAIProvider, provider::Provider, vertexai::VertexAIProvider, 9 | }; 10 | 11 | pub struct ProviderRegistry { 12 | providers: HashMap>, 13 | } 14 | 15 | impl ProviderRegistry { 16 | pub fn new(provider_configs: &[ProviderConfig]) -> Result { 17 | let mut providers = HashMap::new(); 18 | 19 | for config in provider_configs { 20 | let provider: Arc = match config.r#type.as_str() { 21 | "openai" => Arc::new(OpenAIProvider::new(config)), 22 | "anthropic" => Arc::new(AnthropicProvider::new(config)), 23 | "azure" => Arc::new(AzureProvider::new(config)), 24 | "bedrock" => Arc::new(BedrockProvider::new(config)), 25 | "vertexai" => Arc::new(VertexAIProvider::new(config)), 26 | _ => continue, 27 | }; 28 | providers.insert(config.key.clone(), provider); 29 | } 30 | 31 | Ok(Self { providers }) 32 | } 33 | 34 | pub fn get(&self, name: &str) -> Option> { 35 | self.providers.get(name).cloned() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/providers/vertexai/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod models; 2 | pub mod provider; 3 | #[cfg(test)] 4 | mod tests; 5 | 6 | pub use provider::VertexAIProvider; 7 | -------------------------------------------------------------------------------- /src/providers/vertexai/models.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use serde_json::Value; 3 | 4 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest}; 5 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; 6 | use crate::models::streaming::{ChatCompletionChunk, Choice, ChoiceDelta}; 7 | use crate::models::tool_calls::{ChatMessageToolCall, FunctionCall}; 8 | use crate::models::tool_choice::{SimpleToolChoice, ToolChoice}; 9 | use crate::models::usage::Usage; 10 | 11 | #[derive(Debug, Serialize, Deserialize)] 12 | pub struct GeminiChatRequest { 13 | pub contents: Vec, 14 | #[serde(skip_serializing_if = "Option::is_none")] 15 | pub generation_config: Option, 16 | #[serde(skip_serializing_if = "Option::is_none")] 17 | pub safety_settings: Option>, 18 | #[serde(skip_serializing_if = "Option::is_none")] 19 | pub tools: Option>, 20 | #[serde(skip_serializing_if = "Option::is_none")] 21 | pub tool_choice: Option, 22 | } 23 | 24 | #[derive(Debug, Serialize, Deserialize)] 25 | pub struct GeminiTool { 26 | pub function_declarations: Vec, 27 | } 28 | 29 | #[derive(Debug, Serialize, Deserialize)] 30 | pub struct GeminiFunctionDeclaration { 31 | pub name: String, 32 | pub description: Option, 33 | pub parameters: Value, 34 | } 35 | 36 | #[derive(Debug, Serialize, Deserialize)] 37 | #[serde(rename_all = "snake_case")] 38 | pub enum GeminiToolChoice { 39 | None, 40 | Auto, 41 | Function(GeminiFunctionChoice), 42 | } 43 | 44 | #[derive(Debug, Serialize, Deserialize)] 45 | pub struct GeminiFunctionChoice { 46 | pub name: String, 47 | } 48 | 49 | #[derive(Debug, Serialize, Deserialize)] 50 | pub struct GeminiContent { 51 | pub role: String, 52 | pub parts: Vec, 53 | } 54 | 55 | #[derive(Debug, Serialize, Deserialize)] 56 | pub struct ContentPart { 57 | #[serde(skip_serializing_if = "Option::is_none")] 58 | pub text: Option, 59 | #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")] 60 | pub function_call: Option, 61 | } 62 | 63 | #[derive(Debug, Serialize, Deserialize)] 64 | pub struct GenerationConfig { 65 | #[serde(skip_serializing_if = "Option::is_none")] 66 | pub temperature: Option, 67 | #[serde(skip_serializing_if = "Option::is_none")] 68 | pub top_p: Option, 69 | #[serde(skip_serializing_if = "Option::is_none")] 70 | pub top_k: Option, 71 | #[serde(skip_serializing_if = "Option::is_none")] 72 | pub max_output_tokens: Option, 73 | #[serde(skip_serializing_if = "Option::is_none")] 74 | pub stop_sequences: Option>, 75 | } 76 | 77 | #[derive(Debug, Serialize, Deserialize)] 78 | pub struct SafetySetting { 79 | pub category: String, 80 | pub threshold: String, 81 | } 82 | 83 | #[derive(Debug, Serialize, Deserialize)] 84 | pub struct GeminiChatResponse { 85 | pub candidates: Vec, 86 | pub usage_metadata: Option, 87 | } 88 | 89 | #[derive(Debug, Serialize, Deserialize)] 90 | pub struct GeminiCandidate { 91 | pub content: GeminiContent, 92 | pub finish_reason: Option, 93 | pub safety_ratings: Option>, 94 | pub tool_calls: Option>, 95 | } 96 | 97 | #[derive(Debug, Serialize, Deserialize, Clone)] 98 | pub struct GeminiToolCall { 99 | pub function: GeminiFunctionCall, 100 | } 101 | 102 | #[derive(Debug, Serialize, Deserialize, Clone)] 103 | pub struct GeminiFunctionCall { 104 | pub name: String, 105 | pub args: Value, 106 | } 107 | 108 | #[derive(Debug, Serialize, Deserialize)] 109 | pub struct SafetyRating { 110 | pub category: String, 111 | pub probability: String, 112 | } 113 | 114 | #[derive(Debug, Serialize, Deserialize, Clone)] 115 | pub struct UsageMetadata { 116 | pub prompt_token_count: i32, 117 | pub candidates_token_count: i32, 118 | pub total_token_count: i32, 119 | } 120 | 121 | #[derive(Debug, Deserialize)] 122 | pub struct VertexAIStreamChunk { 123 | pub candidates: Vec, 124 | pub usage_metadata: Option, 125 | } 126 | 127 | impl From for GeminiChatRequest { 128 | fn from(req: ChatCompletionRequest) -> Self { 129 | let contents = req 130 | .messages 131 | .into_iter() 132 | .map(|msg| GeminiContent { 133 | role: match msg.role.as_str() { 134 | "assistant" => "model".to_string(), 135 | role => role.to_string(), 136 | }, 137 | parts: vec![ContentPart { 138 | text: match msg.content { 139 | Some(content) => match content { 140 | ChatMessageContent::String(text) => Some(text), 141 | ChatMessageContent::Array(parts) => Some( 142 | parts 143 | .into_iter() 144 | .map(|p| p.text) 145 | .collect::>() 146 | .join(" "), 147 | ), 148 | }, 149 | None => None, 150 | }, 151 | function_call: None, 152 | }], 153 | }) 154 | .collect(); 155 | 156 | let generation_config = Some(GenerationConfig { 157 | temperature: req.temperature, 158 | top_p: req.top_p, 159 | top_k: None, 160 | max_output_tokens: req.max_tokens, 161 | stop_sequences: req.stop, 162 | }); 163 | 164 | let tools = req.tools.map(|tools| { 165 | vec![GeminiTool { 166 | function_declarations: tools 167 | .into_iter() 168 | .map(|tool| GeminiFunctionDeclaration { 169 | name: tool.function.name, 170 | description: tool.function.description, 171 | parameters: serde_json::to_value(tool.function.parameters) 172 | .unwrap_or_default(), 173 | }) 174 | .collect(), 175 | }] 176 | }); 177 | 178 | let tool_choice = req.tool_choice.map(|choice| match choice { 179 | ToolChoice::Simple(SimpleToolChoice::None) => GeminiToolChoice::None, 180 | ToolChoice::Simple(SimpleToolChoice::Auto) => GeminiToolChoice::Auto, 181 | ToolChoice::Named(named) => GeminiToolChoice::Function(GeminiFunctionChoice { 182 | name: named.function.name, 183 | }), 184 | _ => GeminiToolChoice::None, 185 | }); 186 | 187 | Self { 188 | contents, 189 | generation_config, 190 | safety_settings: None, 191 | tools, 192 | tool_choice, 193 | } 194 | } 195 | } 196 | 197 | impl GeminiChatResponse { 198 | pub fn to_openai(self, model: String) -> ChatCompletion { 199 | let choices = self 200 | .candidates 201 | .into_iter() 202 | .enumerate() 203 | .map(|(i, candidate)| { 204 | let mut message_text = String::new(); 205 | let mut tool_calls = Vec::new(); 206 | 207 | for part in candidate.content.parts { 208 | if let Some(text) = part.text { 209 | message_text.push_str(&text); 210 | } 211 | if let Some(fc) = part.function_call { 212 | tool_calls.push(ChatMessageToolCall { 213 | id: format!("call_{}", uuid::Uuid::new_v4()), 214 | r#type: "function".to_string(), 215 | function: FunctionCall { 216 | name: fc.name, 217 | arguments: serde_json::to_string(&fc.args) 218 | .unwrap_or_else(|_| "{}".to_string()), 219 | }, 220 | }); 221 | } 222 | } 223 | 224 | ChatCompletionChoice { 225 | index: i as u32, 226 | message: ChatCompletionMessage { 227 | role: "assistant".to_string(), 228 | content: if message_text.is_empty() { 229 | None 230 | } else { 231 | Some(ChatMessageContent::String(message_text)) 232 | }, 233 | tool_calls: if tool_calls.is_empty() { 234 | None 235 | } else { 236 | Some(tool_calls) 237 | }, 238 | name: None, 239 | refusal: None, 240 | }, 241 | finish_reason: candidate.finish_reason, 242 | logprobs: None, 243 | } 244 | }) 245 | .collect(); 246 | 247 | let usage = self.usage_metadata.map_or_else( 248 | || Usage { 249 | prompt_tokens: 0, 250 | completion_tokens: 0, 251 | total_tokens: 0, 252 | completion_tokens_details: None, 253 | prompt_tokens_details: None, 254 | }, 255 | |meta| Usage { 256 | prompt_tokens: meta.prompt_token_count as u32, 257 | completion_tokens: meta.candidates_token_count as u32, 258 | total_tokens: meta.total_token_count as u32, 259 | completion_tokens_details: None, 260 | prompt_tokens_details: None, 261 | }, 262 | ); 263 | 264 | ChatCompletion { 265 | id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), 266 | object: Some("chat.completion".to_string()), 267 | created: Some(chrono::Utc::now().timestamp() as u64), 268 | model, 269 | choices, 270 | usage, 271 | system_fingerprint: None, 272 | } 273 | } 274 | } 275 | 276 | impl From for ChatCompletionChunk { 277 | fn from(chunk: VertexAIStreamChunk) -> Self { 278 | let first_candidate = chunk.candidates.first(); 279 | 280 | Self { 281 | id: uuid::Uuid::new_v4().to_string(), 282 | service_tier: None, 283 | system_fingerprint: None, 284 | created: chrono::Utc::now().timestamp(), 285 | model: String::new(), 286 | choices: vec![Choice { 287 | index: 0, 288 | logprobs: None, 289 | delta: ChoiceDelta { 290 | role: None, 291 | content: first_candidate 292 | .and_then(|c| c.content.parts.first()) 293 | .map(|p| p.text.clone().unwrap_or_default()), 294 | tool_calls: first_candidate 295 | .and_then(|c| c.tool_calls.clone()) 296 | .map(|calls| { 297 | calls 298 | .into_iter() 299 | .map(|call| ChatMessageToolCall { 300 | id: format!("call_{}", uuid::Uuid::new_v4()), 301 | r#type: "function".to_string(), 302 | function: FunctionCall { 303 | name: call.function.name, 304 | arguments: serde_json::to_string(&call.function.args) 305 | .unwrap_or_else(|_| "{}".to_string()), 306 | }, 307 | }) 308 | .collect() 309 | }), 310 | }, 311 | finish_reason: first_candidate.and_then(|c| c.finish_reason.clone()), 312 | }], 313 | usage: None, 314 | } 315 | } 316 | } 317 | -------------------------------------------------------------------------------- /src/providers/vertexai/provider.rs: -------------------------------------------------------------------------------- 1 | use super::models::{GeminiChatRequest, GeminiChatResponse, VertexAIStreamChunk}; 2 | use crate::config::models::{ModelConfig, Provider as ProviderConfig}; 3 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; 4 | use crate::models::completion::{CompletionRequest, CompletionResponse}; 5 | use crate::models::embeddings::{ 6 | Embedding, Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse, 7 | }; 8 | use crate::models::streaming::ChatCompletionChunk; 9 | use crate::models::usage::EmbeddingUsage; 10 | use crate::providers::provider::Provider; 11 | use axum::async_trait; 12 | use axum::http::StatusCode; 13 | use futures::StreamExt; 14 | use reqwest::Client; 15 | use reqwest_streams::error::{StreamBodyError, StreamBodyKind}; 16 | use reqwest_streams::JsonStreamResponse; 17 | use serde_json::json; 18 | use tracing::{debug, error}; 19 | use yup_oauth2::{ServiceAccountAuthenticator, ServiceAccountKey}; 20 | 21 | const STREAM_BUFFER_SIZE: usize = 8192; 22 | 23 | pub struct VertexAIProvider { 24 | config: ProviderConfig, 25 | http_client: Client, 26 | project_id: String, 27 | location: String, 28 | } 29 | 30 | impl VertexAIProvider { 31 | async fn get_auth_token(&self) -> Result { 32 | debug!("Getting auth token..."); 33 | 34 | // Special case for tests - return dummy token when in test mode 35 | if self 36 | .config 37 | .params 38 | .get("use_test_auth") 39 | .is_some_and(|v| v == "true") 40 | { 41 | debug!("Using test auth mode, returning dummy token"); 42 | return Ok("test-token-for-vertex-ai".to_string()); 43 | } 44 | 45 | if !self.config.api_key.is_empty() { 46 | debug!("Using API key authentication"); 47 | Ok(self.config.api_key.clone()) 48 | } else { 49 | debug!("Using service account authentication"); 50 | let key_path = self.config 51 | .params 52 | .get("credentials_path") 53 | .map(|p| p.to_string()) 54 | .or_else(|| std::env::var("GOOGLE_APPLICATION_CREDENTIALS").ok()) 55 | .expect("Either api_key, credentials_path in config, or GOOGLE_APPLICATION_CREDENTIALS environment variable must be set"); 56 | 57 | debug!("Reading service account key from: {}", key_path); 58 | let key_json = 59 | std::fs::read_to_string(key_path).expect("Failed to read service account key file"); 60 | 61 | debug!( 62 | "Service account key file content length: {}", 63 | key_json.len() 64 | ); 65 | let sa_key: ServiceAccountKey = 66 | serde_json::from_str(&key_json).expect("Failed to parse service account key"); 67 | 68 | debug!("Successfully parsed service account key"); 69 | let auth = ServiceAccountAuthenticator::builder(sa_key) 70 | .build() 71 | .await 72 | .expect("Failed to create authenticator"); 73 | 74 | debug!("Created authenticator, requesting token..."); 75 | let scopes = &["https://www.googleapis.com/auth/cloud-platform"]; 76 | let token = auth.token(scopes).await.map_err(|e| { 77 | error!("Failed to get access token: {}", e); 78 | StatusCode::INTERNAL_SERVER_ERROR 79 | })?; 80 | 81 | debug!("Successfully obtained token"); 82 | Ok(token.token().unwrap_or_default().to_string()) 83 | } 84 | } 85 | 86 | pub fn validate_location(location: &str) -> Result { 87 | let sanitized = location 88 | .chars() 89 | .filter(|c| c.is_alphanumeric() || *c == '-') 90 | .collect::(); 91 | 92 | if sanitized.is_empty() || sanitized != location { 93 | Err(format!( 94 | "Invalid location provided: '{}'. Location must contain only alphanumeric characters and hyphens.", 95 | location 96 | )) 97 | } else { 98 | Ok(sanitized) 99 | } 100 | } 101 | } 102 | 103 | #[async_trait] 104 | impl Provider for VertexAIProvider { 105 | fn new(config: &ProviderConfig) -> Self { 106 | let project_id = config 107 | .params 108 | .get("project_id") 109 | .expect("project_id is required for VertexAI provider") 110 | .to_string(); 111 | let location_str = config 112 | .params 113 | .get("location") 114 | .expect("location is required for VertexAI provider") 115 | .to_string(); 116 | 117 | let location = Self::validate_location(&location_str) 118 | .expect("Invalid location provided in configuration"); 119 | 120 | Self { 121 | config: config.clone(), 122 | http_client: Client::new(), 123 | project_id, 124 | location, 125 | } 126 | } 127 | 128 | fn key(&self) -> String { 129 | self.config.key.clone() 130 | } 131 | 132 | fn r#type(&self) -> String { 133 | "vertexai".to_string() 134 | } 135 | 136 | async fn chat_completions( 137 | &self, 138 | payload: ChatCompletionRequest, 139 | _model_config: &ModelConfig, 140 | ) -> Result { 141 | let auth_token = self.get_auth_token().await?; 142 | let endpoint_suffix = if payload.stream.unwrap_or(false) { 143 | "streamGenerateContent" 144 | } else { 145 | "generateContent" 146 | }; 147 | 148 | // Determine if we're in test mode 149 | let is_test_mode = self 150 | .config 151 | .params 152 | .get("use_test_auth") 153 | .map_or(false, |v| v == "true"); 154 | 155 | let endpoint = if is_test_mode { 156 | // In test mode, use the mock server endpoint 157 | let test_endpoint = std::env::var("VERTEXAI_TEST_ENDPOINT") 158 | .unwrap_or_else(|_| "http://localhost:8080".to_string()); 159 | debug!("Using test endpoint: {}", test_endpoint); 160 | test_endpoint 161 | } else { 162 | // Normal mode, use the real endpoint 163 | let service_endpoint = format!("{}-aiplatform.googleapis.com", self.location); 164 | let full_model_path = format!( 165 | "projects/{}/locations/{}/publishers/google/models/{}", 166 | self.project_id, self.location, payload.model 167 | ); 168 | format!( 169 | "https://{}/v1/{}:{}", 170 | service_endpoint, full_model_path, endpoint_suffix 171 | ) 172 | }; 173 | 174 | let request_body = GeminiChatRequest::from(payload.clone()); 175 | debug!("Sending request to endpoint: {}", endpoint); 176 | debug!( 177 | "Request Body: {}", 178 | serde_json::to_string(&request_body) 179 | .unwrap_or_else(|e| format!("Failed to serialize request: {}", e)) 180 | ); 181 | 182 | let response_result = self 183 | .http_client 184 | .post(&endpoint) 185 | .bearer_auth(auth_token) 186 | .json(&request_body) 187 | .send() 188 | .await; 189 | 190 | let response = match response_result { 191 | Ok(resp) => resp, 192 | Err(e) => { 193 | error!("VertexAI API request failed before getting response: {}", e); 194 | return Err(StatusCode::INTERNAL_SERVER_ERROR); 195 | } 196 | }; 197 | 198 | let status = response.status(); 199 | debug!("Response status: {}", status); 200 | 201 | if status.is_success() { 202 | if payload.stream.unwrap_or(false) { 203 | let model = payload.model.clone(); 204 | let stream = response 205 | .json_array_stream::(STREAM_BUFFER_SIZE) 206 | .map(move |result| { 207 | result 208 | .map(|chunk| { 209 | let mut completion_chunk: ChatCompletionChunk = chunk.into(); 210 | completion_chunk.model = model.clone(); 211 | completion_chunk 212 | }) 213 | .map_err(|e| { 214 | StreamBodyError::new( 215 | StreamBodyKind::CodecError, 216 | Some(Box::new(e)), 217 | None, 218 | ) 219 | }) 220 | }); 221 | 222 | Ok(ChatCompletionResponse::Stream(Box::pin(stream))) 223 | } else { 224 | let response_text = response.text().await.map_err(|e| { 225 | error!("Failed to get response text: {}", e); 226 | StatusCode::INTERNAL_SERVER_ERROR 227 | })?; 228 | debug!("Raw VertexAI Response Body: {}", response_text); 229 | 230 | // In test mode, we may be getting an array directly from the mock server 231 | // since we saved multiple interactions in a single array 232 | if is_test_mode && response_text.trim().starts_with('[') { 233 | debug!("Test mode detected array response, extracting first item"); 234 | let array: Vec = serde_json::from_str(&response_text) 235 | .map_err(|e| { 236 | error!("Failed to parse test response as array: {}", e); 237 | StatusCode::INTERNAL_SERVER_ERROR 238 | })?; 239 | 240 | if let Some(first_item) = array.first() { 241 | // Convert the first item back to JSON string 242 | let item_str = serde_json::to_string(first_item).unwrap_or_default(); 243 | debug!("Using first item from array: {}", item_str); 244 | 245 | // Parse as GeminiChatResponse 246 | let gemini_response: GeminiChatResponse = serde_json::from_str(&item_str) 247 | .map_err(|e| { 248 | error!("Failed to parse test item as GeminiChatResponse: {}", e); 249 | StatusCode::INTERNAL_SERVER_ERROR 250 | })?; 251 | 252 | return Ok(ChatCompletionResponse::NonStream( 253 | gemini_response.to_openai(payload.model), 254 | )); 255 | } 256 | } 257 | 258 | // Regular parsing for normal API responses 259 | let gemini_response: GeminiChatResponse = serde_json::from_str(&response_text) 260 | .map_err(|e| { 261 | error!( 262 | "Failed to parse response as GeminiChatResponse. Error: {}, Raw Response: {}", 263 | e, 264 | response_text 265 | ); 266 | StatusCode::INTERNAL_SERVER_ERROR 267 | })?; 268 | 269 | Ok(ChatCompletionResponse::NonStream( 270 | gemini_response.to_openai(payload.model), 271 | )) 272 | } 273 | } else { 274 | let error_text = response.text().await.unwrap_or_default(); 275 | error!( 276 | "VertexAI API request failed with status {}. Error body: {}", 277 | status, error_text 278 | ); 279 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 280 | } 281 | } 282 | 283 | async fn completions( 284 | &self, 285 | _payload: CompletionRequest, 286 | _model_config: &ModelConfig, 287 | ) -> Result { 288 | unimplemented!( 289 | "Text completions are not supported for Vertex AI. Use chat_completions instead." 290 | ) 291 | } 292 | 293 | async fn embeddings( 294 | &self, 295 | payload: EmbeddingsRequest, 296 | _model_config: &ModelConfig, 297 | ) -> Result { 298 | let auth_token = self.get_auth_token().await?; 299 | 300 | // Determine if we're in test mode 301 | let is_test_mode = self 302 | .config 303 | .params 304 | .get("use_test_auth") 305 | .map_or(false, |v| v == "true"); 306 | 307 | let endpoint = if is_test_mode { 308 | // In test mode, use the mock server endpoint 309 | let test_endpoint = std::env::var("VERTEXAI_TEST_ENDPOINT") 310 | .unwrap_or_else(|_| "http://localhost:8080".to_string()); 311 | debug!("Using test endpoint for embeddings: {}", test_endpoint); 312 | test_endpoint 313 | } else { 314 | // Normal mode, use the real endpoint 315 | format!( 316 | "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict", 317 | self.location, self.project_id, self.location, payload.model 318 | ) 319 | }; 320 | 321 | let response = self 322 | .http_client 323 | .post(&endpoint) 324 | .bearer_auth(auth_token) 325 | .json(&json!({ 326 | "instances": match payload.input { 327 | EmbeddingsInput::Single(text) => vec![json!({"content": text})], 328 | EmbeddingsInput::Multiple(texts) => texts.into_iter() 329 | .map(|text| json!({"content": text})) 330 | .collect::>(), 331 | EmbeddingsInput::SingleTokenIds(tokens) => vec![json!({"content": tokens.iter().map(|t| t.to_string()).collect::>().join(" ")})], 332 | EmbeddingsInput::MultipleTokenIds(token_arrays) => token_arrays.into_iter() 333 | .map(|tokens| json!({"content": tokens.iter().map(|t| t.to_string()).collect::>().join(" ")})) 334 | .collect::>(), 335 | }, 336 | "parameters": { 337 | "autoTruncate": true 338 | } 339 | })) 340 | .send() 341 | .await 342 | .map_err(|e| { 343 | error!("VertexAI API request error: {}", e); 344 | StatusCode::INTERNAL_SERVER_ERROR 345 | })?; 346 | 347 | let status = response.status(); 348 | debug!("Embeddings response status: {}", status); 349 | 350 | if status.is_success() { 351 | let response_text = response.text().await.map_err(|e| { 352 | error!("Failed to get response text: {}", e); 353 | StatusCode::INTERNAL_SERVER_ERROR 354 | })?; 355 | debug!("Embeddings response body: {}", response_text); 356 | 357 | // In test mode, we may be getting an array directly from the mock server 358 | // since we saved multiple interactions in a single array 359 | if is_test_mode && response_text.trim().starts_with('[') { 360 | debug!("Test mode detected array response for embeddings, extracting first item"); 361 | let array: Vec = 362 | serde_json::from_str(&response_text).map_err(|e| { 363 | error!("Failed to parse test response as array: {}", e); 364 | StatusCode::INTERNAL_SERVER_ERROR 365 | })?; 366 | 367 | if let Some(first_item) = array.first() { 368 | // Use the first item from the array as the response 369 | return Ok(EmbeddingsResponse { 370 | object: "list".to_string(), 371 | data: first_item["data"] 372 | .as_array() 373 | .unwrap_or(&vec![]) 374 | .iter() 375 | .enumerate() 376 | .map(|(i, emb)| Embeddings { 377 | object: "embedding".to_string(), 378 | embedding: Embedding::Float( 379 | emb["embedding"] 380 | .as_array() 381 | .unwrap_or(&vec![]) 382 | .iter() 383 | .filter_map(|v| v.as_f64().map(|f| f as f32)) 384 | .collect::>(), 385 | ), 386 | index: i, 387 | }) 388 | .collect(), 389 | model: payload.model, 390 | usage: EmbeddingUsage { 391 | prompt_tokens: Some(0), 392 | total_tokens: Some(0), 393 | }, 394 | }); 395 | } 396 | } 397 | 398 | // Normal processing for regular API responses 399 | let gemini_response: serde_json::Value = 400 | serde_json::from_str(&response_text).map_err(|e| { 401 | error!("Failed to parse response as JSON: {}", e); 402 | StatusCode::INTERNAL_SERVER_ERROR 403 | })?; 404 | 405 | // Extract embeddings from updated response format 406 | let embeddings = gemini_response["predictions"] 407 | .as_array() 408 | .ok_or(StatusCode::INTERNAL_SERVER_ERROR)? 409 | .iter() 410 | .enumerate() 411 | .map(|(i, pred)| Embeddings { 412 | object: "embedding".to_string(), 413 | embedding: Embedding::Float( 414 | pred["embeddings"]["values"] 415 | .as_array() 416 | .unwrap_or(&vec![]) 417 | .iter() 418 | .filter_map(|v| v.as_f64().map(|f| f as f32)) 419 | .collect::>(), 420 | ), 421 | index: i, 422 | }) 423 | .collect(); 424 | 425 | Ok(EmbeddingsResponse { 426 | object: "list".to_string(), 427 | data: embeddings, 428 | model: payload.model, 429 | usage: EmbeddingUsage { 430 | prompt_tokens: Some(0), 431 | total_tokens: Some(0), 432 | }, 433 | }) 434 | } else { 435 | let error_text = response.text().await.unwrap_or_default(); 436 | error!("VertexAI API request error: {}", error_text); 437 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) 438 | } 439 | } 440 | } 441 | 442 | #[cfg(test)] 443 | impl VertexAIProvider { 444 | pub fn with_test_client(config: &ProviderConfig, client: reqwest::Client) -> Self { 445 | let project_id = config 446 | .params 447 | .get("project_id") 448 | .cloned() 449 | .unwrap_or_else(|| "".to_string()); 450 | 451 | let location_str = config 452 | .params 453 | .get("location") 454 | .cloned() 455 | .unwrap_or_else(|| "".to_string()); 456 | 457 | let location = Self::validate_location(&location_str) 458 | .expect("Invalid location provided for test client configuration"); 459 | 460 | Self { 461 | config: config.clone(), 462 | http_client: client, 463 | project_id, 464 | location, 465 | } 466 | } 467 | } 468 | -------------------------------------------------------------------------------- /src/routes.rs: -------------------------------------------------------------------------------- 1 | use crate::{pipelines::pipeline::create_pipeline, state::AppState}; 2 | use axum::{extract::Request, routing::get, Router}; 3 | use axum_prometheus::PrometheusMetricLayerBuilder; 4 | use std::collections::HashMap; 5 | use std::sync::Arc; 6 | use tower::steer::Steer; 7 | 8 | pub fn create_router(state: Arc) -> Router { 9 | let (prometheus_layer, metric_handle) = PrometheusMetricLayerBuilder::new() 10 | .with_ignore_patterns(&["/metrics", "/health"]) 11 | .with_prefix("traceloop_hub") 12 | .with_default_metrics() 13 | .build_pair(); 14 | 15 | let mut pipeline_idxs = HashMap::new(); 16 | let mut routers = Vec::new(); 17 | 18 | // Sort pipelines to ensure default is first 19 | let mut sorted_pipelines: Vec<_> = state.config.pipelines.clone(); 20 | sorted_pipelines.sort_by_key(|p| p.name != "default"); // "default" will come first since false < true 21 | 22 | for pipeline in sorted_pipelines { 23 | let name = pipeline.name.clone(); 24 | pipeline_idxs.insert(name, routers.len()); 25 | routers.push(create_pipeline(&pipeline, &state.model_registry)); 26 | } 27 | 28 | let pipeline_router = Steer::new(routers, move |req: &Request, _services: &[_]| { 29 | *req.headers() 30 | .get("x-traceloop-pipeline") 31 | .and_then(|h| h.to_str().ok()) 32 | .and_then(|name| pipeline_idxs.get(name)) 33 | .unwrap_or(&0) 34 | }); 35 | 36 | Router::new() 37 | .nest_service("/api/v1", pipeline_router) 38 | .route("/health", get(|| async { "Working!" })) 39 | .route("/metrics", get(|| async move { metric_handle.render() })) 40 | .layer(prometheus_layer) 41 | .with_state(state) 42 | } 43 | -------------------------------------------------------------------------------- /src/state.rs: -------------------------------------------------------------------------------- 1 | use crate::ai_models::registry::ModelRegistry; 2 | use crate::config::models::Config; 3 | use crate::providers::registry::ProviderRegistry; 4 | use anyhow::Result; 5 | use std::sync::Arc; 6 | 7 | #[derive(Clone)] 8 | pub struct AppState { 9 | pub config: Arc, 10 | pub provider_registry: Arc, 11 | pub model_registry: Arc, 12 | } 13 | 14 | impl AppState { 15 | pub fn new(config: Config) -> Result { 16 | let provider_registry = Arc::new(ProviderRegistry::new(&config.providers)?); 17 | let model_registry = Arc::new(ModelRegistry::new( 18 | &config.models, 19 | provider_registry.clone(), 20 | )?); 21 | 22 | Ok(Self { 23 | config: Arc::new(config), 24 | provider_registry, 25 | model_registry, 26 | }) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /tests/cassettes/vertexai/chat_completions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "chatcmpl-0f8acb38-a442-4ada-aa51-6ad8c5f91dc9", 4 | "object": "chat.completion", 5 | "created": 1747062082, 6 | "model": "gemini-2.0-flash-exp", 7 | "choices": [ 8 | { 9 | "index": 0, 10 | "message": { 11 | "role": "assistant", 12 | "content": "I am doing well, thank you for asking! As a large language model, I don't experience emotions or feelings like humans do, but I am functioning optimally and ready to assist you. How can I help you today?\n" 13 | } 14 | } 15 | ], 16 | "usage": { 17 | "prompt_tokens": 0, 18 | "completion_tokens": 0, 19 | "total_tokens": 0 20 | }, 21 | "system_fingerprint": null 22 | }, 23 | { 24 | "id": "chatcmpl-45af8b98-b836-4ba1-9d91-7a636f40da89", 25 | "object": "chat.completion", 26 | "created": 1747062191, 27 | "model": "gemini-2.0-flash-exp", 28 | "choices": [ 29 | { 30 | "index": 0, 31 | "message": { 32 | "role": "assistant", 33 | "content": "I am doing well, thank you for asking! As a large language model, I don't experience feelings or emotions in the same way humans do, but I am functioning optimally and ready to assist you with your requests. How can I help you today?\n" 34 | } 35 | } 36 | ], 37 | "usage": { 38 | "prompt_tokens": 0, 39 | "completion_tokens": 0, 40 | "total_tokens": 0 41 | }, 42 | "system_fingerprint": null 43 | }, 44 | { 45 | "id": "chatcmpl-5ead557a-f2ee-4062-a029-f302d6313b70", 46 | "object": "chat.completion", 47 | "created": 1747064691, 48 | "model": "gemini-2.0-flash-exp", 49 | "choices": [ 50 | { 51 | "index": 0, 52 | "message": { 53 | "role": "assistant", 54 | "content": "I am doing well, thank you for asking! As a large language model, I don't experience emotions or feelings like humans do, but I am functioning optimally and ready to assist you. How can I help you today?\n" 55 | } 56 | } 57 | ], 58 | "usage": { 59 | "prompt_tokens": 0, 60 | "completion_tokens": 0, 61 | "total_tokens": 0 62 | }, 63 | "system_fingerprint": null 64 | }, 65 | { 66 | "id": "chatcmpl-0f3f7206-21cb-4832-8936-3af45bbfb9a1", 67 | "object": "chat.completion", 68 | "created": 1747064863, 69 | "model": "gemini-2.0-flash-exp", 70 | "choices": [ 71 | { 72 | "index": 0, 73 | "message": { 74 | "role": "assistant", 75 | "content": "I am doing well, thank you for asking! As a large language model, I don't experience emotions or feelings in the same way humans do, but I am functioning optimally and ready to assist you. How can I help you today?\n" 76 | } 77 | } 78 | ], 79 | "usage": { 80 | "prompt_tokens": 0, 81 | "completion_tokens": 0, 82 | "total_tokens": 0 83 | }, 84 | "system_fingerprint": null 85 | }, 86 | { 87 | "id": "chatcmpl-f56cd4f4-f4a9-4eb1-8c35-cd61a7fb2e7d", 88 | "object": "chat.completion", 89 | "created": 1747065205, 90 | "model": "gemini-2.0-flash-exp", 91 | "choices": [ 92 | { 93 | "index": 0, 94 | "message": { 95 | "role": "assistant", 96 | "content": "I am doing well, thank you for asking! How are you today?\n" 97 | } 98 | } 99 | ], 100 | "usage": { 101 | "prompt_tokens": 0, 102 | "completion_tokens": 0, 103 | "total_tokens": 0 104 | }, 105 | "system_fingerprint": null 106 | } 107 | ] -------------------------------------------------------------------------------- /tests/cassettes/vertexai/chat_completions_with_tools.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "chatcmpl-de1cfeec-6d9f-475e-ae6d-7dead4cd260e", 4 | "object": "chat.completion", 5 | "created": 1747064690, 6 | "model": "gemini-2.0-flash-exp", 7 | "choices": [ 8 | { 9 | "index": 0, 10 | "message": { 11 | "role": "assistant", 12 | "tool_calls": [ 13 | { 14 | "id": "call_3b2927f2-e7a3-4a18-b9ed-ce1efc23d3fb", 15 | "function": { 16 | "arguments": "{\"location\":\"San Francisco\"}", 17 | "name": "get_weather" 18 | }, 19 | "type": "function" 20 | } 21 | ] 22 | } 23 | } 24 | ], 25 | "usage": { 26 | "prompt_tokens": 0, 27 | "completion_tokens": 0, 28 | "total_tokens": 0 29 | }, 30 | "system_fingerprint": null 31 | }, 32 | { 33 | "id": "chatcmpl-2d15eaf9-36da-43b1-9e83-75c465d6938c", 34 | "object": "chat.completion", 35 | "created": 1747064802, 36 | "model": "gemini-2.0-flash-exp", 37 | "choices": [ 38 | { 39 | "index": 0, 40 | "message": { 41 | "role": "assistant", 42 | "tool_calls": [ 43 | { 44 | "id": "call_8cf02207-a5b9-4cc5-b240-8932f380f8df", 45 | "function": { 46 | "arguments": "{\"location\":\"San Francisco\"}", 47 | "name": "get_weather" 48 | }, 49 | "type": "function" 50 | } 51 | ] 52 | } 53 | } 54 | ], 55 | "usage": { 56 | "prompt_tokens": 0, 57 | "completion_tokens": 0, 58 | "total_tokens": 0 59 | }, 60 | "system_fingerprint": null 61 | }, 62 | { 63 | "id": "chatcmpl-4df1d154-9b2c-40cc-882a-83f845429014", 64 | "object": "chat.completion", 65 | "created": 1747065205, 66 | "model": "gemini-2.0-flash-exp", 67 | "choices": [ 68 | { 69 | "index": 0, 70 | "message": { 71 | "role": "assistant", 72 | "tool_calls": [ 73 | { 74 | "id": "call_e069a236-6ea2-4660-a189-2fe1a0439222", 75 | "function": { 76 | "arguments": "{\"location\":\"San Francisco\"}", 77 | "name": "get_weather" 78 | }, 79 | "type": "function" 80 | } 81 | ] 82 | } 83 | } 84 | ], 85 | "usage": { 86 | "prompt_tokens": 0, 87 | "completion_tokens": 0, 88 | "total_tokens": 0 89 | }, 90 | "system_fingerprint": null 91 | } 92 | ] -------------------------------------------------------------------------------- /tests/cassettes/vertexai/completions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "choices": [ 4 | { 5 | "finish_reason": "STOP", 6 | "index": 0, 7 | "text": "Once upon a time, in a village nestled between rolling hills and a whispering river, lived a young girl named Elara. Elara wasn't like the other children. While they played with dolls and chased butterflies, Elara spent her days lost in the pages of dusty old books, her imagination soaring with tales of faraway lands and fantastical creatures. \n\nWhat would you like to happen next? \n" 8 | } 9 | ], 10 | "created": 1735633515, 11 | "id": "chatcmpl-99881c18-33ac-430b-b788-4a5cea4dce3e", 12 | "model": "gemini-1.5-flash", 13 | "object": "text_completion", 14 | "usage": { 15 | "completion_tokens": 0, 16 | "prompt_tokens": 0, 17 | "total_tokens": 0 18 | } 19 | } 20 | ] --------------------------------------------------------------------------------