├── .cz.toml
├── .github
└── workflows
│ ├── ci.yml
│ ├── docker.yml
│ └── release.yml
├── .gitignore
├── CHANGELOG.md
├── Cargo.lock
├── Cargo.toml
├── Dockerfile
├── LICENSE
├── README.md
├── config-example.yaml
├── helm
├── Chart.yaml
├── templates
│ ├── _helpers.tpl
│ ├── configmap.yaml
│ ├── deployment.yaml
│ └── service.yaml
└── values.yaml
├── img
├── logo-dark.png
└── logo-light.png
├── src
├── ai_models
│ ├── instance.rs
│ ├── mod.rs
│ └── registry.rs
├── config
│ ├── constants.rs
│ ├── lib.rs
│ ├── mod.rs
│ └── models.rs
├── lib.rs
├── main.rs
├── models
│ ├── chat.rs
│ ├── completion.rs
│ ├── content.rs
│ ├── embeddings.rs
│ ├── logprob.rs
│ ├── mod.rs
│ ├── response_format.rs
│ ├── streaming.rs
│ ├── tool_calls.rs
│ ├── tool_choice.rs
│ ├── tool_definition.rs
│ └── usage.rs
├── pipelines
│ ├── mod.rs
│ ├── otel.rs
│ └── pipeline.rs
├── providers
│ ├── anthropic
│ │ ├── mod.rs
│ │ ├── models.rs
│ │ └── provider.rs
│ ├── azure
│ │ ├── mod.rs
│ │ └── provider.rs
│ ├── bedrock
│ │ ├── logs
│ │ │ ├── ai21_j2_mid_v1_completions.json
│ │ │ ├── ai21_jamba_1_5_mini_v1_0_chat_completions.json
│ │ │ ├── amazon_titan_embed_text_v2_0_embeddings.json
│ │ │ ├── anthropic_claude_3_haiku_20240307_v1_0_chat_completion.json
│ │ │ └── us_amazon_nova_lite_v1_0_chat_completion.json
│ │ ├── mod.rs
│ │ ├── models.rs
│ │ ├── provider.rs
│ │ └── test.rs
│ ├── mod.rs
│ ├── openai
│ │ ├── mod.rs
│ │ └── provider.rs
│ ├── provider.rs
│ ├── registry.rs
│ └── vertexai
│ │ ├── mod.rs
│ │ ├── models.rs
│ │ ├── provider.rs
│ │ └── tests.rs
├── routes.rs
└── state.rs
└── tests
└── cassettes
└── vertexai
├── chat_completions.json
├── chat_completions_with_tools.json
├── completions.json
└── embeddings.json
/.cz.toml:
--------------------------------------------------------------------------------
1 | [tool.commitizen]
2 | name = "cz_conventional_commits"
3 | tag_format = "v$version"
4 | major_version_zero = true
5 | update_changelog_on_bump = true
6 | version = "0.0.0"
7 | version_files = ["Cargo.toml"]
8 | version_provider = "cargo"
9 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Build and test
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 |
9 | env:
10 | CARGO_TERM_COLOR: always
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Check formatting
19 | run: cargo fmt --check
20 | - name: Check clippy
21 | run: cargo clippy -- -D warnings
22 | - name: Build
23 | run: cargo build --verbose
24 | - name: Run tests
25 | run: cargo test --verbose
26 |
--------------------------------------------------------------------------------
/.github/workflows/docker.yml:
--------------------------------------------------------------------------------
1 | name: Publish Docker Images
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v*"
7 | workflow_dispatch:
8 |
9 | jobs:
10 | publish:
11 | name: Publish Docker Images
12 | runs-on: ubuntu-latest
13 | permissions:
14 | id-token: write
15 | contents: read
16 | packages: write
17 |
18 | steps:
19 | - name: Checkout
20 | uses: actions/checkout@v4
21 |
22 | - name: Set up QEMU
23 | uses: docker/setup-qemu-action@v3
24 | - name: Set up Docker Buildx
25 | uses: docker/setup-buildx-action@v3
26 |
27 | - name: Log in to the GitHub Container registry
28 | uses: docker/login-action@v2
29 | with:
30 | registry: ghcr.io
31 | username: ${{ github.actor }}
32 | password: ${{ secrets.GITHUB_TOKEN }}
33 |
34 | - name: Login to Docker Hub
35 | uses: docker/login-action@v2
36 | with:
37 | username: ${{ secrets.DOCKERHUB_USERNAME }}
38 | password: ${{ secrets.DOCKERHUB_TOKEN }}
39 |
40 | - name: Extract metadata for Docker
41 | id: docker-metadata
42 | uses: docker/metadata-action@v4
43 | with:
44 | images: |
45 | ghcr.io/traceloop/hub # GitHub
46 | traceloop/hub # Docker Hub
47 | tags: |
48 | type=sha
49 | type=semver,pattern={{version}}
50 | type=semver,pattern={{major}}.{{minor}}
51 | - name: Build and push Docker image
52 | uses: docker/build-push-action@v4
53 | with:
54 | context: .
55 | push: true
56 | tags: ${{ steps.docker-metadata.outputs.tags }}
57 | labels: ${{ steps.docker-metadata.outputs.labels }}
58 | platforms: |
59 | linux/amd64, linux/arm64
60 | deploy:
61 | name: Deploy to Traceloop
62 | runs-on: ubuntu-latest
63 | needs: publish
64 | steps:
65 | - name: Install Octopus CLI
66 | uses: OctopusDeploy/install-octopus-cli-action@v3
67 | with:
68 | version: "*"
69 |
70 | - name: Create Octopus Release
71 | env:
72 | OCTOPUS_API_KEY: ${{ secrets.OCTOPUS_API_KEY }}
73 | OCTOPUS_URL: ${{ secrets.OCTOPUS_URL }}
74 | OCTOPUS_SPACE: ${{ secrets.OCTOPUS_SPACE }}
75 | run: octopus release create --project hub --version=sha-${GITHUB_SHA::7} --packageVersion=sha-${GITHUB_SHA::7} --no-prompt
76 |
77 | - name: Deploy Octopus release
78 | env:
79 | OCTOPUS_API_KEY: ${{ secrets.OCTOPUS_API_KEY }}
80 | OCTOPUS_URL: ${{ secrets.OCTOPUS_URL }}
81 | OCTOPUS_SPACE: ${{ secrets.OCTOPUS_SPACE }}
82 | run: octopus release deploy --project hub --version=sha-${GITHUB_SHA::7} --environment Staging --no-prompt
83 | push-helm-chart:
84 | name: Push Helm Chart to Dockerhub
85 | runs-on: ubuntu-latest
86 | needs: publish
87 | permissions:
88 | contents: write
89 |
90 | steps:
91 | - name: Checkout
92 | uses: actions/checkout@v4
93 |
94 | - name: Get Chart Version
95 | id: chartVersion
96 | run: |
97 | CHART_VERSION=$(grep '^version:' helm/Chart.yaml | awk '{print $2}')
98 | echo "CHART_VERSION=$CHART_VERSION"
99 | echo "chart_version=$CHART_VERSION" >> $GITHUB_OUTPUT
100 |
101 | - name: Get Chart Name
102 | id: chartName
103 | run: |
104 | CHART_NAME=$(grep '^name:' helm/Chart.yaml | awk '{print $2}')
105 | echo "CHART_NAME=$CHART_NAME"
106 | echo "chart_name=$CHART_NAME" >> $GITHUB_OUTPUT
107 |
108 | - name: Login to Docker Hub as OCI registry
109 | run: |
110 | echo "${{ secrets.DOCKERHUB_TOKEN }}" | helm registry login registry-1.docker.io \
111 | --username "${{ secrets.DOCKERHUB_USERNAME }}" \
112 | --password-stdin
113 |
114 | - name: Package Helm chart
115 | run: |
116 | helm package helm/
117 |
118 | - name: Push Helm Chart
119 | run: |
120 | helm push "${{ steps.chartName.outputs.chart_name }}-${{ steps.chartVersion.outputs.chart_version }}.tgz" oci://registry-1.docker.io/traceloop
121 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release a New Version
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | jobs:
7 | release:
8 | name: Bump Versions
9 | runs-on: ubuntu-latest
10 |
11 | steps:
12 | - uses: actions/checkout@v4
13 | with:
14 | persist-credentials: false
15 | fetch-depth: 0
16 |
17 | - id: cz
18 | name: Bump Version, Create Tag and Changelog
19 | uses: commitizen-tools/commitizen-action@master
20 | with:
21 | github_token: ${{ secrets.GH_ACCESS_TOKEN }}
22 | changelog_increment_filename: body.md
23 |
24 | - name: Create Release
25 | uses: softprops/action-gh-release@v2
26 | with:
27 | body_path: "body.md"
28 | tag_name: ${{ env.REVISION }}
29 | env:
30 | GITHUB_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
31 |
32 | - name: Print Version
33 | run: echo "Bumped to version ${{ steps.cz.outputs.version }}"
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generated by Cargo
2 | # will have compiled files and executables
3 | debug/
4 | target/
5 |
6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
8 | Cargo.lock
9 |
10 | # These are backup files generated by rustfmt
11 | **/*.rs.bk
12 |
13 | # MSVC Windows builds of rustc generate these, which store debugging information
14 | *.pdb
15 |
16 | config.yaml
17 |
18 | # RustRover
19 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
20 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
21 | # and can be added to the global gitignore or merged into this file. For a more nuclear
22 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
23 | #.idea/
24 | .vscode/
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## v0.4.3 (2025-05-29)
2 |
3 | ### Fix
4 |
5 | - make general optional again (#43)
6 |
7 | ## v0.4.2 (2025-05-22)
8 |
9 | ### Fix
10 |
11 | - **tracing**: support disabling tracing of prompts and completions (#42)
12 |
13 | ## v0.4.1 (2025-05-20)
14 |
15 | ### Fix
16 |
17 | - **openai**: support custom base URL (#40)
18 | - **azure**: add support for custom base URL in AzureProvider endpoint (#41)
19 |
20 | ## v0.4.0 (2025-05-16)
21 |
22 | ### Feat
23 |
24 | - **provider**: add Google VertexAI support (#24)
25 | - support AWS bedrock base models (#25)
26 | - add max_completion_tokens to ChatCompletionRequest (#36)
27 | - support structured output (#33)
28 |
29 | ### Fix
30 |
31 | - replace eprintln with tracing info for API request errors in Azure and OpenAI providers (#37)
32 | - make optional json_schema field to ResponseFormat (#35)
33 |
34 | ## v0.3.0 (2025-03-04)
35 |
36 | ### Feat
37 |
38 | - add logprobs and top_logprobs options to ChatCompletionRequest (#27)
39 |
40 | ### Fix
41 |
42 | - **cd**: correct docker hub secrets (#31)
43 | - **azure**: embeddings structs improvement (#29)
44 | - add proper error logging for azure and openai calls (#18)
45 | - **anthropic**: separate system from messages (#17)
46 |
47 | ## v0.2.1 (2024-12-01)
48 |
49 | ### Fix
50 |
51 | - tool call support (#16)
52 | - restructure providers, separate request/response conversion (#15)
53 |
54 | ## v0.2.0 (2024-11-25)
55 |
56 | ### Feat
57 |
58 | - **openai**: support streaming (#10)
59 | - add prometheus metrics (#13)
60 | - **cd**: deploy to traceloop on workflow distpatch (#11)
61 |
62 | ### Fix
63 |
64 | - config file path from env var instead of command argument (#12)
65 |
66 | ## v0.1.0 (2024-11-16)
67 |
68 | ### Feat
69 |
70 | - otel support (#7)
71 | - implement pipeline steering logic (#5)
72 | - dynamic pipeline routing (#4)
73 | - azure openai provider (#3)
74 | - initial completions and embeddings routes with openai and anthropic providers (#1)
75 |
76 | ### Fix
77 |
78 | - dockerfile and release pipeline (#2)
79 | - make anthropic work (#8)
80 | - cleanups; lint warnings fail CI (#9)
81 | - missing model name in response; 404 for model not found (#6)
82 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "hub"
3 | version = "0.4.3"
4 | edition = "2021"
5 |
6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7 |
8 | [dependencies]
9 | axum = "0.7"
10 | tokio = { version = "1.0", features = ["full"] }
11 | serde = { version = "1.0", features = ["derive"] }
12 | reqwest = { version = "0.12", features = ["json", "stream"] }
13 | serde_json = "1.0"
14 | tracing = "0.1"
15 | tracing-subscriber = "0.3"
16 | serde_yaml = "0.9"
17 | tower = { version = "0.5.1", features = ["full"] }
18 | anyhow = "1.0.92"
19 | chrono = "0.4.38"
20 | opentelemetry = { version = "0.27", default-features = false, features = [
21 | "trace",
22 | ] }
23 | opentelemetry_sdk = { version = "0.27", default-features = false, features = [
24 | "trace",
25 | "rt-tokio",
26 | ] }
27 | opentelemetry-semantic-conventions = { version = "0.27.0", features = [
28 | "semconv_experimental",
29 | ] }
30 | opentelemetry-otlp = { version = "0.27.0", features = [
31 | "http-proto",
32 | "reqwest-client",
33 | "reqwest-rustls",
34 | ] }
35 | axum-prometheus = "0.7.0"
36 | reqwest-streams = { version = "0.8.1", features = ["json"] }
37 | futures = "0.3.31"
38 | async-stream = "0.3.6"
39 | yup-oauth2 = "8.3.0"
40 | aws-sdk-bedrockruntime = "1.66.0"
41 | aws-config = "1.5.12"
42 | aws-credential-types = { version = "1.2.1", features = [
43 | "hardcoded-credentials",
44 | ] }
45 | http = "1.1.0"
46 | aws-smithy-runtime = { version = "1.7.6", features = ["test-util"] }
47 | aws-smithy-types = "1.2.11"
48 | aws-types = "1.3.3"
49 | uuid = { version = "1.16.0", features = ["v4"] }
50 |
51 | [dev-dependencies]
52 | surf = "2.3.2"
53 | surf-vcr = "0.2.0"
54 | wiremock = "0.5"
55 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM rust:1.82-bookworm AS builder
2 |
3 | WORKDIR /app
4 | COPY . .
5 | RUN cargo build --release --bin hub
6 |
7 | FROM debian:bookworm-slim AS runtime
8 | RUN apt-get update && apt-get install -y openssl ca-certificates
9 | WORKDIR /app
10 | COPY --from=builder /app/target/release/hub /usr/local/bin
11 | WORKDIR /etc
12 |
13 | ENV PORT 3000
14 | EXPOSE 3000
15 |
16 | ENTRYPOINT ["/usr/local/bin/hub"]
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
Open-source, high-performance LLM gateway written in Rust. Connect to any LLM provider with a single API. Observability Included.
11 |
12 |
19 |
20 |
44 |
45 | Hub is a next generation smart proxy for LLM applications. It centralizes control and tracing of all LLM calls and traces.
46 | It's built in Rust so it's fast and efficient. It's completely open-source and free to use.
47 |
48 | Built and maintained by Traceloop under the Apache 2.0 license.
49 |
50 | ## 🚀 Getting Started
51 |
52 | Make sure to copy a `config.yaml` file from `config-example.yaml` and set the correct values, following the [configuration](https://www.traceloop.com/docs/hub/configuration) instructions.
53 |
54 | You can then run the hub using the docker image:
55 |
56 | ```
57 | docker run --rm -p 3000:3000 -v $(pwd)/config.yaml:/etc/hub/config.yaml:ro -e CONFIG_FILE_PATH='/etc/hub/config.yaml' -t traceloop/hub
58 | ```
59 |
60 | You can also run it locally. Make sure you have `rust` v1.82 and above installed and then run:
61 |
62 | ```
63 | cargo run
64 | ```
65 |
66 | Connect to the hub by using the OpenAI SDK on any language, and setting the base URL to:
67 |
68 | ```
69 | http://localhost:3000/api/v1
70 | ```
71 |
72 | For example, in Python:
73 |
74 | ```
75 | client = OpenAI(
76 | base_url="http://localhost:3000/api/v1",
77 | api_key=os.getenv("OPENAI_API_KEY"),
78 | # default_headers={"x-traceloop-pipeline": "azure-only"},
79 | )
80 | completion = client.chat.completions.create(
81 | model="claude-3-5-sonnet-20241022",
82 | messages=[{"role": "user", "content": "Tell me a joke about opentelemetry"}],
83 | max_tokens=1000,
84 | )
85 | ```
86 |
87 | ## 🌱 Contributing
88 |
89 | Whether big or small, we love contributions ❤️ Check out our guide to see how to [get started](https://traceloop.com/docs/hub/contributing/overview).
90 |
91 | Not sure where to get started? You can:
92 |
93 | - [Book a free pairing session with one of our teammates](mailto:nir@traceloop.com?subject=Pairing%20session&body=I'd%20like%20to%20do%20a%20pairing%20session!)!
94 | - Join our Slack, and ask us any questions there.
95 |
96 | ## 💚 Community & Support
97 |
98 | - [Slack](https://traceloop.com/slack) (For live discussion with the community and the Traceloop team)
99 | - [GitHub Discussions](https://github.com/traceloop/hub/discussions) (For help with building and deeper conversations about features)
100 | - [GitHub Issues](https://github.com/traceloop/hub/issues) (For any bugs and errors you encounter using OpenLLMetry)
101 | - [Twitter](https://twitter.com/traceloopdev) (Get news fast)
102 |
103 | # Hub
104 |
105 | A unified API interface for routing LLM requests to various providers.
106 |
107 | ## Supported Providers
108 |
109 | - OpenAI
110 | - Anthropic
111 | - Azure OpenAI
112 | - Google VertexAI (Gemini)
113 |
114 | ## Configuration
115 |
116 | See `config-example.yaml` for a complete configuration example.
117 |
118 | ### Provider Configuration
119 |
120 | #### OpenAI
121 |
122 | ```yaml
123 | providers:
124 | - key: openai
125 | type: openai
126 | api_key: ""
127 | ```
128 |
129 | #### Azure OpenAI
130 |
131 | ```yaml
132 | providers:
133 | - key: azure-openai
134 | type: azure
135 | api_key: ""
136 | resource_name: ""
137 | api_version: ""
138 | ```
139 |
140 | #### Google VertexAI (Gemini)
141 |
142 | ```yaml
143 | providers:
144 | - key: vertexai
145 | type: vertexai
146 | api_key: ""
147 | project_id: ""
148 | location: ""
149 | credentials_path: "/path/to/service-account.json"
150 | ```
151 |
152 | Authentication Methods:
153 | 1. API Key Authentication:
154 | - Set the `api_key` field with your GCP API key
155 | - Leave `credentials_path` empty
156 | 2. Service Account Authentication:
157 | - Set `credentials_path` to your service account JSON file path
158 | - Can also use `GOOGLE_APPLICATION_CREDENTIALS` environment variable
159 | - Leave `api_key` empty when using service account auth
160 |
161 | Supported Features:
162 | - Chat Completions (with Gemini models)
163 | - Text Completions
164 | - Embeddings
165 | - Streaming Support
166 | - Function/Tool Calling
167 | - Multi-modal Inputs (images + text)
168 |
169 | Example Model Configuration:
170 | ```yaml
171 | models:
172 | # Chat and Completion model
173 | - key: gemini-1.5-flash
174 | type: gemini-1.5-flash
175 | provider: vertexai
176 |
177 | # Embeddings model
178 | - key: textembedding-gecko
179 | type: textembedding-gecko
180 | provider: vertexai
181 | ```
182 |
183 | Example Usage with OpenAI SDK:
184 | ```python
185 | from openai import OpenAI
186 |
187 | client = OpenAI(
188 | base_url="http://localhost:3000/api/v1",
189 | api_key="not-needed-for-vertexai"
190 | )
191 |
192 | # Chat completion
193 | response = client.chat.completions.create(
194 | model="gemini-1.5-flash",
195 | messages=[{"role": "user", "content": "Tell me a joke"}]
196 | )
197 |
198 | # Embeddings
199 | response = client.embeddings.create(
200 | model="textembedding-gecko",
201 | input="Sample text for embedding"
202 | )
203 | ```
204 |
205 | ### Pipeline Configuration
206 |
207 | ```yaml
208 | pipelines:
209 | - name: default
210 | type: chat
211 | plugins:
212 | - model-router:
213 | models:
214 | - gemini-pro
215 | ```
216 |
217 | ## Development
218 |
219 | ### Running Tests
220 |
221 | The test suite uses recorded HTTP interactions (cassettes) to make tests reproducible without requiring actual API credentials.
222 |
223 | To run tests:
224 | ```bash
225 | cargo test
226 | ```
227 |
228 | To record new test cassettes:
229 | 1. Set up your API credentials:
230 | - For service account auth: Set `VERTEXAI_CREDENTIALS_PATH` to your service account key file path
231 | - For API key auth: Use the test with API key (currently marked as ignored)
232 | 2. Delete the existing cassette files in `tests/cassettes/vertexai/`
233 | 3. Run the tests with recording enabled:
234 | ```bash
235 | RECORD_MODE=1 cargo test
236 | ```
237 |
238 | Additional test configurations:
239 | - `RETRY_DELAY`: Set the delay in seconds between retries when hitting quota limits (default: 60)
240 | - Tests automatically retry up to 3 times when hitting quota limits
241 |
242 | Note: Some tests may be marked as `#[ignore]` if they require specific credentials or are not ready for general use.
243 |
244 | ## License
245 |
246 | See LICENSE file.
247 |
--------------------------------------------------------------------------------
/config-example.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | trace_content_enabled: true # Optional, defaults to true, set to false to disable tracing of request and response content
3 | providers:
4 | # Azure OpenAI configuration
5 | - key: azure-openai
6 | type: azure
7 | api_key: ""
8 | resource_name: ""
9 | api_version: ""
10 |
11 | # OpenAI configuration
12 | - key: openai
13 | type: openai
14 | api_key: ""
15 | base_url: "optional base url. If not provided, defaults to https://api.openai.com/v1"
16 | - key: bedrock
17 | type: bedrock
18 | api_key: ""# Not used for AWS Bedrock
19 | region: "" # like "us-east-1"
20 | inference_profile_id: "" # like "us"
21 | AWS_ACCESS_KEY_ID: ""
22 | AWS_SECRET_ACCESS_KEY: ""
23 | AWS_SESSION_TOKEN: "" # Optional
24 |
25 | # Vertex AI configuration
26 | # Uses service account authentication
27 | - key: vertexai
28 | type: vertexai
29 | api_key: "" # Required field but not used with service account auth
30 | project_id: ""
31 | location: "" # e.g., us-central1
32 | credentials_path: "/path/to/service-account.json" # Path to your service account key file
33 |
34 | models:
35 | # OpenAI Models
36 | - key: gpt-4
37 | type: gpt-4
38 | provider: openai
39 | - key: gpt-3.5-turbo
40 | type: gpt-3.5-turbo
41 | provider: openai
42 |
43 | # Azure OpenAI Models
44 | - key: gpt-4-azure
45 | type: gpt-4
46 | provider: azure-openai
47 | deployment: ""
48 | - key: gpt-35-turbo-azure
49 | type: gpt-35-turbo
50 | provider: azure-openai
51 | deployment: ""
52 |
53 | # Bedrock Models
54 | - key: bedrock-model
55 | # some models are region specific, it is a good idea to get ARN from cross region reference tab
56 | type: "< model-id or Inference profile ARN or Inference profile ID>"
57 | provider: bedrock
58 | model_provider: "anthropic" # can be: ai21, titan, anthropic
59 | model_version: "v2:0" # optional, defaults to "v1:0"
60 |
61 | # Vertex AI Models
62 | # Chat and Completion model
63 | - key: gemini-1.5-flash
64 | type: gemini-1.5-flash # Supports both chat and completion endpoints
65 | provider: vertexai
66 | # Embeddings model
67 | - key: textembedding-gecko
68 | type: textembedding-gecko # Supports embeddings endpoint
69 | provider: vertexai
70 | deployment: ""
71 |
72 | pipelines:
73 | # Default pipeline for chat completions
74 | - name: default
75 | type: chat
76 | plugins:
77 | - logging:
78 | level: info # Supported levels: debug, info, warning, error
79 | - tracing: # Optional tracing configuration
80 | endpoint: "https://api.traceloop.com/v1/traces"
81 | api_key: ""
82 | - model-router:
83 | models: # List the models you want to use for chat
84 | - gpt-4
85 | - gpt-4-azure
86 | - gemini-1.5-flash
87 |
88 | # Pipeline for text completions
89 | - name: completions
90 | type: completion
91 | plugins:
92 | - model-router:
93 | models: # List the models you want to use for completions
94 | - gpt-3.5-turbo
95 | - gpt-35-turbo-azure
96 | - gemini-1.5-flash
97 |
98 | # Pipeline for embeddings
99 | - name: embeddings
100 | type: embeddings
101 | plugins:
102 | - model-router:
103 | models: # List the models you want to use for embeddings
104 | - textembedding-gecko
105 |
--------------------------------------------------------------------------------
/helm/Chart.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v2
2 | name: helm-hub
3 | description: A Helm chart for Hub application
4 | type: application
5 | version: 0.3.0
6 |
--------------------------------------------------------------------------------
/helm/templates/_helpers.tpl:
--------------------------------------------------------------------------------
1 | {{/*
2 | Expand the name of the chart.
3 | */}}
4 | {{- define "app.name" -}}
5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
6 | {{- end }}
7 |
8 | {{/*
9 | Create a default fully qualified app name.
10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
11 | If release name contains chart name it will be used as a full name.
12 | */}}
13 | {{- define "app.fullname" -}}
14 | {{- if .Values.fullnameOverride }}
15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
16 | {{- else }}
17 | {{- $name := default .Chart.Name .Values.nameOverride }}
18 | {{- if contains $name .Release.Name }}
19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }}
20 | {{- else }}
21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
22 | {{- end }}
23 | {{- end }}
24 | {{- end }}
25 |
26 | {{/*
27 | Create chart name and version as used by the chart label.
28 | */}}
29 | {{- define "app.chart" -}}
30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
31 | {{- end }}
32 |
33 | {{/*
34 | Common labels
35 | */}}
36 | {{- define "app.labels" -}}
37 | helm.sh/chart: {{ include "app.chart" . }}
38 | {{ include "app.selectorLabels" . }}
39 | {{- if .Chart.AppVersion }}
40 | app.kubernetes.io/version: {{ .Chart.Version | quote }}
41 | {{- end }}
42 | app.kubernetes.io/managed-by: {{ .Release.Service }}
43 | {{- end }}
44 |
45 | {{/*
46 | Selector labels
47 | */}}
48 | {{- define "app.selectorLabels" -}}
49 | app.kubernetes.io/name: {{ include "app.name" . }}
50 | app.kubernetes.io/instance: {{ .Release.Name }}
51 | {{- end }}
52 |
--------------------------------------------------------------------------------
/helm/templates/configmap.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ConfigMap
3 | metadata:
4 | name: {{ include "app.fullname" . }}-config
5 | labels:
6 | {{- include "app.labels" . | nindent 4 }}
7 | data:
8 | config.yaml: |
9 | providers:
10 | {{- range .Values.config.providers }}
11 | - {{ toYaml . | nindent 8 | trim }}
12 | {{- end }}
13 | models:
14 | {{- range .Values.config.models }}
15 | - key: {{ .key }}
16 | type: {{ .type }}
17 | provider: {{ .provider }}
18 | {{- if .deployment }}
19 | deployment: "{{ .deployment }}"
20 | {{- end }}
21 | {{- end }}
22 |
23 | pipelines:
24 | {{- range .Values.config.pipelines }}
25 | - name: {{ .name }}
26 | type: {{ .type }}
27 | plugins:
28 | {{- range .plugins }}
29 | - {{ toYaml . | nindent 12 | trim }}
30 | {{- end }}
31 | {{- end }}
32 |
--------------------------------------------------------------------------------
/helm/templates/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: {{ include "app.fullname" . }}
5 | labels:
6 | {{- include "app.labels" . | nindent 4 }}
7 | spec:
8 | replicas: {{ .Values.replicaCount }}
9 | selector:
10 | matchLabels:
11 | {{- include "app.selectorLabels" . | nindent 6 }}
12 | template:
13 | metadata:
14 | labels:
15 | {{- include "app.selectorLabels" . | nindent 8 }}
16 | {{- with .Values.podAnnotations }}
17 | annotations:
18 | {{- toYaml . | nindent 8 }}
19 | {{- end }}
20 | spec:
21 | securityContext:
22 | {{- toYaml .Values.podSecurityContext | nindent 8 }}
23 | initContainers:
24 | - name: config-init
25 | image: busybox:latest
26 | command: ["sh", "-c"]
27 | args:
28 | - |
29 | cp /config-template/config.yaml /config/config.yaml
30 | # Dynamic substitution for all API keys
31 | {{- range $key, $value := .Values.secrets.keyMapping }}
32 | {{- $upperKey := upper (regexReplaceAll "ApiKey$" $key "") }}
33 | sed -i "s/{{ $upperKey }}_API_KEY_PLACEHOLDER/${{ $upperKey }}_API_KEY/g" /config/config.yaml
34 | {{- end }}
35 | env:
36 | {{- range $key, $value := .Values.secrets.keyMapping }}
37 | {{- $upperKey := upper (regexReplaceAll "ApiKey$" $key "") }}
38 | - name: {{ $upperKey }}_API_KEY
39 | valueFrom:
40 | secretKeyRef:
41 | name: {{ $.Values.secrets.existingSecretName }}
42 | key: {{ $value }}
43 | optional: {{ ne $key "openaiApiKey" }}
44 | {{- end }}
45 | volumeMounts:
46 | - name: config-template
47 | mountPath: /config-template
48 | - name: config-volume
49 | mountPath: /config
50 | containers:
51 | - name: {{ .Chart.Name }}
52 | securityContext:
53 | {{- toYaml .Values.securityContext | nindent 12 }}
54 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.Version }}"
55 | imagePullPolicy: {{ .Values.image.pullPolicy }}
56 | workingDir: /app
57 | env:
58 | - name: PORT
59 | value: {{ .Values.service.port | quote }}
60 | ports:
61 | - name: http
62 | containerPort: {{ .Values.service.port }}
63 | protocol: TCP
64 | livenessProbe:
65 | httpGet:
66 | path: /health
67 | port: http
68 | initialDelaySeconds: {{ .Values.probes.liveness.initialDelaySeconds | default 30 }}
69 | periodSeconds: {{ .Values.probes.liveness.periodSeconds | default 10 }}
70 | timeoutSeconds: {{ .Values.probes.liveness.timeoutSeconds | default 5 }}
71 | successThreshold: {{ .Values.probes.liveness.successThreshold | default 1 }}
72 | failureThreshold: {{ .Values.probes.liveness.failureThreshold | default 3 }}
73 | readinessProbe:
74 | httpGet:
75 | path: /health
76 | port: http
77 | initialDelaySeconds: {{ .Values.probes.readiness.initialDelaySeconds | default 10 }}
78 | periodSeconds: {{ .Values.probes.readiness.periodSeconds | default 10 }}
79 | timeoutSeconds: {{ .Values.probes.readiness.timeoutSeconds | default 5 }}
80 | successThreshold: {{ .Values.probes.readiness.successThreshold | default 1 }}
81 | failureThreshold: {{ .Values.probes.readiness.failureThreshold | default 3 }}
82 | volumeMounts:
83 | - name: config-volume
84 | mountPath: /app/config.yaml
85 | subPath: config.yaml
86 | resources:
87 | {{- toYaml .Values.resources | nindent 12 }}
88 | volumes:
89 | - name: config-template
90 | configMap:
91 | name: {{ include "app.fullname" . }}-config
92 | - name: config-volume
93 | emptyDir: {}
94 | {{- with .Values.nodeSelector }}
95 | nodeSelector:
96 | {{- toYaml . | nindent 8 }}
97 | {{- end }}
98 | {{- with .Values.affinity }}
99 | affinity:
100 | {{- toYaml . | nindent 8 }}
101 | {{- end }}
102 | {{- with .Values.tolerations }}
103 | tolerations:
104 | {{- toYaml . | nindent 8 }}
105 | {{- end }}
106 |
--------------------------------------------------------------------------------
/helm/templates/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: {{ include "app.fullname" . }}
5 | labels:
6 | {{- include "app.labels" . | nindent 4 }}
7 | spec:
8 | type: {{ .Values.service.type }}
9 | ports:
10 | - port: {{ .Values.service.port }}
11 | targetPort: http
12 | protocol: TCP
13 | name: http
14 | selector:
15 | {{- include "app.selectorLabels" . | nindent 4 }}
16 |
--------------------------------------------------------------------------------
/helm/values.yaml:
--------------------------------------------------------------------------------
1 | # Default values for hub.
2 | # This is a YAML-formatted file.
3 | # Declare variables to be passed into your templates.
4 |
5 | replicaCount: 1
6 |
7 | image:
8 | repository: docker.io/traceloop/hub
9 | pullPolicy: IfNotPresent
10 | # Overrides the image tag whose default is the chart appVersion.
11 | # tag: "latest"
12 | # Optionally specify an array of imagePullSecrets.
13 | # imagePullSecrets:
14 | # - name: myRegistryKeySecretName
15 |
16 | nameOverride: ""
17 | fullnameOverride: ""
18 |
19 | podAnnotations: {}
20 |
21 | podSecurityContext: {}
22 | # fsGroup: 2000
23 |
24 | securityContext: {}
25 | # capabilities:
26 | # drop:
27 | # - ALL
28 | # readOnlyRootFilesystem: true
29 | # runAsNonRoot: true
30 | # runAsUser: 1000
31 |
32 | service:
33 | type: ClusterIP
34 | port: 3100
35 |
36 | # Kubernetes probe configuration
37 | probes:
38 | liveness:
39 | initialDelaySeconds: 30
40 | periodSeconds: 10
41 | timeoutSeconds: 5
42 | successThreshold: 1
43 | failureThreshold: 3
44 | readiness:
45 | initialDelaySeconds: 10
46 | periodSeconds: 10
47 | timeoutSeconds: 5
48 | successThreshold: 1
49 | failureThreshold: 3
50 |
51 | resources:
52 | limits:
53 | cpu: 300m
54 | memory: 400Mi
55 | requests:
56 | cpu: 100m
57 | memory: 200Mi
58 |
59 | nodeSelector: {}
60 |
61 | tolerations: []
62 |
63 | affinity: {}
64 |
65 | # Environment variables to pass to the hub container
66 | env: []
67 | # - name: DEBUG
68 | # value: "true"
69 |
70 | # Configuration for the hub
71 | config:
72 | # Define providers that will be available in the hub
73 | providers:
74 | - key: openai
75 | # Configuration for the OpenAI provider
76 | # Secret management configuration
77 | secrets:
78 | # Name of the existing Kubernetes secret containing API keys
79 | existingSecretName: "llm-api-keys"
80 | # Mapping of API keys in the secret
81 | keyMapping:
82 | # The key in the secret for OpenAI API key
83 | openaiApiKey: "OPENAI_API_KEY"
84 | # Uncomment to use Azure OpenAI API key
85 | # azureApiKey: "AZURE_API_KEY"
86 | # Uncomment to use Anthropic API key
87 | # anthropicApiKey: "ANTHROPIC_API_KEY"
88 | # Uncomment to use Traceloop API key for tracing
89 | # traceloopApiKey: "TRACELOOP_API_KEY"
90 | # Add mappings for any other provider API keys as needed
91 |
--------------------------------------------------------------------------------
/img/logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/traceloop/hub/0beb6ad838abe80d32c6f924598f9be2b95747cf/img/logo-dark.png
--------------------------------------------------------------------------------
/img/logo-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/traceloop/hub/0beb6ad838abe80d32c6f924598f9be2b95747cf/img/logo-light.png
--------------------------------------------------------------------------------
/src/ai_models/instance.rs:
--------------------------------------------------------------------------------
1 | use crate::config::models::ModelConfig;
2 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
3 | use crate::models::completion::{CompletionRequest, CompletionResponse};
4 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
5 | use crate::providers::provider::Provider;
6 | use axum::http::StatusCode;
7 | use std::sync::Arc;
8 |
9 | pub struct ModelInstance {
10 | pub name: String,
11 | pub model_type: String,
12 | pub provider: Arc,
13 | pub config: ModelConfig,
14 | }
15 |
16 | impl ModelInstance {
17 | pub async fn chat_completions(
18 | &self,
19 | mut payload: ChatCompletionRequest,
20 | ) -> Result {
21 | payload.model = self.model_type.clone();
22 | self.provider.chat_completions(payload, &self.config).await
23 | }
24 |
25 | pub async fn completions(
26 | &self,
27 | mut payload: CompletionRequest,
28 | ) -> Result {
29 | payload.model = self.model_type.clone();
30 |
31 | self.provider.completions(payload, &self.config).await
32 | }
33 |
34 | pub async fn embeddings(
35 | &self,
36 | mut payload: EmbeddingsRequest,
37 | ) -> Result {
38 | payload.model = self.model_type.clone();
39 | self.provider.embeddings(payload, &self.config).await
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/src/ai_models/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod instance;
2 | pub mod registry;
3 |
--------------------------------------------------------------------------------
/src/ai_models/registry.rs:
--------------------------------------------------------------------------------
1 | use anyhow::Result;
2 | use std::collections::HashMap;
3 | use std::sync::Arc;
4 |
5 | use super::instance::ModelInstance;
6 | use crate::config::models::ModelConfig;
7 | use crate::providers::registry::ProviderRegistry;
8 |
9 | #[derive(Clone)]
10 | pub struct ModelRegistry {
11 | models: HashMap>,
12 | }
13 |
14 | impl ModelRegistry {
15 | pub fn new(
16 | model_configs: &[ModelConfig],
17 | provider_registry: Arc,
18 | ) -> Result {
19 | let mut models = HashMap::new();
20 |
21 | for config in model_configs {
22 | if let Some(provider) = provider_registry.get(&config.provider) {
23 | let model = Arc::new(ModelInstance {
24 | name: config.key.clone(),
25 | model_type: config.r#type.clone(),
26 | provider,
27 | config: config.clone(),
28 | });
29 |
30 | models.insert(config.key.clone(), model);
31 | }
32 | }
33 |
34 | Ok(Self { models })
35 | }
36 |
37 | pub fn get(&self, name: &str) -> Option> {
38 | self.models.get(name).cloned()
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/src/config/constants.rs:
--------------------------------------------------------------------------------
1 | use std::env;
2 |
3 | pub fn stream_buffer_size_bytes() -> usize {
4 | env::var("STREAM_BUFFER_SIZE_BYTES")
5 | .unwrap_or_else(|_| "1000".to_string())
6 | .parse()
7 | .unwrap_or(1000)
8 | }
9 |
10 | pub fn default_max_tokens() -> u32 {
11 | env::var("DEFAULT_MAX_TOKENS")
12 | .unwrap_or_else(|_| "4096".to_string())
13 | .parse()
14 | .unwrap_or(4096)
15 | }
16 |
17 | // Required field for the TitanEmbeddingRequest
18 | pub fn default_embedding_dimension() -> u32 {
19 | env::var("DEFAULT_EMBEDDING_DIMENSION")
20 | .unwrap_or_else(|_| "512".to_string())
21 | .parse()
22 | .unwrap_or(512)
23 | }
24 | // Required field for the TitanEmbeddingRequest
25 | pub fn default_embedding_normalize() -> bool {
26 | env::var("DEFAULT_EMBEDDING_NORMALIZE")
27 | .unwrap_or_else(|_| "true".to_string())
28 | .parse()
29 | .unwrap_or(true)
30 | }
31 |
--------------------------------------------------------------------------------
/src/config/lib.rs:
--------------------------------------------------------------------------------
1 | use std::sync::OnceLock;
2 |
3 | use crate::config::models::Config;
4 |
5 | pub static TRACE_CONTENT_ENABLED: OnceLock = OnceLock::new();
6 |
7 | pub fn load_config(path: &str) -> Result> {
8 | let contents = std::fs::read_to_string(path)?;
9 | let config: Config = serde_yaml::from_str(&contents)?;
10 | TRACE_CONTENT_ENABLED
11 | .set(
12 | config
13 | .general
14 | .as_ref()
15 | .is_none_or(|g| g.trace_content_enabled),
16 | )
17 | .expect("Failed to set trace content enabled flag");
18 | Ok(config)
19 | }
20 |
21 | pub fn get_trace_content_enabled() -> bool {
22 | *TRACE_CONTENT_ENABLED.get_or_init(|| true)
23 | }
24 |
--------------------------------------------------------------------------------
/src/config/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod constants;
2 | pub mod lib;
3 | pub mod models;
4 |
--------------------------------------------------------------------------------
/src/config/models.rs:
--------------------------------------------------------------------------------
1 | use std::collections::HashMap;
2 |
3 | use serde::{Deserialize, Serialize};
4 |
5 | #[derive(Debug, Deserialize, Serialize, Clone)]
6 | pub struct Config {
7 | pub general: Option,
8 | pub providers: Vec,
9 | pub models: Vec,
10 | pub pipelines: Vec,
11 | }
12 |
13 | #[derive(Debug, Deserialize, Serialize, Clone, Default)]
14 | pub struct General {
15 | #[serde(default = "default_trace_content_enabled")]
16 | pub trace_content_enabled: bool,
17 | }
18 |
19 | #[derive(Debug, Deserialize, Serialize, Clone, Default)]
20 | pub struct Provider {
21 | pub key: String,
22 | pub r#type: String,
23 | #[serde(default = "no_api_key")]
24 | pub api_key: String,
25 | #[serde(flatten)]
26 | pub params: HashMap,
27 | }
28 |
29 | #[derive(Debug, Deserialize, Serialize, Clone, Default)]
30 | pub struct ModelConfig {
31 | pub key: String,
32 | pub r#type: String,
33 | pub provider: String,
34 | #[serde(flatten)]
35 | pub params: HashMap,
36 | }
37 |
38 | #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
39 | #[serde(rename_all = "lowercase")]
40 | pub enum PipelineType {
41 | Chat,
42 | Completion,
43 | Embeddings,
44 | }
45 |
46 | #[derive(Debug, Deserialize, Serialize, Clone)]
47 | pub struct Pipeline {
48 | pub name: String,
49 | pub r#type: PipelineType,
50 | #[serde(with = "serde_yaml::with::singleton_map_recursive")]
51 | pub plugins: Vec,
52 | }
53 |
54 | #[derive(Debug, Deserialize, Serialize, Clone)]
55 | #[serde(rename_all = "kebab-case")]
56 | pub enum PluginConfig {
57 | Logging {
58 | #[serde(default = "default_log_level")]
59 | level: String,
60 | },
61 | Tracing {
62 | endpoint: String,
63 | api_key: String,
64 | },
65 | ModelRouter {
66 | models: Vec,
67 | },
68 | }
69 |
70 | fn default_trace_content_enabled() -> bool {
71 | true
72 | }
73 |
74 | fn default_log_level() -> String {
75 | "warning".to_string()
76 | }
77 |
78 | fn no_api_key() -> String {
79 | "".to_string()
80 | }
81 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod ai_models;
2 | pub mod config;
3 | pub mod models;
4 | pub mod pipelines;
5 | pub mod providers;
6 | pub mod routes;
7 | pub mod state;
8 |
9 | pub use axum;
10 | pub use reqwest;
11 | pub use serde;
12 | pub use serde_json;
13 |
--------------------------------------------------------------------------------
/src/main.rs:
--------------------------------------------------------------------------------
1 | use hub::{config::lib::load_config, routes, state::AppState};
2 | use std::sync::Arc;
3 | use tracing::info;
4 |
5 | #[tokio::main]
6 | async fn main() -> Result<(), anyhow::Error> {
7 | tracing_subscriber::fmt::init();
8 |
9 | info!("Starting Traceloop Hub...");
10 |
11 | let config_path = std::env::var("CONFIG_FILE_PATH").unwrap_or("config.yaml".to_string());
12 |
13 | info!("Loading configuration from {}", config_path);
14 | let config = load_config(&config_path)
15 | .map_err(|e| anyhow::anyhow!("Failed to load configuration: {}", e))?;
16 | let state = Arc::new(
17 | AppState::new(config).map_err(|e| anyhow::anyhow!("Failed to create app state: {}", e))?,
18 | );
19 | let app = routes::create_router(state);
20 | let port: String = std::env::var("PORT").unwrap_or("3000".to_string());
21 | let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
22 | .await
23 | .unwrap();
24 |
25 | info!("Server is running on port {}", port);
26 | axum::serve(listener, app).await.unwrap();
27 |
28 | Ok(())
29 | }
30 |
--------------------------------------------------------------------------------
/src/models/chat.rs:
--------------------------------------------------------------------------------
1 | use futures::stream::BoxStream;
2 | use reqwest_streams::error::StreamBodyError;
3 | use serde::{Deserialize, Serialize};
4 | use std::collections::HashMap;
5 |
6 | use super::content::ChatCompletionMessage;
7 | use super::logprob::LogProbs;
8 | use super::response_format::ResponseFormat;
9 | use super::streaming::ChatCompletionChunk;
10 | use super::tool_choice::ToolChoice;
11 | use super::tool_definition::ToolDefinition;
12 | use super::usage::Usage;
13 |
14 | #[derive(Deserialize, Serialize, Clone)]
15 | pub struct ChatCompletionRequest {
16 | pub model: String,
17 | pub messages: Vec,
18 | #[serde(skip_serializing_if = "Option::is_none")]
19 | pub temperature: Option,
20 | #[serde(skip_serializing_if = "Option::is_none")]
21 | pub top_p: Option,
22 | #[serde(skip_serializing_if = "Option::is_none")]
23 | pub n: Option,
24 | #[serde(skip_serializing_if = "Option::is_none")]
25 | pub stream: Option,
26 | #[serde(skip_serializing_if = "Option::is_none")]
27 | pub stop: Option>,
28 | #[serde(skip_serializing_if = "Option::is_none")]
29 | pub max_tokens: Option,
30 | #[serde(skip_serializing_if = "Option::is_none")]
31 | pub max_completion_tokens: Option,
32 | #[serde(skip_serializing_if = "Option::is_none")]
33 | pub parallel_tool_calls: Option,
34 | #[serde(skip_serializing_if = "Option::is_none")]
35 | pub presence_penalty: Option,
36 | #[serde(skip_serializing_if = "Option::is_none")]
37 | pub frequency_penalty: Option,
38 | #[serde(skip_serializing_if = "Option::is_none")]
39 | pub logit_bias: Option>,
40 | #[serde(skip_serializing_if = "Option::is_none")]
41 | pub tool_choice: Option,
42 | #[serde(skip_serializing_if = "Option::is_none")]
43 | pub tools: Option>,
44 | #[serde(skip_serializing_if = "Option::is_none")]
45 | pub user: Option,
46 | #[serde(skip_serializing_if = "Option::is_none")]
47 | pub logprobs: Option,
48 | #[serde(skip_serializing_if = "Option::is_none")]
49 | pub top_logprobs: Option,
50 | #[serde(skip_serializing_if = "Option::is_none")]
51 | pub response_format: Option,
52 | }
53 |
54 | pub enum ChatCompletionResponse {
55 | Stream(BoxStream<'static, Result>),
56 | NonStream(ChatCompletion),
57 | }
58 |
59 | #[derive(Deserialize, Serialize, Clone)]
60 | pub struct ChatCompletion {
61 | pub id: String,
62 | #[serde(skip_serializing_if = "Option::is_none")]
63 | pub object: Option,
64 | #[serde(skip_serializing_if = "Option::is_none")]
65 | pub created: Option,
66 | pub model: String,
67 | pub choices: Vec,
68 | pub usage: Usage,
69 | pub system_fingerprint: Option,
70 | }
71 |
72 | #[derive(Deserialize, Serialize, Clone)]
73 | pub struct ChatCompletionChoice {
74 | pub index: u32,
75 | pub message: ChatCompletionMessage,
76 | #[serde(skip_serializing_if = "Option::is_none")]
77 | pub finish_reason: Option,
78 | #[serde(skip_serializing_if = "Option::is_none")]
79 | pub logprobs: Option,
80 | }
81 |
--------------------------------------------------------------------------------
/src/models/completion.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 | use std::collections::HashMap;
3 |
4 | use super::usage::Usage;
5 |
6 | #[derive(Deserialize, Serialize, Clone)]
7 | pub struct CompletionRequest {
8 | pub model: String,
9 | pub prompt: String,
10 | #[serde(skip_serializing_if = "Option::is_none")]
11 | pub suffix: Option,
12 | #[serde(skip_serializing_if = "Option::is_none")]
13 | pub max_tokens: Option,
14 | #[serde(skip_serializing_if = "Option::is_none")]
15 | pub temperature: Option,
16 | #[serde(skip_serializing_if = "Option::is_none")]
17 | pub top_p: Option,
18 | #[serde(skip_serializing_if = "Option::is_none")]
19 | pub n: Option,
20 | #[serde(skip_serializing_if = "Option::is_none")]
21 | pub stream: Option,
22 | #[serde(skip_serializing_if = "Option::is_none")]
23 | pub logprobs: Option,
24 | #[serde(skip_serializing_if = "Option::is_none")]
25 | pub echo: Option,
26 | #[serde(skip_serializing_if = "Option::is_none")]
27 | pub stop: Option>,
28 | #[serde(skip_serializing_if = "Option::is_none")]
29 | pub presence_penalty: Option,
30 | #[serde(skip_serializing_if = "Option::is_none")]
31 | pub frequency_penalty: Option,
32 | #[serde(skip_serializing_if = "Option::is_none")]
33 | pub best_of: Option,
34 | #[serde(skip_serializing_if = "Option::is_none")]
35 | pub logit_bias: Option>,
36 | #[serde(skip_serializing_if = "Option::is_none")]
37 | pub user: Option,
38 | }
39 |
40 | #[derive(Deserialize, Serialize, Clone)]
41 | pub struct CompletionResponse {
42 | pub id: String,
43 | pub object: String,
44 | pub created: u64,
45 | pub model: String,
46 | pub choices: Vec,
47 | pub usage: Usage,
48 | }
49 |
50 | #[derive(Deserialize, Serialize, Clone)]
51 | pub struct CompletionChoice {
52 | pub text: String,
53 | pub index: u32,
54 | #[serde(skip_serializing_if = "Option::is_none")]
55 | pub logprobs: Option,
56 | #[serde(skip_serializing_if = "Option::is_none")]
57 | pub finish_reason: Option,
58 | }
59 |
60 | #[derive(Deserialize, Serialize, Clone)]
61 | pub struct LogProbs {
62 | pub tokens: Vec,
63 | pub token_logprobs: Vec,
64 | pub top_logprobs: Vec>,
65 | pub text_offset: Vec,
66 | }
67 |
--------------------------------------------------------------------------------
/src/models/content.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 |
3 | use super::tool_calls::ChatMessageToolCall;
4 |
5 | #[derive(Deserialize, Serialize, Clone)]
6 | #[serde(untagged)]
7 | pub enum ChatMessageContent {
8 | String(String),
9 | Array(Vec),
10 | }
11 |
12 | #[derive(Deserialize, Serialize, Clone)]
13 | pub struct ChatMessageContentPart {
14 | #[serde(rename = "type")]
15 | pub r#type: String,
16 | pub text: String,
17 | }
18 |
19 | #[derive(Deserialize, Serialize, Clone)]
20 | pub struct ChatCompletionMessage {
21 | pub role: String,
22 | #[serde(skip_serializing_if = "Option::is_none")]
23 | pub content: Option,
24 | #[serde(skip_serializing_if = "Option::is_none")]
25 | pub name: Option,
26 | #[serde(skip_serializing_if = "Option::is_none")]
27 | pub tool_calls: Option>,
28 | #[serde(skip_serializing_if = "Option::is_none")]
29 | pub refusal: Option,
30 | }
31 |
--------------------------------------------------------------------------------
/src/models/embeddings.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 | use serde_json::Value;
3 |
4 | use super::usage::EmbeddingUsage;
5 |
6 | #[derive(Deserialize, Serialize, Clone)]
7 | pub struct EmbeddingsRequest {
8 | pub model: String,
9 | pub input: EmbeddingsInput,
10 | #[serde(skip_serializing_if = "Option::is_none")]
11 | pub user: Option,
12 | #[serde(skip_serializing_if = "Option::is_none")]
13 | pub encoding_format: Option,
14 | }
15 |
16 | #[derive(Deserialize, Serialize, Clone)]
17 | #[serde(untagged)]
18 | pub enum EmbeddingsInput {
19 | Single(String),
20 | Multiple(Vec),
21 | SingleTokenIds(Vec),
22 | MultipleTokenIds(Vec>),
23 | }
24 |
25 | #[derive(Deserialize, Serialize, Clone)]
26 | pub struct EmbeddingsResponse {
27 | pub object: String,
28 | pub data: Vec,
29 | pub model: String,
30 | pub usage: EmbeddingUsage,
31 | }
32 |
33 | #[derive(Deserialize, Serialize, Clone)]
34 | pub struct Embeddings {
35 | pub object: String,
36 | pub embedding: Embedding,
37 | pub index: usize,
38 | }
39 |
40 | #[derive(Deserialize, Serialize, Clone)]
41 | #[serde(untagged)]
42 | pub enum Embedding {
43 | String(String),
44 | Float(Vec),
45 | Json(Value),
46 | }
47 |
--------------------------------------------------------------------------------
/src/models/logprob.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 |
3 | #[derive(Deserialize, Serialize, Clone, Debug)]
4 | pub struct LogProbs {
5 | pub content: Vec,
6 | }
7 |
8 | #[derive(Deserialize, Serialize, Clone, Debug)]
9 | pub struct LogProbContent {
10 | pub token: String,
11 | pub logprob: f32,
12 | pub bytes: Vec,
13 | pub top_logprobs: Vec,
14 | }
15 |
16 | #[derive(Deserialize, Serialize, Clone, Debug)]
17 | pub struct TopLogprob {
18 | pub token: String,
19 | #[serde(skip_serializing_if = "Option::is_none")]
20 | pub bytes: Option>,
21 | pub logprob: f64,
22 | }
23 |
24 | #[derive(Deserialize, Serialize, Clone, Debug)]
25 | pub struct ChatCompletionTokenLogprob {
26 | pub token: String,
27 | #[serde(skip_serializing_if = "Option::is_none")]
28 | pub bytes: Option>,
29 | pub logprob: f64,
30 | pub top_logprobs: Vec,
31 | }
32 |
33 | #[derive(Deserialize, Serialize, Clone, Debug)]
34 | pub struct ChoiceLogprobs {
35 | #[serde(skip_serializing_if = "Option::is_none")]
36 | pub content: Option>,
37 | #[serde(skip_serializing_if = "Option::is_none")]
38 | pub refusal: Option>,
39 | }
40 |
--------------------------------------------------------------------------------
/src/models/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod chat;
2 | pub mod completion;
3 | pub mod content;
4 | pub mod embeddings;
5 | pub mod logprob;
6 | pub mod response_format;
7 | pub mod streaming;
8 | pub mod tool_calls;
9 | pub mod tool_choice;
10 | pub mod tool_definition;
11 | pub mod usage;
12 |
--------------------------------------------------------------------------------
/src/models/response_format.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 |
3 | #[derive(Deserialize, Serialize, Clone)]
4 | pub struct ResponseFormat {
5 | #[serde(rename = "type")]
6 | pub r#type: String,
7 | #[serde(skip_serializing_if = "Option::is_none")]
8 | pub json_schema: Option,
9 | }
10 |
11 | #[derive(Deserialize, Serialize, Clone)]
12 | pub struct JsonSchema {
13 | pub name: String,
14 | #[serde(skip_serializing_if = "Option::is_none")]
15 | pub description: Option,
16 | #[serde(skip_serializing_if = "Option::is_none")]
17 | pub schema: Option,
18 | #[serde(skip_serializing_if = "Option::is_none")]
19 | pub strict: Option,
20 | }
21 |
--------------------------------------------------------------------------------
/src/models/streaming.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 |
3 | use super::logprob::ChoiceLogprobs;
4 | use super::tool_calls::ChatMessageToolCall;
5 | use super::usage::Usage;
6 |
7 | #[derive(Deserialize, Serialize, Clone, Debug, Default)]
8 | pub struct Delta {
9 | pub role: Option,
10 | pub content: Option,
11 | pub function_call: Option,
12 | pub tool_calls: Option>,
13 | }
14 |
15 | #[derive(Deserialize, Serialize, Clone, Debug)]
16 | pub struct ChoiceDelta {
17 | #[serde(skip_serializing_if = "Option::is_none")]
18 | pub content: Option,
19 | #[serde(skip_serializing_if = "Option::is_none")]
20 | pub role: Option,
21 | #[serde(skip_serializing_if = "Option::is_none")]
22 | pub tool_calls: Option>,
23 | }
24 |
25 | #[derive(Deserialize, Serialize, Clone, Debug)]
26 | pub struct Choice {
27 | pub delta: ChoiceDelta,
28 | #[serde(skip_serializing_if = "Option::is_none")]
29 | pub finish_reason: Option,
30 | pub index: u32,
31 | #[serde(skip_serializing_if = "Option::is_none")]
32 | pub logprobs: Option,
33 | }
34 |
35 | #[derive(Deserialize, Serialize, Clone, Debug)]
36 | pub struct ChatCompletionChunk {
37 | pub id: String,
38 | pub choices: Vec,
39 | pub created: i64,
40 | pub model: String,
41 | #[serde(skip_serializing_if = "Option::is_none")]
42 | pub service_tier: Option,
43 | #[serde(skip_serializing_if = "Option::is_none")]
44 | pub system_fingerprint: Option,
45 | #[serde(skip_serializing_if = "Option::is_none")]
46 | pub usage: Option,
47 | }
48 |
--------------------------------------------------------------------------------
/src/models/tool_calls.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 |
3 | #[derive(Deserialize, Serialize, Clone, Debug)]
4 | pub struct FunctionCall {
5 | pub arguments: String,
6 | pub name: String,
7 | }
8 |
9 | #[derive(Deserialize, Serialize, Clone, Debug)]
10 | pub struct ChatMessageToolCall {
11 | pub id: String,
12 | pub function: FunctionCall,
13 | #[serde(rename = "type")]
14 | pub r#type: String, // Using `function` as the only valid value
15 | }
16 |
--------------------------------------------------------------------------------
/src/models/tool_choice.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 |
3 | #[derive(Debug, Clone, Serialize, Deserialize)]
4 | #[serde(untagged)]
5 | pub enum ToolChoice {
6 | Simple(SimpleToolChoice),
7 | Named(ChatCompletionNamedToolChoice),
8 | }
9 |
10 | #[derive(Debug, Clone, Serialize, Deserialize)]
11 | #[serde(rename_all = "lowercase")]
12 | pub enum SimpleToolChoice {
13 | None,
14 | Auto,
15 | Required,
16 | }
17 |
18 | #[derive(Debug, Clone, Serialize, Deserialize)]
19 | pub struct ChatCompletionNamedToolChoice {
20 | #[serde(rename = "type")]
21 | pub tool_type: ToolType,
22 | pub function: Function,
23 | }
24 |
25 | #[derive(Debug, Clone, Serialize, Deserialize)]
26 | #[serde(rename_all = "lowercase")]
27 | pub enum ToolType {
28 | Function,
29 | }
30 |
31 | #[derive(Debug, Clone, Serialize, Deserialize)]
32 | pub struct Function {
33 | pub name: String,
34 | }
35 |
--------------------------------------------------------------------------------
/src/models/tool_definition.rs:
--------------------------------------------------------------------------------
1 | use std::collections::HashMap;
2 |
3 | use serde::{Deserialize, Serialize};
4 |
5 | #[derive(Debug, Serialize, Deserialize, Clone)]
6 | pub struct ToolDefinition {
7 | pub function: FunctionDefinition,
8 |
9 | #[serde(rename = "type")]
10 | pub tool_type: String, // Will only accept "function" value
11 | }
12 |
13 | /// A definition of a function that can be called.
14 | #[derive(Debug, Serialize, Deserialize, Clone)]
15 | pub struct FunctionDefinition {
16 | pub name: String,
17 | #[serde(skip_serializing_if = "Option::is_none")]
18 | pub description: Option,
19 | #[serde(skip_serializing_if = "Option::is_none")]
20 | pub parameters: Option>,
21 | #[serde(skip_serializing_if = "Option::is_none")]
22 | pub strict: Option,
23 | }
24 |
--------------------------------------------------------------------------------
/src/models/usage.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 |
3 | #[derive(Deserialize, Serialize, Clone, Debug)]
4 | pub struct CompletionTokensDetails {
5 | #[serde(skip_serializing_if = "Option::is_none")]
6 | pub accepted_prediction_tokens: Option,
7 | #[serde(skip_serializing_if = "Option::is_none")]
8 | pub audio_tokens: Option,
9 | #[serde(skip_serializing_if = "Option::is_none")]
10 | pub reasoning_tokens: Option,
11 | #[serde(skip_serializing_if = "Option::is_none")]
12 | pub rejected_prediction_tokens: Option,
13 | }
14 |
15 | #[derive(Deserialize, Serialize, Clone, Debug)]
16 | pub struct PromptTokensDetails {
17 | #[serde(skip_serializing_if = "Option::is_none")]
18 | pub audio_tokens: Option,
19 | #[serde(skip_serializing_if = "Option::is_none")]
20 | pub cached_tokens: Option,
21 | }
22 |
23 | #[derive(Deserialize, Serialize, Clone, Debug, Default)]
24 | pub struct Usage {
25 | pub prompt_tokens: u32,
26 | pub completion_tokens: u32,
27 | pub total_tokens: u32,
28 | #[serde(skip_serializing_if = "Option::is_none")]
29 | pub completion_tokens_details: Option,
30 | #[serde(skip_serializing_if = "Option::is_none")]
31 | pub prompt_tokens_details: Option,
32 | }
33 |
34 | #[derive(Deserialize, Serialize, Clone, Debug, Default)]
35 | pub struct EmbeddingUsage {
36 | #[serde(skip_serializing_if = "Option::is_none")]
37 | pub prompt_tokens: Option,
38 | pub total_tokens: Option,
39 | }
40 |
--------------------------------------------------------------------------------
/src/pipelines/mod.rs:
--------------------------------------------------------------------------------
1 | mod otel;
2 | pub mod pipeline;
3 |
--------------------------------------------------------------------------------
/src/pipelines/otel.rs:
--------------------------------------------------------------------------------
1 | use crate::config::lib::get_trace_content_enabled;
2 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest};
3 | use crate::models::completion::{CompletionRequest, CompletionResponse};
4 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
5 | use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse};
6 | use crate::models::streaming::ChatCompletionChunk;
7 | use crate::models::usage::{EmbeddingUsage, Usage};
8 | use opentelemetry::global::{BoxedSpan, ObjectSafeSpan};
9 | use opentelemetry::trace::{SpanKind, Status, Tracer};
10 | use opentelemetry::{global, KeyValue};
11 | use opentelemetry_otlp::{SpanExporter, WithExportConfig, WithHttpConfig};
12 | use opentelemetry_sdk::propagation::TraceContextPropagator;
13 | use opentelemetry_sdk::trace::TracerProvider;
14 | use opentelemetry_semantic_conventions::attribute::GEN_AI_REQUEST_MODEL;
15 | use opentelemetry_semantic_conventions::trace::*;
16 | use std::collections::HashMap;
17 |
18 | pub trait RecordSpan {
19 | fn record_span(&self, span: &mut BoxedSpan);
20 | }
21 |
22 | pub struct OtelTracer {
23 | span: BoxedSpan,
24 | accumulated_completion: Option,
25 | }
26 |
27 | impl OtelTracer {
28 | pub fn init(endpoint: String, api_key: String) {
29 | global::set_text_map_propagator(TraceContextPropagator::new());
30 | let mut headers = HashMap::new();
31 | headers.insert("Authorization".to_string(), format!("Bearer {}", api_key));
32 |
33 | let exporter: SpanExporter = SpanExporter::builder()
34 | .with_http()
35 | .with_endpoint(endpoint)
36 | .with_headers(headers)
37 | .build()
38 | .expect("Failed to initialize OpenTelemetry");
39 |
40 | let provider = TracerProvider::builder()
41 | .with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio)
42 | .build();
43 |
44 | global::set_tracer_provider(provider);
45 | }
46 |
47 | pub fn start(operation: &str, request: &T) -> Self {
48 | let tracer = global::tracer("traceloop_hub");
49 | let mut span = tracer
50 | .span_builder(format!("traceloop_hub.{}", operation))
51 | .with_kind(SpanKind::Client)
52 | .start(&tracer);
53 |
54 | request.record_span(&mut span);
55 |
56 | Self {
57 | span,
58 | accumulated_completion: None,
59 | }
60 | }
61 |
62 | pub fn log_chunk(&mut self, chunk: &ChatCompletionChunk) {
63 | if self.accumulated_completion.is_none() {
64 | self.accumulated_completion = Some(ChatCompletion {
65 | id: chunk.id.clone(),
66 | object: None,
67 | created: None,
68 | model: chunk.model.clone(),
69 | choices: vec![],
70 | usage: Usage::default(),
71 | system_fingerprint: chunk.system_fingerprint.clone(),
72 | });
73 | }
74 |
75 | if let Some(completion) = &mut self.accumulated_completion {
76 | for chunk_choice in &chunk.choices {
77 | if let Some(existing_choice) =
78 | completion.choices.get_mut(chunk_choice.index as usize)
79 | {
80 | if let Some(content) = &chunk_choice.delta.content {
81 | if let Some(ChatMessageContent::String(existing_content)) =
82 | &mut existing_choice.message.content
83 | {
84 | existing_content.push_str(content);
85 | }
86 | }
87 | if chunk_choice.finish_reason.is_some() {
88 | existing_choice.finish_reason = chunk_choice.finish_reason.clone();
89 | }
90 | if let Some(tool_calls) = &chunk_choice.delta.tool_calls {
91 | existing_choice.message.tool_calls = Some(tool_calls.clone());
92 | }
93 | } else {
94 | completion.choices.push(ChatCompletionChoice {
95 | index: chunk_choice.index,
96 | message: ChatCompletionMessage {
97 | name: None,
98 | role: chunk_choice
99 | .delta
100 | .role
101 | .clone()
102 | .unwrap_or_else(|| "assistant".to_string()),
103 | content: Some(ChatMessageContent::String(
104 | chunk_choice.delta.content.clone().unwrap_or_default(),
105 | )),
106 | tool_calls: chunk_choice.delta.tool_calls.clone(),
107 | refusal: None,
108 | },
109 | finish_reason: chunk_choice.finish_reason.clone(),
110 | logprobs: None,
111 | });
112 | }
113 | }
114 | }
115 | }
116 |
117 | pub fn streaming_end(&mut self) {
118 | if let Some(completion) = self.accumulated_completion.take() {
119 | completion.record_span(&mut self.span);
120 | self.span.set_status(Status::Ok);
121 | }
122 | }
123 |
124 | pub fn log_success(&mut self, response: &T) {
125 | response.record_span(&mut self.span);
126 | self.span.set_status(Status::Ok);
127 | }
128 |
129 | pub fn log_error(&mut self, description: String) {
130 | self.span.set_status(Status::error(description));
131 | }
132 | }
133 |
134 | impl RecordSpan for ChatCompletionRequest {
135 | fn record_span(&self, span: &mut BoxedSpan) {
136 | span.set_attribute(KeyValue::new("llm.request.type", "chat"));
137 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone()));
138 |
139 | if let Some(freq_penalty) = self.frequency_penalty {
140 | span.set_attribute(KeyValue::new(
141 | GEN_AI_REQUEST_FREQUENCY_PENALTY,
142 | freq_penalty as f64,
143 | ));
144 | }
145 | if let Some(pres_penalty) = self.presence_penalty {
146 | span.set_attribute(KeyValue::new(
147 | GEN_AI_REQUEST_PRESENCE_PENALTY,
148 | pres_penalty as f64,
149 | ));
150 | }
151 | if let Some(top_p) = self.top_p {
152 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64));
153 | }
154 | if let Some(temp) = self.temperature {
155 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64));
156 | }
157 |
158 | if get_trace_content_enabled() {
159 | for (i, message) in self.messages.iter().enumerate() {
160 | if let Some(content) = &message.content {
161 | span.set_attribute(KeyValue::new(
162 | format!("gen_ai.prompt.{}.role", i),
163 | message.role.clone(),
164 | ));
165 | span.set_attribute(KeyValue::new(
166 | format!("gen_ai.prompt.{}.content", i),
167 | match &content {
168 | ChatMessageContent::String(content) => content.clone(),
169 | ChatMessageContent::Array(content) => {
170 | serde_json::to_string(content).unwrap_or_default()
171 | }
172 | },
173 | ));
174 | }
175 | }
176 | }
177 | }
178 | }
179 |
180 | impl RecordSpan for ChatCompletion {
181 | fn record_span(&self, span: &mut BoxedSpan) {
182 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_MODEL, self.model.clone()));
183 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_ID, self.id.clone()));
184 |
185 | self.usage.record_span(span);
186 |
187 | if get_trace_content_enabled() {
188 | for choice in &self.choices {
189 | if let Some(content) = &choice.message.content {
190 | span.set_attribute(KeyValue::new(
191 | format!("gen_ai.completion.{}.role", choice.index),
192 | choice.message.role.clone(),
193 | ));
194 | span.set_attribute(KeyValue::new(
195 | format!("gen_ai.completion.{}.content", choice.index),
196 | match &content {
197 | ChatMessageContent::String(content) => content.clone(),
198 | ChatMessageContent::Array(content) => {
199 | serde_json::to_string(content).unwrap_or_default()
200 | }
201 | },
202 | ));
203 | }
204 | span.set_attribute(KeyValue::new(
205 | format!("gen_ai.completion.{}.finish_reason", choice.index),
206 | choice.finish_reason.clone().unwrap_or_default(),
207 | ));
208 | }
209 | }
210 | }
211 | }
212 |
213 | impl RecordSpan for CompletionRequest {
214 | fn record_span(&self, span: &mut BoxedSpan) {
215 | span.set_attribute(KeyValue::new("llm.request.type", "completion"));
216 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone()));
217 | span.set_attribute(KeyValue::new("gen_ai.prompt", self.prompt.clone()));
218 |
219 | if let Some(freq_penalty) = self.frequency_penalty {
220 | span.set_attribute(KeyValue::new(
221 | GEN_AI_REQUEST_FREQUENCY_PENALTY,
222 | freq_penalty as f64,
223 | ));
224 | }
225 | if let Some(pres_penalty) = self.presence_penalty {
226 | span.set_attribute(KeyValue::new(
227 | GEN_AI_REQUEST_PRESENCE_PENALTY,
228 | pres_penalty as f64,
229 | ));
230 | }
231 | if let Some(top_p) = self.top_p {
232 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TOP_P, top_p as f64));
233 | }
234 | if let Some(temp) = self.temperature {
235 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temp as f64));
236 | }
237 | }
238 | }
239 |
240 | impl RecordSpan for CompletionResponse {
241 | fn record_span(&self, span: &mut BoxedSpan) {
242 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_MODEL, self.model.clone()));
243 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_ID, self.id.clone()));
244 |
245 | self.usage.record_span(span);
246 |
247 | for choice in &self.choices {
248 | span.set_attribute(KeyValue::new(
249 | format!("gen_ai.completion.{}.role", choice.index),
250 | "assistant".to_string(),
251 | ));
252 | span.set_attribute(KeyValue::new(
253 | format!("gen_ai.completion.{}.content", choice.index),
254 | choice.text.clone(),
255 | ));
256 | span.set_attribute(KeyValue::new(
257 | format!("gen_ai.completion.{}.finish_reason", choice.index),
258 | choice.finish_reason.clone().unwrap_or_default(),
259 | ));
260 | }
261 | }
262 | }
263 |
264 | impl RecordSpan for EmbeddingsRequest {
265 | fn record_span(&self, span: &mut BoxedSpan) {
266 | span.set_attribute(KeyValue::new("llm.request.type", "embeddings"));
267 | span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, self.model.clone()));
268 |
269 | if get_trace_content_enabled() {
270 | match &self.input {
271 | EmbeddingsInput::Single(text) => {
272 | span.set_attribute(KeyValue::new("llm.prompt.0.content", text.clone()));
273 | }
274 | EmbeddingsInput::Multiple(texts) => {
275 | for (i, text) in texts.iter().enumerate() {
276 | span.set_attribute(KeyValue::new(
277 | format!("llm.prompt.{}.role", i),
278 | "user".to_string(),
279 | ));
280 | span.set_attribute(KeyValue::new(
281 | format!("llm.prompt.{}.content", i),
282 | text.clone(),
283 | ));
284 | }
285 | }
286 | EmbeddingsInput::SingleTokenIds(token_ids) => {
287 | span.set_attribute(KeyValue::new(
288 | "llm.prompt.0.content",
289 | format!("{:?}", token_ids),
290 | ));
291 | }
292 | EmbeddingsInput::MultipleTokenIds(token_ids) => {
293 | for (i, token_ids) in token_ids.iter().enumerate() {
294 | span.set_attribute(KeyValue::new(
295 | format!("llm.prompt.{}.role", i),
296 | "user".to_string(),
297 | ));
298 | span.set_attribute(KeyValue::new(
299 | format!("llm.prompt.{}.content", i),
300 | format!("{:?}", token_ids),
301 | ));
302 | }
303 | }
304 | }
305 | }
306 | }
307 | }
308 | impl RecordSpan for EmbeddingsResponse {
309 | fn record_span(&self, span: &mut BoxedSpan) {
310 | span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_MODEL, self.model.clone()));
311 |
312 | self.usage.record_span(span);
313 | }
314 | }
315 |
316 | impl RecordSpan for Usage {
317 | fn record_span(&self, span: &mut BoxedSpan) {
318 | span.set_attribute(KeyValue::new(
319 | "gen_ai.usage.prompt_tokens",
320 | self.prompt_tokens as i64,
321 | ));
322 | span.set_attribute(KeyValue::new(
323 | "gen_ai.usage.completion_tokens",
324 | self.completion_tokens as i64,
325 | ));
326 | span.set_attribute(KeyValue::new(
327 | "gen_ai.usage.total_tokens",
328 | self.total_tokens as i64,
329 | ));
330 | }
331 | }
332 |
333 | impl RecordSpan for EmbeddingUsage {
334 | fn record_span(&self, span: &mut BoxedSpan) {
335 | span.set_attribute(KeyValue::new(
336 | "gen_ai.usage.prompt_tokens",
337 | self.prompt_tokens.unwrap_or(0) as i64,
338 | ));
339 | span.set_attribute(KeyValue::new(
340 | "gen_ai.usage.total_tokens",
341 | self.total_tokens.unwrap_or(0) as i64,
342 | ));
343 | }
344 | }
345 |
--------------------------------------------------------------------------------
/src/pipelines/pipeline.rs:
--------------------------------------------------------------------------------
1 | use crate::config::models::PipelineType;
2 | use crate::models::chat::ChatCompletionResponse;
3 | use crate::models::completion::CompletionRequest;
4 | use crate::models::embeddings::EmbeddingsRequest;
5 | use crate::models::streaming::ChatCompletionChunk;
6 | use crate::pipelines::otel::OtelTracer;
7 | use crate::{
8 | ai_models::registry::ModelRegistry,
9 | config::models::{Pipeline, PluginConfig},
10 | models::chat::ChatCompletionRequest,
11 | };
12 | use async_stream::stream;
13 | use axum::response::sse::{Event, KeepAlive};
14 | use axum::response::{IntoResponse, Sse};
15 | use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
16 | use futures::stream::BoxStream;
17 | use futures::{Stream, StreamExt};
18 | use reqwest_streams::error::StreamBodyError;
19 | use std::sync::Arc;
20 |
21 | pub fn create_pipeline(pipeline: &Pipeline, model_registry: &ModelRegistry) -> Router {
22 | let mut router = Router::new();
23 |
24 | for plugin in pipeline.plugins.clone() {
25 | router = match plugin {
26 | PluginConfig::Tracing { endpoint, api_key } => {
27 | OtelTracer::init(endpoint, api_key);
28 | router
29 | }
30 | PluginConfig::ModelRouter { models } => match pipeline.r#type {
31 | PipelineType::Chat => router.route(
32 | "/chat/completions",
33 | post(move |state, payload| chat_completions(state, payload, models)),
34 | ),
35 | PipelineType::Completion => router.route(
36 | "/completions",
37 | post(move |state, payload| completions(state, payload, models)),
38 | ),
39 | PipelineType::Embeddings => router.route(
40 | "/embeddings",
41 | post(move |state, payload| embeddings(state, payload, models)),
42 | ),
43 | },
44 | _ => router,
45 | };
46 | }
47 |
48 | router.with_state(Arc::new(model_registry.clone()))
49 | }
50 |
51 | fn trace_and_stream(
52 | mut tracer: OtelTracer,
53 | stream: BoxStream<'static, Result>,
54 | ) -> impl Stream- > {
55 | stream! {
56 | let mut stream = stream;
57 | while let Some(result) = stream.next().await {
58 | yield match result {
59 | Ok(chunk) => {
60 | tracer.log_chunk(&chunk);
61 | Event::default().json_data(chunk)
62 | }
63 | Err(e) => {
64 | eprintln!("Error in stream: {:?}", e);
65 | tracer.log_error(e.to_string());
66 | Err(axum::Error::new(e))
67 | }
68 | };
69 | }
70 | tracer.streaming_end();
71 | }
72 | }
73 |
74 | pub async fn chat_completions(
75 | State(model_registry): State>,
76 | Json(payload): Json,
77 | model_keys: Vec,
78 | ) -> Result {
79 | let mut tracer = OtelTracer::start("chat", &payload);
80 |
81 | for model_key in model_keys {
82 | let model = model_registry.get(&model_key).unwrap();
83 |
84 | if payload.model == model.model_type {
85 | let response = model
86 | .chat_completions(payload.clone())
87 | .await
88 | .inspect_err(|e| {
89 | eprintln!("Chat completion error for model {}: {:?}", model_key, e);
90 | })?;
91 |
92 | if let ChatCompletionResponse::NonStream(completion) = response {
93 | tracer.log_success(&completion);
94 | return Ok(Json(completion).into_response());
95 | }
96 |
97 | if let ChatCompletionResponse::Stream(stream) = response {
98 | return Ok(Sse::new(trace_and_stream(tracer, stream))
99 | .keep_alive(KeepAlive::default())
100 | .into_response());
101 | }
102 | }
103 | }
104 |
105 | tracer.log_error("No matching model found".to_string());
106 | eprintln!("No matching model found for: {}", payload.model);
107 | Err(StatusCode::NOT_FOUND)
108 | }
109 |
110 | pub async fn completions(
111 | State(model_registry): State>,
112 | Json(payload): Json,
113 | model_keys: Vec,
114 | ) -> impl IntoResponse {
115 | let mut tracer = OtelTracer::start("completion", &payload);
116 |
117 | for model_key in model_keys {
118 | let model = model_registry.get(&model_key).unwrap();
119 |
120 | if payload.model == model.model_type {
121 | let response = model.completions(payload.clone()).await.inspect_err(|e| {
122 | eprintln!("Completion error for model {}: {:?}", model_key, e);
123 | })?;
124 | tracer.log_success(&response);
125 | return Ok(Json(response));
126 | }
127 | }
128 |
129 | tracer.log_error("No matching model found".to_string());
130 | eprintln!("No matching model found for: {}", payload.model);
131 | Err(StatusCode::NOT_FOUND)
132 | }
133 |
134 | pub async fn embeddings(
135 | State(model_registry): State>,
136 | Json(payload): Json,
137 | model_keys: Vec,
138 | ) -> impl IntoResponse {
139 | let mut tracer = OtelTracer::start("embeddings", &payload);
140 |
141 | for model_key in model_keys {
142 | let model = model_registry.get(&model_key).unwrap();
143 |
144 | if payload.model == model.model_type {
145 | let response = model.embeddings(payload.clone()).await.inspect_err(|e| {
146 | eprintln!("Embeddings error for model {}: {:?}", model_key, e);
147 | })?;
148 | tracer.log_success(&response);
149 | return Ok(Json(response));
150 | }
151 | }
152 |
153 | tracer.log_error("No matching model found".to_string());
154 | eprintln!("No matching model found for: {}", payload.model);
155 | Err(StatusCode::NOT_FOUND)
156 | }
157 |
--------------------------------------------------------------------------------
/src/providers/anthropic/mod.rs:
--------------------------------------------------------------------------------
1 | pub(crate) mod models;
2 | mod provider;
3 |
4 | pub use models::{AnthropicChatCompletionRequest, AnthropicChatCompletionResponse};
5 | pub use provider::AnthropicProvider;
6 |
--------------------------------------------------------------------------------
/src/providers/anthropic/models.rs:
--------------------------------------------------------------------------------
1 | use crate::config::constants::default_max_tokens;
2 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest};
3 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent, ChatMessageContentPart};
4 | use crate::models::tool_calls::{ChatMessageToolCall, FunctionCall};
5 | use serde::{Deserialize, Serialize};
6 |
7 | #[derive(Deserialize, Serialize, Clone)]
8 | pub struct AnthropicChatCompletionRequest {
9 | pub max_tokens: u32,
10 | pub model: String,
11 | pub messages: Vec,
12 | #[serde(skip_serializing_if = "Option::is_none")]
13 | pub temperature: Option,
14 | #[serde(skip_serializing_if = "Option::is_none")]
15 | pub tool_choice: Option,
16 | pub tools: Vec,
17 | #[serde(skip_serializing_if = "Option::is_none")]
18 | pub top_p: Option,
19 | #[serde(skip_serializing_if = "Option::is_none")]
20 | pub stream: Option,
21 | #[serde(skip_serializing_if = "Option::is_none")]
22 | pub system: Option,
23 | }
24 |
25 | #[derive(Deserialize, Serialize, Clone)]
26 | pub struct AnthropicChatCompletionResponse {
27 | pub id: String,
28 | pub model: String,
29 | pub content: Vec,
30 | pub usage: Usage,
31 | }
32 |
33 | #[derive(Deserialize, Serialize, Clone)]
34 | #[serde(tag = "type")]
35 | pub enum ContentBlock {
36 | #[serde(rename = "text")]
37 | Text { text: String },
38 | #[serde(rename = "tool_use")]
39 | ToolUse {
40 | id: String,
41 | input: serde_json::Value,
42 | name: String,
43 | },
44 | }
45 |
46 | #[derive(Deserialize, Serialize, Clone)]
47 | pub struct Usage {
48 | pub input_tokens: u32,
49 | pub output_tokens: u32,
50 | }
51 |
52 | #[derive(Deserialize, Serialize, Clone)]
53 | pub(crate) struct InputSchemaTyped {
54 | #[serde(rename = "type")]
55 | pub r#type: String,
56 | #[serde(skip_serializing_if = "Option::is_none")]
57 | pub properties: Option,
58 | }
59 |
60 | pub(crate) type InputSchema = serde_json::Value;
61 |
62 | #[derive(Deserialize, Serialize, Clone)]
63 | pub struct ToolParam {
64 | pub input_schema: InputSchema,
65 | pub name: String,
66 | #[serde(skip_serializing_if = "Option::is_none")]
67 | pub description: Option,
68 | }
69 |
70 | #[derive(Deserialize, Serialize, Clone)]
71 | #[serde(tag = "type")]
72 | pub enum ToolChoice {
73 | #[serde(rename = "auto")]
74 | Auto { disable_parallel_tool_use: bool },
75 | #[serde(rename = "any")]
76 | Any { disable_parallel_tool_use: bool },
77 | #[serde(rename = "tool")]
78 | Tool {
79 | name: String,
80 | disable_parallel_tool_use: bool,
81 | },
82 | }
83 |
84 | impl From for AnthropicChatCompletionRequest {
85 | fn from(request: ChatCompletionRequest) -> Self {
86 | let should_include_tools = !matches!(
87 | request.tool_choice,
88 | Some(crate::models::tool_choice::ToolChoice::Simple(
89 | crate::models::tool_choice::SimpleToolChoice::None
90 | ))
91 | );
92 |
93 | let system = request
94 | .messages
95 | .iter()
96 | .find(|msg| msg.role == "system")
97 | .and_then(|message| match &message.content {
98 | Some(ChatMessageContent::String(text)) => Some(text.clone()),
99 | Some(ChatMessageContent::Array(parts)) => parts
100 | .iter()
101 | .find(|part| part.r#type == "text")
102 | .map(|part| part.text.clone()),
103 | _ => None,
104 | });
105 |
106 | let messages: Vec = request
107 | .messages
108 | .into_iter()
109 | .filter(|msg| msg.role != "system")
110 | .collect();
111 |
112 | let max_tokens = match request.max_completion_tokens {
113 | Some(val) if val > 0 => val,
114 | _ => request.max_tokens.unwrap_or_else(default_max_tokens),
115 | };
116 |
117 | AnthropicChatCompletionRequest {
118 | max_tokens,
119 | model: request.model,
120 | messages,
121 | temperature: request.temperature,
122 | top_p: request.top_p,
123 | stream: request.stream,
124 | system,
125 | tool_choice: request.tool_choice.map(|choice| match choice {
126 | crate::models::tool_choice::ToolChoice::Simple(simple) => match simple {
127 | crate::models::tool_choice::SimpleToolChoice::None
128 | | crate::models::tool_choice::SimpleToolChoice::Auto => ToolChoice::Auto {
129 | disable_parallel_tool_use: request.parallel_tool_calls.unwrap_or(false),
130 | },
131 | crate::models::tool_choice::SimpleToolChoice::Required => ToolChoice::Any {
132 | disable_parallel_tool_use: request.parallel_tool_calls.unwrap_or(false),
133 | },
134 | },
135 | crate::models::tool_choice::ToolChoice::Named(named) => ToolChoice::Tool {
136 | name: named.function.name,
137 | disable_parallel_tool_use: request.parallel_tool_calls.unwrap_or(false),
138 | },
139 | }),
140 | tools: if should_include_tools {
141 | request
142 | .tools
143 | .unwrap_or_default()
144 | .into_iter()
145 | .map(|tool| ToolParam {
146 | name: tool.function.name,
147 | description: tool.function.description,
148 | input_schema: serde_json::to_value(tool.function.parameters)
149 | .unwrap_or_default(),
150 | })
151 | .collect()
152 | } else {
153 | Vec::new()
154 | },
155 | }
156 | }
157 | }
158 |
159 | impl From> for ChatCompletionMessage {
160 | fn from(blocks: Vec) -> Self {
161 | let mut text_content = Vec::::new();
162 | let mut tool_calls = Vec::::new();
163 |
164 | for block in blocks {
165 | match block {
166 | ContentBlock::Text { text } => {
167 | text_content.push(ChatMessageContentPart {
168 | r#type: "text".to_string(),
169 | text,
170 | });
171 | }
172 | ContentBlock::ToolUse { name, input, id } => {
173 | tool_calls.push(ChatMessageToolCall {
174 | id,
175 | function: FunctionCall {
176 | name,
177 | arguments: input.to_string(),
178 | },
179 | r#type: "function".to_string(),
180 | });
181 | }
182 | }
183 | }
184 |
185 | ChatCompletionMessage {
186 | role: "assistant".to_string(),
187 | content: Some(ChatMessageContent::Array(text_content)),
188 | name: None,
189 | refusal: None,
190 | tool_calls: if tool_calls.is_empty() {
191 | None
192 | } else {
193 | Some(tool_calls)
194 | },
195 | }
196 | }
197 | }
198 |
199 | impl From for ChatCompletion {
200 | fn from(response: AnthropicChatCompletionResponse) -> Self {
201 | ChatCompletion {
202 | id: response.id,
203 | object: None,
204 | created: None,
205 | model: response.model,
206 | choices: vec![ChatCompletionChoice {
207 | index: 0,
208 | message: response.content.into(),
209 | finish_reason: Some("stop".to_string()),
210 | logprobs: None,
211 | }],
212 | usage: crate::models::usage::Usage {
213 | prompt_tokens: response.usage.input_tokens,
214 | completion_tokens: response.usage.output_tokens,
215 | total_tokens: response.usage.input_tokens + response.usage.output_tokens,
216 | completion_tokens_details: None,
217 | prompt_tokens_details: None,
218 | },
219 | system_fingerprint: None,
220 | }
221 | }
222 | }
223 |
--------------------------------------------------------------------------------
/src/providers/anthropic/provider.rs:
--------------------------------------------------------------------------------
1 | use axum::async_trait;
2 | use axum::http::StatusCode;
3 | use reqwest::Client;
4 |
5 | use super::models::{AnthropicChatCompletionRequest, AnthropicChatCompletionResponse};
6 | use crate::config::models::{ModelConfig, Provider as ProviderConfig};
7 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
8 | use crate::models::completion::{CompletionRequest, CompletionResponse};
9 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
10 | use crate::providers::provider::Provider;
11 |
12 | pub struct AnthropicProvider {
13 | api_key: String,
14 | config: ProviderConfig,
15 | http_client: Client,
16 | }
17 |
18 | #[async_trait]
19 | impl Provider for AnthropicProvider {
20 | fn new(config: &ProviderConfig) -> Self {
21 | Self {
22 | api_key: config.api_key.clone(),
23 | config: config.clone(),
24 | http_client: Client::new(),
25 | }
26 | }
27 |
28 | fn key(&self) -> String {
29 | self.config.key.clone()
30 | }
31 |
32 | fn r#type(&self) -> String {
33 | "anthropic".to_string()
34 | }
35 |
36 | async fn chat_completions(
37 | &self,
38 | payload: ChatCompletionRequest,
39 | _model_config: &ModelConfig,
40 | ) -> Result {
41 | let request = AnthropicChatCompletionRequest::from(payload);
42 | let response = self
43 | .http_client
44 | .post("https://api.anthropic.com/v1/messages")
45 | .header("x-api-key", &self.api_key)
46 | .header("anthropic-version", "2023-06-01")
47 | .json(&request)
48 | .send()
49 | .await
50 | .map_err(|e| {
51 | eprintln!("Anthropic API request error: {}", e);
52 | StatusCode::INTERNAL_SERVER_ERROR
53 | })?;
54 |
55 | let status = response.status();
56 | if status.is_success() {
57 | if request.stream.unwrap_or(false) {
58 | unimplemented!()
59 | } else {
60 | let anthropic_response: AnthropicChatCompletionResponse = response
61 | .json()
62 | .await
63 | .expect("Failed to parse Anthropic response");
64 | Ok(ChatCompletionResponse::NonStream(anthropic_response.into()))
65 | }
66 | } else {
67 | eprintln!(
68 | "Anthropic API request error: {}",
69 | response.text().await.unwrap()
70 | );
71 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
72 | }
73 | }
74 |
75 | async fn completions(
76 | &self,
77 | _payload: CompletionRequest,
78 | _model_config: &ModelConfig,
79 | ) -> Result {
80 | unimplemented!()
81 | }
82 |
83 | async fn embeddings(
84 | &self,
85 | _payload: EmbeddingsRequest,
86 | _model_config: &ModelConfig,
87 | ) -> Result {
88 | unimplemented!()
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/src/providers/azure/mod.rs:
--------------------------------------------------------------------------------
1 | mod provider;
2 |
3 | pub use provider::AzureProvider;
4 |
--------------------------------------------------------------------------------
/src/providers/azure/provider.rs:
--------------------------------------------------------------------------------
1 | use axum::async_trait;
2 | use axum::http::StatusCode;
3 | use reqwest_streams::JsonStreamResponse;
4 |
5 | use crate::config::constants::stream_buffer_size_bytes;
6 | use crate::config::models::{ModelConfig, Provider as ProviderConfig};
7 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
8 | use crate::models::completion::{CompletionRequest, CompletionResponse};
9 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
10 | use crate::models::streaming::ChatCompletionChunk;
11 | use crate::providers::provider::Provider;
12 | use reqwest::Client;
13 | use tracing::info;
14 |
15 | pub struct AzureProvider {
16 | config: ProviderConfig,
17 | http_client: Client,
18 | }
19 |
20 | impl AzureProvider {
21 | fn endpoint(&self) -> String {
22 | if let Some(base_url) = self.config.params.get("base_url") {
23 | base_url.clone()
24 | } else {
25 | format!(
26 | "https://{}.openai.azure.com/openai/deployments",
27 | self.config.params.get("resource_name").unwrap(),
28 | )
29 | }
30 | }
31 | fn api_version(&self) -> String {
32 | self.config.params.get("api_version").unwrap().clone()
33 | }
34 | }
35 |
36 | #[async_trait]
37 | impl Provider for AzureProvider {
38 | fn new(config: &ProviderConfig) -> Self {
39 | Self {
40 | config: config.clone(),
41 | http_client: Client::new(),
42 | }
43 | }
44 |
45 | fn key(&self) -> String {
46 | self.config.key.clone()
47 | }
48 |
49 | fn r#type(&self) -> String {
50 | "azure".to_string()
51 | }
52 |
53 | async fn chat_completions(
54 | &self,
55 | payload: ChatCompletionRequest,
56 | model_config: &ModelConfig,
57 | ) -> Result {
58 | let deployment = model_config.params.get("deployment").unwrap();
59 | let api_version = self.api_version();
60 | let url = format!(
61 | "{}/{}/chat/completions?api-version={}",
62 | self.endpoint(),
63 | deployment,
64 | api_version
65 | );
66 |
67 | let response = self
68 | .http_client
69 | .post(&url)
70 | .header("api-key", &self.config.api_key)
71 | .json(&payload)
72 | .send()
73 | .await
74 | .map_err(|e| {
75 | eprintln!("Azure OpenAI API request error: {}", e);
76 | StatusCode::INTERNAL_SERVER_ERROR
77 | })?;
78 |
79 | let status = response.status();
80 | if status.is_success() {
81 | if payload.stream.unwrap_or(false) {
82 | let stream =
83 | response.json_array_stream::(stream_buffer_size_bytes());
84 | Ok(ChatCompletionResponse::Stream(stream))
85 | } else {
86 | response
87 | .json()
88 | .await
89 | .map(ChatCompletionResponse::NonStream)
90 | .map_err(|e| {
91 | eprintln!("Azure OpenAI API response error: {}", e);
92 | StatusCode::INTERNAL_SERVER_ERROR
93 | })
94 | }
95 | } else {
96 | info!(
97 | "Azure OpenAI API request error: {}",
98 | response.text().await.unwrap()
99 | );
100 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
101 | }
102 | }
103 |
104 | async fn completions(
105 | &self,
106 | payload: CompletionRequest,
107 | model_config: &ModelConfig,
108 | ) -> Result {
109 | let deployment = model_config.params.get("deployment").unwrap();
110 | let api_version = self.api_version();
111 | let url = format!(
112 | "{}/{}/completions?api-version={}",
113 | self.endpoint(),
114 | deployment,
115 | api_version
116 | );
117 |
118 | let response = self
119 | .http_client
120 | .post(&url)
121 | .header("api-key", &self.config.api_key)
122 | .json(&payload)
123 | .send()
124 | .await
125 | .map_err(|e| {
126 | eprintln!("Azure OpenAI API request error: {}", e);
127 | StatusCode::INTERNAL_SERVER_ERROR
128 | })?;
129 |
130 | let status = response.status();
131 | if status.is_success() {
132 | response.json().await.map_err(|e| {
133 | eprintln!("Azure OpenAI API response error: {}", e);
134 | StatusCode::INTERNAL_SERVER_ERROR
135 | })
136 | } else {
137 | eprintln!(
138 | "Azure OpenAI API request error: {}",
139 | response.text().await.unwrap()
140 | );
141 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
142 | }
143 | }
144 |
145 | async fn embeddings(
146 | &self,
147 | payload: EmbeddingsRequest,
148 | model_config: &ModelConfig,
149 | ) -> Result {
150 | let deployment = model_config.params.get("deployment").unwrap();
151 | let api_version = self.api_version();
152 |
153 | let url = format!(
154 | "{}/{}/embeddings?api-version={}",
155 | self.endpoint(),
156 | deployment,
157 | api_version
158 | );
159 |
160 | let response = self
161 | .http_client
162 | .post(&url)
163 | .header("api-key", &self.config.api_key)
164 | .json(&payload)
165 | .send()
166 | .await
167 | .map_err(|e| {
168 | eprintln!("Azure OpenAI API request error: {}", e);
169 | StatusCode::INTERNAL_SERVER_ERROR
170 | })?;
171 |
172 | let status = response.status();
173 | if status.is_success() {
174 | response.json().await.map_err(|e| {
175 | eprintln!("Azure OpenAI Embeddings API response error: {}", e);
176 | StatusCode::INTERNAL_SERVER_ERROR
177 | })
178 | } else {
179 | eprintln!(
180 | "Azure OpenAI Embeddings API request error: {}",
181 | response.text().await.unwrap()
182 | );
183 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
184 | }
185 | }
186 | }
187 |
--------------------------------------------------------------------------------
/src/providers/bedrock/logs/ai21_j2_mid_v1_completions.json:
--------------------------------------------------------------------------------
1 | {"id":1234,"prompt":{"text":"Tell me a joke","tokens":[{"generatedToken":{"token":"▁Tell▁me","logprob":-15.517871856689453,"raw_logprob":-15.517871856689453},"topTokens":null,"textRange":{"start":0,"end":7}},{"generatedToken":{"token":"▁a▁joke","logprob":-6.4708662033081055,"raw_logprob":-6.4708662033081055},"topTokens":null,"textRange":{"start":7,"end":14}}]},"completions":[{"data":{"text":"\nWell, if you break the internet, will you be held responsible? \n \nBecause... \n \ndownsizing is coming!","tokens":[{"generatedToken":{"token":"<|newline|>","logprob":-4.389514506328851E-4,"raw_logprob":-4.389514506328851E-4},"topTokens":null,"textRange":{"start":0,"end":1}},{"generatedToken":{"token":"▁Well","logprob":-4.107903003692627,"raw_logprob":-4.107903003692627},"topTokens":null,"textRange":{"start":1,"end":5}},{"generatedToken":{"token":",","logprob":-0.020943328738212585,"raw_logprob":-0.020943328738212585},"topTokens":null,"textRange":{"start":5,"end":6}},{"generatedToken":{"token":"▁if▁you","logprob":-1.9317010641098022,"raw_logprob":-1.9317010641098022},"topTokens":null,"textRange":{"start":6,"end":13}},{"generatedToken":{"token":"▁break","logprob":-1.2613706588745117,"raw_logprob":-1.2613706588745117},"topTokens":null,"textRange":{"start":13,"end":19}},{"generatedToken":{"token":"▁the▁internet","logprob":-0.8270013928413391,"raw_logprob":-0.8270013928413391},"topTokens":null,"textRange":{"start":19,"end":32}},{"generatedToken":{"token":",","logprob":-0.0465129017829895,"raw_logprob":-0.0465129017829895},"topTokens":null,"textRange":{"start":32,"end":33}},{"generatedToken":{"token":"▁will▁you","logprob":-5.168251991271973,"raw_logprob":-5.168251991271973},"topTokens":null,"textRange":{"start":33,"end":42}},{"generatedToken":{"token":"▁be▁held▁responsible","logprob":-1.7928266525268555,"raw_logprob":-1.7928266525268555},"topTokens":null,"textRange":{"start":42,"end":62}},{"generatedToken":{"token":"?","logprob":-0.38259002566337585,"raw_logprob":-0.38259002566337585},"topTokens":null,"textRange":{"start":62,"end":63}},{"generatedToken":{"token":"▁","logprob":-0.5640338063240051,"raw_logprob":-0.5640338063240051},"topTokens":null,"textRange":{"start":63,"end":64}},{"generatedToken":{"token":"<|newline|>","logprob":-0.8146487474441528,"raw_logprob":-0.8146487474441528},"topTokens":null,"textRange":{"start":64,"end":65}},{"generatedToken":{"token":"▁▁","logprob":-1.8104121685028076,"raw_logprob":-1.8104121685028076},"topTokens":null,"textRange":{"start":65,"end":66}},{"generatedToken":{"token":"<|newline|>","logprob":-0.11257430166006088,"raw_logprob":-0.11257430166006088},"topTokens":null,"textRange":{"start":66,"end":67}},{"generatedToken":{"token":"▁Because","logprob":-0.9798339605331421,"raw_logprob":-0.9798339605331421},"topTokens":null,"textRange":{"start":67,"end":74}},{"generatedToken":{"token":"...","logprob":-3.552177906036377,"raw_logprob":-3.552177906036377},"topTokens":null,"textRange":{"start":74,"end":77}},{"generatedToken":{"token":"▁","logprob":-1.0605108737945557,"raw_logprob":-1.0605108737945557},"topTokens":null,"textRange":{"start":77,"end":78}},{"generatedToken":{"token":"<|newline|>","logprob":-0.04952247813344002,"raw_logprob":-0.04952247813344002},"topTokens":null,"textRange":{"start":78,"end":79}},{"generatedToken":{"token":"▁▁","logprob":-0.054540783166885376,"raw_logprob":-0.054540783166885376},"topTokens":null,"textRange":{"start":79,"end":80}},{"generatedToken":{"token":"<|newline|>","logprob":-3.5523738915799186E-5,"raw_logprob":-3.5523738915799186E-5},"topTokens":null,"textRange":{"start":80,"end":81}},{"generatedToken":{"token":"▁downsizing","logprob":-10.688655853271484,"raw_logprob":-10.688655853271484},"topTokens":null,"textRange":{"start":81,"end":91}},{"generatedToken":{"token":"▁is▁coming","logprob":-3.518221378326416,"raw_logprob":-3.518221378326416},"topTokens":null,"textRange":{"start":91,"end":101}},{"generatedToken":{"token":"!","logprob":-0.8572316765785217,"raw_logprob":-0.8572316765785217},"topTokens":null,"textRange":{"start":101,"end":102}},{"generatedToken":{"token":"<|endoftext|>","logprob":-0.002994698006659746,"raw_logprob":-0.002994698006659746},"topTokens":null,"textRange":{"start":102,"end":102}}]},"finishReason":{"reason":"endoftext"}}]}
--------------------------------------------------------------------------------
/src/providers/bedrock/logs/ai21_jamba_1_5_mini_v1_0_chat_completions.json:
--------------------------------------------------------------------------------
1 | {"id":"chatcmpl-7c961099b52f4798bede41e6bb087b37","choices":[{"index":0,"message":{"role":"assistant","content":" Why don't skeletons fight each other?\n\nThey don't have the guts.","tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":15,"completion_tokens":21,"total_tokens":36},"meta":{"requestDurationMillis":173},"model":"jamba-1.5-mini"}
--------------------------------------------------------------------------------
/src/providers/bedrock/logs/anthropic_claude_3_haiku_20240307_v1_0_chat_completion.json:
--------------------------------------------------------------------------------
1 | {"id":"msg_bdrk_01EUKDUQgXoPmpVJR79cggfA","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Here's a short joke for you:\n\nWhy can't a bicycle stand up on its own? Because it's two-tired!"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":12,"output_tokens":30}}
--------------------------------------------------------------------------------
/src/providers/bedrock/logs/us_amazon_nova_lite_v1_0_chat_completion.json:
--------------------------------------------------------------------------------
1 | {"output":{"message":{"content":[{"text":"Paris"}],"role":"assistant"}},"stopReason":"end_turn","usage":{"inputTokens":12,"outputTokens":1,"totalTokens":13,"cacheReadInputTokenCount":0,"cacheWriteInputTokenCount":0}}
--------------------------------------------------------------------------------
/src/providers/bedrock/mod.rs:
--------------------------------------------------------------------------------
1 | mod models;
2 |
3 | mod provider;
4 | #[cfg(test)]
5 | mod test;
6 |
7 | pub use provider::BedrockProvider;
8 |
--------------------------------------------------------------------------------
/src/providers/bedrock/models.rs:
--------------------------------------------------------------------------------
1 | use crate::config::constants::{
2 | default_embedding_dimension, default_embedding_normalize, default_max_tokens,
3 | };
4 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest};
5 | use crate::models::completion::{
6 | CompletionChoice, CompletionRequest, CompletionResponse, LogProbs,
7 | };
8 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
9 | use crate::models::embeddings::{
10 | Embedding, Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse,
11 | };
12 | use crate::models::usage::{EmbeddingUsage, Usage};
13 | use serde::{Deserialize, Serialize};
14 |
15 | /**
16 | * Titan models
17 | */
18 |
19 | #[derive(Serialize, Deserialize, Clone)]
20 | pub struct TitanMessageContent {
21 | pub text: String,
22 | }
23 |
24 | #[derive(Serialize, Deserialize, Clone)]
25 | pub struct TitanMessage {
26 | pub role: String,
27 | pub content: Vec,
28 | }
29 |
30 | #[derive(Serialize, Deserialize, Clone)]
31 | pub struct TitanInferenceConfig {
32 | pub max_new_tokens: u32,
33 | }
34 |
35 | #[derive(Serialize, Deserialize, Clone)]
36 | pub struct TitanChatCompletionRequest {
37 | #[serde(rename = "inferenceConfig")]
38 | pub inference_config: TitanInferenceConfig,
39 | pub messages: Vec,
40 | }
41 |
42 | #[derive(Deserialize, Serialize)]
43 | pub struct TitanChatCompletionResponse {
44 | pub output: TitanOutput,
45 | #[serde(rename = "stopReason")]
46 | pub stop_reason: String,
47 | pub usage: TitanUsage,
48 | }
49 |
50 | #[derive(Deserialize, Serialize)]
51 | pub struct TitanOutput {
52 | pub message: TitanMessage,
53 | }
54 |
55 | #[derive(Deserialize, Serialize)]
56 | pub struct TitanUsage {
57 | #[serde(rename = "inputTokens")]
58 | pub input_tokens: u32,
59 | #[serde(rename = "outputTokens")]
60 | pub output_tokens: u32,
61 | #[serde(rename = "totalTokens")]
62 | pub total_tokens: u32,
63 | }
64 |
65 | impl From for TitanChatCompletionRequest {
66 | fn from(request: ChatCompletionRequest) -> Self {
67 | let messages = request
68 | .messages
69 | .into_iter()
70 | .map(|msg| {
71 | let content_text = match msg.content {
72 | Some(ChatMessageContent::String(text)) => text,
73 | Some(ChatMessageContent::Array(parts)) => parts
74 | .into_iter()
75 | .filter(|part| part.r#type == "text")
76 | .map(|part| part.text)
77 | .collect::>()
78 | .join(" "),
79 | None => String::new(),
80 | };
81 |
82 | TitanMessage {
83 | role: msg.role,
84 | content: vec![TitanMessageContent { text: content_text }],
85 | }
86 | })
87 | .collect();
88 |
89 | TitanChatCompletionRequest {
90 | inference_config: TitanInferenceConfig {
91 | max_new_tokens: request.max_tokens.unwrap_or(default_max_tokens()),
92 | },
93 | messages,
94 | }
95 | }
96 | }
97 |
98 | impl From for ChatCompletion {
99 | fn from(response: TitanChatCompletionResponse) -> Self {
100 | let message = ChatCompletionMessage {
101 | role: response.output.message.role,
102 | content: Some(ChatMessageContent::String(
103 | response
104 | .output
105 | .message
106 | .content
107 | .into_iter()
108 | .map(|c| c.text)
109 | .collect::>()
110 | .join(" "),
111 | )),
112 | name: None,
113 | tool_calls: None,
114 | refusal: None, //this is not returned titan as at 1/04/2025
115 | };
116 |
117 | ChatCompletion {
118 | id: "".to_string(), // _response.id is private in aws sdk , can't access
119 | object: None,
120 | created: None,
121 | model: "".to_string(),
122 | choices: vec![ChatCompletionChoice {
123 | index: 0,
124 | message,
125 | finish_reason: Some(response.stop_reason),
126 | logprobs: None,
127 | }],
128 | usage: Usage {
129 | prompt_tokens: response.usage.input_tokens,
130 | completion_tokens: response.usage.output_tokens,
131 | total_tokens: response.usage.total_tokens,
132 | completion_tokens_details: None,
133 | prompt_tokens_details: None,
134 | },
135 | system_fingerprint: None,
136 | }
137 | }
138 | }
139 |
140 | #[derive(Debug, Serialize, Deserialize)]
141 | pub struct TitanEmbeddingRequest {
142 | #[serde(rename = "inputText")]
143 | pub input_text: String,
144 | pub dimensions: u32,
145 | pub normalize: bool,
146 | }
147 |
148 | #[derive(Debug, Serialize, Deserialize)]
149 | pub struct TitanEmbeddingResponse {
150 | pub embedding: Vec,
151 | #[serde(rename = "embeddingsByType")]
152 | pub embeddings_by_type: EmbeddingsByType,
153 | #[serde(rename = "inputTextTokenCount")]
154 | pub input_text_token_count: u32,
155 | }
156 |
157 | #[derive(Debug, Serialize, Deserialize)]
158 | pub struct EmbeddingsByType {
159 | pub float: Vec,
160 | }
161 |
162 | impl From for TitanEmbeddingRequest {
163 | fn from(request: EmbeddingsRequest) -> Self {
164 | let input_text = match request.input {
165 | EmbeddingsInput::Single(text) => text,
166 | EmbeddingsInput::Multiple(texts) => {
167 | texts.first().map(|s| s.to_string()).unwrap_or_default()
168 | }
169 | EmbeddingsInput::SingleTokenIds(token_ids) => token_ids
170 | .iter()
171 | .map(|id| id.to_string())
172 | .collect::>()
173 | .join(" "),
174 | EmbeddingsInput::MultipleTokenIds(all_token_ids) => all_token_ids
175 | .first()
176 | .map(|token_ids| {
177 | token_ids
178 | .iter()
179 | .map(|id| id.to_string())
180 | .collect::>()
181 | .join(" ")
182 | })
183 | .unwrap_or_default(),
184 | };
185 |
186 | TitanEmbeddingRequest {
187 | input_text,
188 | dimensions: default_embedding_dimension(),
189 | normalize: default_embedding_normalize(),
190 | }
191 | }
192 | }
193 |
194 | impl From for EmbeddingsResponse {
195 | fn from(response: TitanEmbeddingResponse) -> Self {
196 | EmbeddingsResponse {
197 | object: "list".to_string(),
198 | data: vec![Embeddings {
199 | object: "embedding".to_string(),
200 | embedding: Embedding::Float(response.embedding),
201 | index: 0,
202 | }],
203 | model: "".to_string(),
204 | usage: EmbeddingUsage {
205 | prompt_tokens: Some(response.input_text_token_count),
206 | total_tokens: Some(response.input_text_token_count),
207 | },
208 | }
209 | }
210 | }
211 |
212 | /*
213 | Ai21 models
214 | */
215 |
216 | #[derive(Debug, Deserialize, Serialize, Clone)]
217 | pub struct Ai21Message {
218 | pub role: String,
219 | pub content: String,
220 | }
221 |
222 | #[derive(Debug, Deserialize, Serialize, Clone)]
223 | pub struct Ai21ChatCompletionRequest {
224 | pub messages: Vec,
225 | pub max_tokens: u32,
226 | #[serde(skip_serializing_if = "Option::is_none")]
227 | pub temperature: Option,
228 | #[serde(skip_serializing_if = "Option::is_none")]
229 | pub top_p: Option,
230 | }
231 |
232 | #[derive(Deserialize, Serialize, Clone)]
233 | pub struct Ai21ChatCompletionResponse {
234 | pub id: String,
235 | pub choices: Vec,
236 | pub model: String,
237 | pub usage: Ai21Usage,
238 | pub meta: Ai21Meta,
239 | }
240 |
241 | #[derive(Deserialize, Serialize, Clone)]
242 | pub struct Ai21Choice {
243 | pub finish_reason: String,
244 | pub index: u32,
245 | pub message: Ai21Message,
246 | }
247 |
248 | #[derive(Deserialize, Serialize, Clone)]
249 | pub struct Ai21Meta {
250 | #[serde(rename = "requestDurationMillis")]
251 | pub request_duration_millis: u64,
252 | }
253 |
254 | #[derive(Deserialize, Serialize, Clone)]
255 | pub struct Ai21Usage {
256 | pub completion_tokens: u32,
257 | pub prompt_tokens: u32,
258 | pub total_tokens: u32,
259 | }
260 |
261 | impl From for Ai21ChatCompletionRequest {
262 | fn from(request: ChatCompletionRequest) -> Self {
263 | let messages = request
264 | .messages
265 | .into_iter()
266 | .map(|msg| {
267 | let content = match msg.content {
268 | Some(ChatMessageContent::String(text)) => text,
269 | Some(ChatMessageContent::Array(parts)) => parts
270 | .into_iter()
271 | .filter(|part| part.r#type == "text")
272 | .map(|part| part.text)
273 | .collect::>()
274 | .join(" "),
275 | None => String::new(),
276 | };
277 |
278 | Ai21Message {
279 | role: msg.role,
280 | content,
281 | }
282 | })
283 | .collect();
284 |
285 | Ai21ChatCompletionRequest {
286 | messages,
287 | max_tokens: request.max_tokens.unwrap_or(default_max_tokens()),
288 | temperature: request.temperature,
289 | top_p: request.top_p,
290 | }
291 | }
292 | }
293 |
294 | impl From for ChatCompletion {
295 | fn from(response: Ai21ChatCompletionResponse) -> Self {
296 | ChatCompletion {
297 | id: response.id,
298 | object: None,
299 | created: None,
300 | model: response.model,
301 | choices: response
302 | .choices
303 | .into_iter()
304 | .map(|choice| ChatCompletionChoice {
305 | index: choice.index,
306 | message: ChatCompletionMessage {
307 | role: choice.message.role,
308 | content: Some(ChatMessageContent::String(choice.message.content)),
309 | name: None,
310 | tool_calls: None,
311 | refusal: None, //Ai21 does not return this as at 1/04/2025
312 | },
313 | finish_reason: Some(choice.finish_reason),
314 | logprobs: None,
315 | })
316 | .collect(),
317 | usage: Usage {
318 | prompt_tokens: response.usage.prompt_tokens,
319 | completion_tokens: response.usage.completion_tokens,
320 | total_tokens: response.usage.total_tokens,
321 | completion_tokens_details: None,
322 | prompt_tokens_details: None,
323 | },
324 | system_fingerprint: None,
325 | }
326 | }
327 | }
328 |
329 | #[derive(Debug, Serialize, Deserialize, Clone)]
330 | pub struct Ai21CompletionsRequest {
331 | pub prompt: String,
332 | #[serde(rename = "maxTokens")]
333 | pub max_tokens: u32,
334 | #[serde(skip_serializing_if = "Option::is_none")]
335 | pub temperature: Option,
336 | #[serde(rename = "topP", skip_serializing_if = "Option::is_none")]
337 | pub top_p: Option,
338 | #[serde(rename = "stopSequences")]
339 | pub stop_sequences: Vec,
340 | #[serde(rename = "countPenalty")]
341 | pub count_penalty: PenaltyConfig,
342 | #[serde(rename = "presencePenalty")]
343 | pub presence_penalty: PenaltyConfig,
344 | #[serde(rename = "frequencyPenalty")]
345 | pub frequency_penalty: PenaltyConfig,
346 | }
347 |
348 | #[derive(Debug, Serialize, Deserialize, Clone)]
349 | pub struct PenaltyConfig {
350 | pub scale: i32,
351 | }
352 |
353 | #[derive(Debug, Serialize, Deserialize, Clone)]
354 | pub struct Ai21CompletionsResponse {
355 | pub id: i64,
356 | pub prompt: Ai21Prompt,
357 | pub completions: Vec,
358 | }
359 |
360 | #[derive(Debug, Serialize, Deserialize, Clone)]
361 | pub struct Ai21CompletionWrapper {
362 | pub data: Ai21CompletionData,
363 | #[serde(rename = "finishReason")]
364 | pub finish_reason: Ai21FinishReason,
365 | }
366 |
367 | #[derive(Debug, Serialize, Deserialize, Clone)]
368 | pub struct Ai21Prompt {
369 | pub text: String,
370 | pub tokens: Vec,
371 | }
372 |
373 | #[derive(Debug, Serialize, Deserialize, Clone)]
374 | pub struct Ai21CompletionData {
375 | pub text: String,
376 | pub tokens: Vec,
377 | }
378 |
379 | #[derive(Debug, Serialize, Deserialize, Clone)]
380 | pub struct Ai21Token {
381 | #[serde(rename = "generatedToken")]
382 | pub generated_token: Option,
383 | #[serde(rename = "textRange")]
384 | pub text_range: TextRange,
385 | #[serde(rename = "topTokens")]
386 | pub top_tokens: Option>,
387 | }
388 |
389 | #[derive(Debug, Serialize, Deserialize, Clone)]
390 | pub struct GeneratedToken {
391 | pub token: String,
392 | #[serde(rename = "logprob")]
393 | pub log_prob: f64,
394 | #[serde(rename = "raw_logprob")]
395 | pub raw_log_prob: f64,
396 | }
397 |
398 | #[derive(Debug, Serialize, Deserialize, Clone)]
399 | pub struct TextRange {
400 | pub start: i32,
401 | pub end: i32,
402 | }
403 |
404 | #[derive(Debug, Serialize, Deserialize, Clone)]
405 | pub struct TopToken {
406 | pub token: String,
407 | pub logprob: f64,
408 | }
409 |
410 | #[derive(Debug, Serialize, Deserialize, Clone)]
411 | pub struct Ai21FinishReason {
412 | pub reason: String,
413 | }
414 |
415 | impl From for Ai21CompletionsRequest {
416 | fn from(request: CompletionRequest) -> Self {
417 | Self {
418 | prompt: request.prompt,
419 | max_tokens: request.max_tokens.unwrap_or(default_max_tokens()),
420 | temperature: request.temperature,
421 | top_p: request.top_p,
422 | stop_sequences: request.stop.unwrap_or_default(),
423 | count_penalty: PenaltyConfig { scale: 0 },
424 | presence_penalty: PenaltyConfig {
425 | scale: if let Some(penalty) = request.presence_penalty {
426 | penalty as i32
427 | } else {
428 | 0
429 | },
430 | },
431 | frequency_penalty: PenaltyConfig {
432 | scale: if let Some(penalty) = request.frequency_penalty {
433 | penalty as i32
434 | } else {
435 | 0
436 | },
437 | },
438 | }
439 | }
440 | }
441 |
442 | impl From for CompletionResponse {
443 | fn from(response: Ai21CompletionsResponse) -> Self {
444 | let total_prompt_tokens = response.prompt.tokens.len() as u32;
445 | let total_completion_tokens = response
446 | .completions
447 | .iter()
448 | .map(|c| c.data.tokens.len() as u32)
449 | .sum();
450 |
451 | CompletionResponse {
452 | id: response.id.to_string(),
453 | object: "".to_string(),
454 | created: chrono::Utc::now().timestamp() as u64,
455 | model: "".to_string(),
456 | choices: response
457 | .completions
458 | .into_iter()
459 | .enumerate()
460 | .map(|(index, completion)| CompletionChoice {
461 | text: completion.data.text,
462 | index: index as u32,
463 | logprobs: Some(LogProbs {
464 | tokens: completion
465 | .data
466 | .tokens
467 | .iter()
468 | .filter_map(|t| t.generated_token.as_ref().map(|gt| gt.token.clone()))
469 | .collect(),
470 | token_logprobs: completion
471 | .data
472 | .tokens
473 | .iter()
474 | .filter_map(|t| t.generated_token.as_ref().map(|gt| gt.log_prob as f32))
475 | .collect(),
476 | top_logprobs: completion
477 | .data
478 | .tokens
479 | .iter()
480 | .map(|t| {
481 | t.top_tokens
482 | .clone()
483 | .map(|tt| {
484 | tt.into_iter()
485 | .map(|top| (top.token, top.logprob as f32))
486 | .collect()
487 | })
488 | .unwrap_or_default()
489 | })
490 | .collect(),
491 | text_offset: completion
492 | .data
493 | .tokens
494 | .iter()
495 | .map(|t| t.text_range.start as usize)
496 | .collect(),
497 | }),
498 | finish_reason: Some(completion.finish_reason.reason),
499 | })
500 | .collect(),
501 | usage: Usage {
502 | prompt_tokens: total_prompt_tokens,
503 | completion_tokens: total_completion_tokens,
504 | total_tokens: total_prompt_tokens + total_completion_tokens,
505 | completion_tokens_details: None,
506 | prompt_tokens_details: None,
507 | },
508 | }
509 | }
510 | }
511 |
--------------------------------------------------------------------------------
/src/providers/bedrock/provider.rs:
--------------------------------------------------------------------------------
1 | use axum::async_trait;
2 | use axum::http::StatusCode;
3 | use std::error::Error;
4 |
5 | use aws_sdk_bedrockruntime::Client as BedrockRuntimeClient;
6 |
7 | use crate::config::models::{ModelConfig, Provider as ProviderConfig};
8 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
9 | use crate::models::completion::{CompletionRequest, CompletionResponse};
10 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
11 | use crate::providers::provider::Provider;
12 |
13 | use crate::providers::anthropic::{
14 | AnthropicChatCompletionRequest, AnthropicChatCompletionResponse,
15 | };
16 | use crate::providers::bedrock::models::{
17 | Ai21ChatCompletionRequest, Ai21ChatCompletionResponse, Ai21CompletionsRequest,
18 | Ai21CompletionsResponse, TitanChatCompletionRequest, TitanChatCompletionResponse,
19 | TitanEmbeddingRequest, TitanEmbeddingResponse,
20 | };
21 | use aws_sdk_bedrockruntime::primitives::Blob;
22 |
23 | struct AI21Implementation;
24 | struct TitanImplementation;
25 | struct AnthropicImplementation;
26 |
27 | pub struct BedrockProvider {
28 | pub(crate) config: ProviderConfig,
29 | }
30 |
31 | pub trait ClientProvider {
32 | async fn create_client(&self) -> Result;
33 | }
34 |
35 | #[cfg(not(test))]
36 | impl ClientProvider for BedrockProvider {
37 | async fn create_client(&self) -> Result {
38 | use aws_config::BehaviorVersion;
39 | use aws_config::Region;
40 | use aws_credential_types::Credentials;
41 |
42 | let region = self.config.params.get("region").unwrap().clone();
43 | let access_key_id = self.config.params.get("AWS_ACCESS_KEY_ID").unwrap().clone();
44 | let secret_access_key = self
45 | .config
46 | .params
47 | .get("AWS_SECRET_ACCESS_KEY")
48 | .unwrap()
49 | .clone();
50 | let session_token = self.config.params.get("AWS_SESSION_TOKEN").cloned();
51 |
52 | let credentials = Credentials::from_keys(access_key_id, secret_access_key, session_token);
53 |
54 | let sdk_config = aws_config::defaults(BehaviorVersion::latest())
55 | .region(Region::new(region))
56 | .credentials_provider(credentials)
57 | .load()
58 | .await;
59 |
60 | Ok(BedrockRuntimeClient::new(&sdk_config))
61 | }
62 | }
63 |
64 | impl BedrockProvider {
65 | fn get_provider_implementation(
66 | &self,
67 | model_config: &ModelConfig,
68 | ) -> Box {
69 | let bedrock_model_provider = model_config.params.get("model_provider").unwrap();
70 |
71 | let provider_implementation: Box =
72 | match bedrock_model_provider.as_str() {
73 | "ai21" => Box::new(AI21Implementation),
74 | "titan" => Box::new(TitanImplementation),
75 | "anthropic" => Box::new(AnthropicImplementation),
76 | _ => panic!("Invalid bedrock model provider"),
77 | };
78 |
79 | provider_implementation
80 | }
81 | }
82 |
83 | #[async_trait]
84 | impl Provider for BedrockProvider {
85 | fn new(config: &ProviderConfig) -> Self {
86 | Self {
87 | config: config.clone(),
88 | }
89 | }
90 |
91 | fn key(&self) -> String {
92 | self.config.key.clone()
93 | }
94 |
95 | fn r#type(&self) -> String {
96 | "bedrock".to_string()
97 | }
98 |
99 | async fn chat_completions(
100 | &self,
101 | payload: ChatCompletionRequest,
102 | model_config: &ModelConfig,
103 | ) -> Result {
104 | let client = self.create_client().await.map_err(|e| {
105 | eprintln!("Failed to create Bedrock client: {}", e);
106 | StatusCode::INTERNAL_SERVER_ERROR
107 | })?;
108 |
109 | // Transform model name to include provider prefix
110 | let model_provider = model_config.params.get("model_provider").unwrap();
111 | let inference_profile_id = self.config.params.get("inference_profile_id");
112 | let mut transformed_payload = payload;
113 | let model_version = model_config
114 | .params
115 | .get("model_version")
116 | .map_or("v1:0", |s| &**s);
117 | transformed_payload.model = if let Some(profile_id) = inference_profile_id {
118 | format!(
119 | "{}.{}.{}-{}",
120 | profile_id, model_provider, transformed_payload.model, model_version
121 | )
122 | } else {
123 | format!(
124 | "{}.{}-{}",
125 | model_provider, transformed_payload.model, model_version
126 | )
127 | };
128 |
129 | self.get_provider_implementation(model_config)
130 | .chat_completion(&client, transformed_payload)
131 | .await
132 | }
133 |
134 | async fn completions(
135 | &self,
136 | payload: CompletionRequest,
137 | model_config: &ModelConfig,
138 | ) -> Result {
139 | let client = self.create_client().await.map_err(|e| {
140 | eprintln!("Failed to create Bedrock client: {}", e);
141 | StatusCode::INTERNAL_SERVER_ERROR
142 | })?;
143 |
144 | self.get_provider_implementation(model_config)
145 | .completion(&client, payload)
146 | .await
147 | }
148 |
149 | async fn embeddings(
150 | &self,
151 | payload: EmbeddingsRequest,
152 | model_config: &ModelConfig,
153 | ) -> Result {
154 | let client = self.create_client().await.map_err(|e| {
155 | eprintln!("Failed to create Bedrock client: {}", e);
156 | StatusCode::INTERNAL_SERVER_ERROR
157 | })?;
158 |
159 | self.get_provider_implementation(model_config)
160 | .embedding(&client, payload)
161 | .await
162 | }
163 | }
164 |
165 | /**
166 | BEDROCK IMPLEMENTATION TEMPLATE - WILL SERVE AS LAYOUT FOR OTHER IMPLEMENTATIONS
167 | */
168 |
169 | #[async_trait]
170 | trait BedrockModelImplementation: Send + Sync {
171 | async fn chat_completion(
172 | &self,
173 | client: &BedrockRuntimeClient,
174 | payload: ChatCompletionRequest,
175 | ) -> Result;
176 |
177 | async fn completion(
178 | &self,
179 | _client: &BedrockRuntimeClient,
180 | _payload: CompletionRequest,
181 | ) -> Result {
182 | Err(StatusCode::NOT_IMPLEMENTED)
183 | }
184 |
185 | async fn embedding(
186 | &self,
187 | _client: &BedrockRuntimeClient,
188 | _payload: EmbeddingsRequest,
189 | ) -> Result {
190 | Err(StatusCode::NOT_IMPLEMENTED)
191 | }
192 | }
193 |
194 | trait BedrockRequestHandler {
195 | async fn handle_bedrock_request(
196 | &self,
197 | client: &BedrockRuntimeClient,
198 | model_id: &str,
199 | request: T,
200 | error_context: &str,
201 | ) -> Result
202 | where
203 | T: serde::Serialize + std::marker::Send,
204 | U: for<'de> serde::Deserialize<'de>,
205 | {
206 | // Serialize request
207 | let request_json = serde_json::to_vec(&request).map_err(|e| {
208 | eprintln!("Failed to serialize {}: {}", error_context, e);
209 | StatusCode::INTERNAL_SERVER_ERROR
210 | })?;
211 |
212 | // Make API call
213 | let response = client
214 | .invoke_model()
215 | .body(Blob::new(request_json))
216 | .model_id(model_id)
217 | .send()
218 | .await
219 | .map_err(|e| {
220 | eprintln!("Bedrock API error for {}: {:?}", error_context, e);
221 | eprintln!(
222 | "Error details - Source: {}, Raw error: {:?}",
223 | e.source().unwrap_or(&e),
224 | e.raw_response()
225 | );
226 | StatusCode::INTERNAL_SERVER_ERROR
227 | })?;
228 |
229 | // Deserialize response
230 | serde_json::from_slice(&response.body.into_inner()).map_err(|e| {
231 | eprintln!("Failed to deserialize {} response: {}", error_context, e);
232 | StatusCode::INTERNAL_SERVER_ERROR
233 | })
234 | }
235 | }
236 |
237 | impl BedrockRequestHandler for AI21Implementation {}
238 | impl BedrockRequestHandler for TitanImplementation {}
239 | impl BedrockRequestHandler for AnthropicImplementation {}
240 |
241 | /**
242 | AI21 IMPLEMENTATION
243 | */
244 |
245 | #[async_trait]
246 | impl BedrockModelImplementation for AI21Implementation {
247 | async fn chat_completion(
248 | &self,
249 | client: &BedrockRuntimeClient,
250 | payload: ChatCompletionRequest,
251 | ) -> Result {
252 | let ai21_request = Ai21ChatCompletionRequest::from(payload.clone());
253 | let ai21_response: Ai21ChatCompletionResponse = self
254 | .handle_bedrock_request(client, &payload.model, ai21_request, "AI21 chat completion")
255 | .await?;
256 |
257 | Ok(ChatCompletionResponse::NonStream(ai21_response.into()))
258 | }
259 |
260 | async fn completion(
261 | &self,
262 | client: &BedrockRuntimeClient,
263 | payload: CompletionRequest,
264 | ) -> Result {
265 | // Bedrock AI21 supports completions in legacy models similar to openai
266 | let ai21_request = Ai21CompletionsRequest::from(payload.clone());
267 | let ai21_response: Ai21CompletionsResponse = self
268 | .handle_bedrock_request(client, &payload.model, ai21_request, "AI21 completion")
269 | .await?;
270 |
271 | Ok(CompletionResponse::from(ai21_response))
272 | }
273 | }
274 |
275 | /**
276 | TITAN IMPLEMENTATION
277 | */
278 |
279 | #[async_trait]
280 | impl BedrockModelImplementation for TitanImplementation {
281 | async fn chat_completion(
282 | &self,
283 | client: &BedrockRuntimeClient,
284 | payload: ChatCompletionRequest,
285 | ) -> Result {
286 | let titan_request = TitanChatCompletionRequest::from(payload.clone());
287 | let titan_response: TitanChatCompletionResponse = self
288 | .handle_bedrock_request(
289 | client,
290 | &payload.model,
291 | titan_request,
292 | "Titan chat completion",
293 | )
294 | .await?;
295 |
296 | Ok(ChatCompletionResponse::NonStream(titan_response.into()))
297 | }
298 |
299 | async fn embedding(
300 | &self,
301 | client: &BedrockRuntimeClient,
302 | payload: EmbeddingsRequest,
303 | ) -> Result {
304 | let titan_request = TitanEmbeddingRequest::from(payload.clone());
305 | let titan_response: TitanEmbeddingResponse = self
306 | .handle_bedrock_request(client, &payload.model, titan_request, "Titan embedding")
307 | .await?;
308 |
309 | Ok(EmbeddingsResponse::from(titan_response))
310 | }
311 | }
312 |
313 | /**
314 | ANTHROPIC IMPLEMENTATION
315 | */
316 |
317 | #[async_trait]
318 | impl BedrockModelImplementation for AnthropicImplementation {
319 | async fn chat_completion(
320 | &self,
321 | client: &BedrockRuntimeClient,
322 | payload: ChatCompletionRequest,
323 | ) -> Result {
324 | let anthropic_request = AnthropicChatCompletionRequest::from(payload.clone());
325 |
326 | // Convert to Value for Bedrock-specific modifications
327 | let mut request_value = serde_json::to_value(&anthropic_request).map_err(|e| {
328 | eprintln!("Failed to serialize Anthropic request: {}", e);
329 | StatusCode::INTERNAL_SERVER_ERROR
330 | })?;
331 |
332 | if let serde_json::Value::Object(ref mut map) = request_value {
333 | map.remove("model");
334 | map.insert(
335 | "anthropic_version".to_string(),
336 | serde_json::Value::String("bedrock-2023-05-31".to_string()),
337 | );
338 | }
339 |
340 | let anthropic_response: AnthropicChatCompletionResponse = self
341 | .handle_bedrock_request(
342 | client,
343 | &payload.model,
344 | request_value,
345 | "Anthropic chat completion",
346 | )
347 | .await?;
348 |
349 | Ok(ChatCompletionResponse::NonStream(anthropic_response.into()))
350 | }
351 | }
352 |
--------------------------------------------------------------------------------
/src/providers/bedrock/test.rs:
--------------------------------------------------------------------------------
1 | #[cfg(test)]
2 | impl crate::providers::bedrock::provider::ClientProvider
3 | for crate::providers::bedrock::BedrockProvider
4 | {
5 | // COMMENT OUT THIS BLOCK TO RUN AGAINST ACTUAL AWS SERVICES
6 | // OR CHANGE YOUR ENVIRONMENT FROM TEST TO PROD
7 | async fn create_client(&self) -> Result {
8 | let handler = self
9 | .config
10 | .params
11 | .get("test_response_handler")
12 | .map(|s| s.as_str());
13 | let mock_responses = match handler {
14 | Some("anthropic_chat_completion") => vec![dummy_anthropic_chat_completion_response()],
15 | Some("ai21_chat_completion") => vec![dummy_ai21_chat_completion_response()],
16 | Some("ai21_completion") => vec![dummy_ai21_completion_response()],
17 | Some("titan_chat_completion") => vec![dummy_titan_chat_completion_response()],
18 | Some("titan_embedding") => vec![dummy_titan_embedding_response()],
19 | _ => vec![],
20 | };
21 | let test_client = create_test_bedrock_client(mock_responses).await;
22 | Ok(test_client)
23 | }
24 | }
25 |
26 | #[cfg(test)]
27 | fn get_test_provider_config(
28 | region: &str,
29 | test_response_handler: &str,
30 | ) -> crate::config::models::Provider {
31 | use std::collections::HashMap;
32 |
33 | let mut params = HashMap::new();
34 | params.insert("region".to_string(), region.to_string());
35 |
36 | let aws_access_key_id = std::env::var("AWS_ACCESS_KEY_ID").unwrap_or("test_id".to_string());
37 | let aws_secret_access_key =
38 | std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or("test_key".to_string());
39 |
40 | params.insert("AWS_ACCESS_KEY_ID".to_string(), aws_access_key_id);
41 | params.insert("AWS_SECRET_ACCESS_KEY".to_string(), aws_secret_access_key);
42 |
43 | params.insert(
44 | "test_response_handler".to_string(),
45 | format!("{}", test_response_handler).to_string(),
46 | );
47 |
48 | crate::config::models::Provider {
49 | key: "test_key".to_string(),
50 | r#type: "".to_string(),
51 | api_key: "".to_string(),
52 | params,
53 | }
54 | }
55 | #[cfg(test)]
56 | fn get_test_model_config(
57 | model_type: &str,
58 | provider_type: &str,
59 | ) -> crate::config::models::ModelConfig {
60 | use std::collections::HashMap;
61 |
62 | let mut params = HashMap::new();
63 | params.insert("model_provider".to_string(), provider_type.to_string());
64 |
65 | crate::config::models::ModelConfig {
66 | key: "test-model".to_string(),
67 | r#type: model_type.to_string(),
68 | provider: "bedrock".to_string(),
69 | params,
70 | }
71 | }
72 |
73 | #[cfg(test)]
74 | mod antropic_tests {
75 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
76 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
77 | use crate::providers::bedrock::test::{get_test_model_config, get_test_provider_config};
78 | use crate::providers::bedrock::BedrockProvider;
79 | use crate::providers::provider::Provider;
80 |
81 | #[test]
82 | fn test_bedrock_provider_new() {
83 | let config = get_test_provider_config("us-east-1", "");
84 | let provider = BedrockProvider::new(&config);
85 |
86 | assert_eq!(provider.key(), "test_key");
87 | assert_eq!(provider.r#type(), "bedrock");
88 | }
89 |
90 | #[tokio::test]
91 | async fn test_bedrock_provider_chat_completions() {
92 | let config = get_test_provider_config("us-east-2", "anthropic_chat_completion");
93 | let provider = BedrockProvider::new(&config);
94 |
95 | let model_config =
96 | get_test_model_config("us.anthropic.claude-3-haiku-20240307-v1:0", "anthropic");
97 |
98 | let payload = ChatCompletionRequest {
99 | model: "us.anthropic.claude-3-haiku-20240307-v1:0".to_string(),
100 | messages: vec![ChatCompletionMessage {
101 | role: "user".to_string(),
102 | content: Some(ChatMessageContent::String(
103 | "Tell me a short joke".to_string(),
104 | )),
105 | name: None,
106 | tool_calls: None,
107 | refusal: None,
108 | }],
109 | temperature: None,
110 | top_p: None,
111 | n: None,
112 | stream: None,
113 | stop: None,
114 | max_tokens: None,
115 | max_completion_tokens: None,
116 | parallel_tool_calls: None,
117 | presence_penalty: None,
118 | frequency_penalty: None,
119 | logit_bias: None,
120 | tool_choice: None,
121 | tools: None,
122 | user: None,
123 | logprobs: None,
124 | top_logprobs: None,
125 | response_format: None,
126 | };
127 |
128 | let result = provider.chat_completions(payload, &model_config).await;
129 |
130 | assert!(result.is_ok(), "Chat completion failed: {:?}", result.err());
131 |
132 | if let Ok(ChatCompletionResponse::NonStream(completion)) = result {
133 | assert!(!completion.choices.is_empty(), "Expected non-empty choices");
134 | assert!(
135 | completion.usage.total_tokens > 0,
136 | "Expected non-zero token usage"
137 | );
138 |
139 | let first_choice = &completion.choices[0];
140 | assert!(
141 | first_choice.message.content.is_some(),
142 | "Expected message content"
143 | );
144 | assert_eq!(
145 | first_choice.message.role, "assistant",
146 | "Expected assistant role"
147 | );
148 | }
149 | }
150 | }
151 |
152 | #[cfg(test)]
153 | mod titan_tests {
154 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
155 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
156 | use crate::models::embeddings::EmbeddingsInput::Single;
157 | use crate::models::embeddings::{Embedding, EmbeddingsRequest};
158 | use crate::providers::bedrock::test::{get_test_model_config, get_test_provider_config};
159 | use crate::providers::bedrock::BedrockProvider;
160 | use crate::providers::provider::Provider;
161 |
162 | #[test]
163 | fn test_titan_provider_new() {
164 | let config = get_test_provider_config("us-east-2", "");
165 | let provider = BedrockProvider::new(&config);
166 |
167 | assert_eq!(provider.key(), "test_key");
168 | assert_eq!(provider.r#type(), "bedrock");
169 | }
170 |
171 | #[tokio::test]
172 | async fn test_embeddings() {
173 | let config = get_test_provider_config("us-east-2", "titan_embedding");
174 | let provider = BedrockProvider::new(&config);
175 | let model_config = get_test_model_config("amazon.titan-embed-text-v2:0", "titan");
176 |
177 | let payload = EmbeddingsRequest {
178 | model: "amazon.titan-embed-text-v2:0".to_string(),
179 | user: None,
180 | input: Single("this is where you place your input text".to_string()),
181 | encoding_format: None,
182 | };
183 |
184 | let result = provider.embeddings(payload, &model_config).await;
185 | assert!(
186 | result.is_ok(),
187 | "Titan Embeddings generation failed: {:?}",
188 | result.err()
189 | );
190 | let response = result.unwrap();
191 | assert!(
192 | !response.data.is_empty(),
193 | "Expected non-empty embeddings data"
194 | );
195 | assert!(
196 | matches!(&response.data[0].embedding, Embedding::Float(vec) if !vec.is_empty()),
197 | "Expected non-empty Float embedding vector",
198 | );
199 | assert!(
200 | response.usage.prompt_tokens > Some(0),
201 | "Expected non-zero token usage"
202 | );
203 | }
204 |
205 | #[tokio::test]
206 | async fn test_chat_completions() {
207 | let config = get_test_provider_config("us-east-2", "titan_chat_completion");
208 | let provider = BedrockProvider::new(&config);
209 |
210 | let model_config = get_test_model_config("amazon.titan-embed-text-v2:0", "titan");
211 |
212 | let payload = ChatCompletionRequest {
213 | model: "us.amazon.nova-lite-v1:0".to_string(),
214 | messages: vec![ChatCompletionMessage {
215 | role: "user".to_string(),
216 | content: Some(ChatMessageContent::String(
217 | "What is the capital of France? Answer in one word.".to_string(),
218 | )),
219 | name: None,
220 | tool_calls: None,
221 | refusal: None,
222 | }],
223 | temperature: None,
224 | top_p: None,
225 | n: None,
226 | stream: None,
227 | stop: None,
228 | max_tokens: None,
229 | max_completion_tokens: None,
230 | parallel_tool_calls: None,
231 | presence_penalty: None,
232 | frequency_penalty: None,
233 | logit_bias: None,
234 | tool_choice: None,
235 | tools: None,
236 | user: None,
237 | logprobs: None,
238 | top_logprobs: None,
239 | response_format: None,
240 | };
241 |
242 | let result = provider.chat_completions(payload, &model_config).await;
243 | assert!(result.is_ok(), "Chat completion failed: {:?}", result.err());
244 |
245 | if let Ok(ChatCompletionResponse::NonStream(completion)) = result {
246 | assert!(!completion.choices.is_empty(), "Expected non-empty choices");
247 | assert!(
248 | completion.usage.total_tokens > 0,
249 | "Expected non-zero token usage"
250 | );
251 |
252 | let first_choice = &completion.choices[0];
253 | assert!(
254 | first_choice.message.content.is_some(),
255 | "Expected message content"
256 | );
257 | assert_eq!(
258 | first_choice.message.role, "assistant",
259 | "Expected assistant role"
260 | );
261 | }
262 | }
263 | }
264 |
265 | #[cfg(test)]
266 | mod ai21_tests {
267 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
268 | use crate::models::completion::CompletionRequest;
269 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
270 | use crate::providers::bedrock::test::{get_test_model_config, get_test_provider_config};
271 | use crate::providers::bedrock::BedrockProvider;
272 | use crate::providers::provider::Provider;
273 |
274 | #[test]
275 | fn test_ai21_provider_new() {
276 | let config = get_test_provider_config("us-east-1", "");
277 | let provider = BedrockProvider::new(&config);
278 |
279 | assert_eq!(provider.key(), "test_key");
280 | assert_eq!(provider.r#type(), "bedrock");
281 | }
282 |
283 | #[tokio::test]
284 | async fn test_ai21_provider_completions() {
285 | let config = get_test_provider_config("us-east-1", "ai21_completion");
286 | let provider = BedrockProvider::new(&config);
287 |
288 | let model_config = get_test_model_config("ai21.j2-mid-v1", "ai21");
289 |
290 | let payload = CompletionRequest {
291 | model: "ai21.j2-mid-v1".to_string(),
292 | prompt: "Tell me a joke".to_string(),
293 | suffix: None,
294 | max_tokens: Some(400),
295 | temperature: None,
296 | top_p: None,
297 | n: None,
298 | stream: None,
299 | logprobs: None,
300 | echo: None,
301 | stop: None,
302 | presence_penalty: None,
303 | frequency_penalty: None,
304 | best_of: None,
305 | logit_bias: None,
306 | user: None,
307 | };
308 |
309 | let result = provider.completions(payload, &model_config).await;
310 | assert!(result.is_ok(), "Completion failed: {:?}", result.err());
311 |
312 | let response = result.unwrap();
313 | assert!(!response.choices.is_empty(), "Expected non-empty choices");
314 | assert!(
315 | response.usage.total_tokens > 0,
316 | "Expected non-zero token usage"
317 | );
318 |
319 | let first_choice = &response.choices[0];
320 | assert!(
321 | !first_choice.text.is_empty(),
322 | "Expected non-empty completion text"
323 | );
324 | assert!(
325 | first_choice.logprobs.is_some(),
326 | "Expected logprobs to be present"
327 | );
328 | }
329 |
330 | #[tokio::test]
331 | async fn test_ai21_provider_chat_completions() {
332 | let config = get_test_provider_config("us-east-1", "ai21_chat_completion");
333 | let provider = BedrockProvider::new(&config);
334 |
335 | let model_config = get_test_model_config("ai21.jamba-1-5-mini-v1:0", "ai21");
336 |
337 | let payload = ChatCompletionRequest {
338 | model: "ai21.jamba-1-5-mini-v1:0".to_string(),
339 | messages: vec![ChatCompletionMessage {
340 | role: "user".to_string(),
341 | content: Some(ChatMessageContent::String(
342 | "Tell me a short joke".to_string(),
343 | )),
344 | name: None,
345 | tool_calls: None,
346 | refusal: None,
347 | }],
348 | temperature: Some(0.8),
349 | top_p: Some(0.8),
350 | n: None,
351 | stream: None,
352 | stop: None,
353 | max_tokens: None,
354 | max_completion_tokens: None,
355 | parallel_tool_calls: None,
356 | presence_penalty: None,
357 | frequency_penalty: None,
358 | logit_bias: None,
359 | tool_choice: None,
360 | tools: None,
361 | user: None,
362 | logprobs: None,
363 | top_logprobs: None,
364 | response_format: None,
365 | };
366 |
367 | let result = provider.chat_completions(payload, &model_config).await;
368 | assert!(result.is_ok(), "Chat completion failed: {:?}", result.err());
369 |
370 | if let Ok(ChatCompletionResponse::NonStream(completion)) = result {
371 | assert!(!completion.choices.is_empty(), "Expected non-empty choices");
372 | assert!(
373 | completion.usage.total_tokens > 0,
374 | "Expected non-zero token usage"
375 | );
376 |
377 | let first_choice = &completion.choices[0];
378 | assert!(
379 | first_choice.message.content.is_some(),
380 | "Expected message content"
381 | );
382 | assert_eq!(
383 | first_choice.message.role, "assistant",
384 | "Expected assistant role"
385 | );
386 | }
387 | }
388 | }
389 | /**
390 |
391 | Helper functions for creating test clients and mock responses
392 |
393 | */
394 | #[cfg(test)]
395 | async fn create_test_bedrock_client(
396 | mock_responses: Vec,
397 | ) -> aws_sdk_bedrockruntime::Client {
398 | use aws_config::BehaviorVersion;
399 | use aws_credential_types::provider::SharedCredentialsProvider;
400 | use aws_credential_types::Credentials;
401 | use aws_smithy_runtime::client::http::test_util::StaticReplayClient;
402 | use aws_types::region::Region;
403 |
404 | let replay_client = StaticReplayClient::new(mock_responses);
405 |
406 | let credentials = Credentials::new("test-key", "test-secret", None, None, "testing");
407 | let credentials_provider = SharedCredentialsProvider::new(credentials);
408 |
409 | let config = aws_config::SdkConfig::builder()
410 | .behavior_version(BehaviorVersion::latest())
411 | .region(Region::new("us-east-1".to_string()))
412 | .credentials_provider(credentials_provider)
413 | .http_client(replay_client)
414 | .build();
415 |
416 | aws_sdk_bedrockruntime::Client::new(&config)
417 | }
418 | #[cfg(test)]
419 | fn read_response_file(filename: &str) -> Result {
420 | use std::fs;
421 | use std::io::Read;
422 | use std::path::Path;
423 |
424 | let log_dir = Path::new("src/providers/bedrock/logs");
425 | let file_path = log_dir.join(filename);
426 |
427 | let mut file = fs::File::open(file_path)?;
428 | let mut contents = String::new();
429 | file.read_to_string(&mut contents)?;
430 |
431 | Ok(contents)
432 | }
433 |
434 | /**
435 |
436 | Mock responses for the Bedrock API
437 |
438 | */
439 | #[cfg(test)]
440 | fn dummy_anthropic_chat_completion_response(
441 | ) -> aws_smithy_runtime::client::http::test_util::ReplayEvent {
442 | use aws_smithy_types::body::SdkBody;
443 |
444 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new(
445 | http::Request::builder()
446 | .method("POST")
447 | .uri("https://bedrock-runtime.us-east-2.amazonaws.com/model/us.anthropic.claude-3-haiku-20240307-v1:0/invoke")
448 | .body(SdkBody::empty())
449 | .unwrap(),
450 | http::Response::builder()
451 | .status(200)
452 | .body(SdkBody::from(read_response_file("anthropic_claude_3_haiku_20240307_v1_0_chat_completion.json").unwrap()))
453 | .unwrap(),
454 | )
455 | }
456 | #[cfg(test)]
457 | fn dummy_ai21_chat_completion_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent
458 | {
459 | use aws_smithy_types::body::SdkBody;
460 |
461 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new(
462 | http::Request::builder()
463 | .method("POST")
464 | .uri("https://bedrock-runtime.us-east-1.amazonaws.com/model/ai21.jamba-1-5-mini-v1:0/invoke")
465 | .body(SdkBody::empty())
466 | .unwrap(),
467 | http::Response::builder()
468 | .status(200)
469 | .body(SdkBody::from(read_response_file("ai21_jamba_1_5_mini_v1_0_chat_completions.json").unwrap()))
470 | .unwrap(),
471 | )
472 | }
473 | #[cfg(test)]
474 | fn dummy_ai21_completion_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent {
475 | use aws_smithy_types::body::SdkBody;
476 |
477 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new(
478 | http::Request::builder()
479 | .method("POST")
480 | .uri("https://bedrock-runtime.us-east-1.amazonaws.com/model/ai21.j2-mid-v1/invoke")
481 | .body(SdkBody::empty())
482 | .unwrap(),
483 | http::Response::builder()
484 | .status(200)
485 | .body(SdkBody::from(
486 | read_response_file("ai21_j2_mid_v1_completions.json").unwrap(),
487 | ))
488 | .unwrap(),
489 | )
490 | }
491 | #[cfg(test)]
492 | fn dummy_titan_embedding_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent {
493 | use aws_smithy_types::body::SdkBody;
494 |
495 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new(
496 | http::Request::builder()
497 | .method("POST")
498 | .uri("https://bedrock-runtime.us-east-2.amazonaws.com/model/amazon.titan-embed-text-v2:0/invoke")
499 | .body(SdkBody::empty())
500 | .unwrap(),
501 | http::Response::builder()
502 | .status(200)
503 | .body(SdkBody::from(read_response_file("amazon_titan_embed_text_v2_0_embeddings.json").unwrap()))
504 | .unwrap(),
505 | )
506 | }
507 | #[cfg(test)]
508 | fn dummy_titan_chat_completion_response() -> aws_smithy_runtime::client::http::test_util::ReplayEvent
509 | {
510 | use aws_smithy_types::body::SdkBody;
511 |
512 | aws_smithy_runtime::client::http::test_util::ReplayEvent::new(
513 | http::Request::builder()
514 | .method("POST")
515 | .uri("https://bedrock-runtime.us-east-2.amazonaws.com/model/us.amazon.nova-lite-v1:0/invoke")
516 | .body(SdkBody::empty())
517 | .unwrap(),
518 | http::Response::builder()
519 | .status(200)
520 | .body(SdkBody::from(read_response_file("us_amazon_nova_lite_v1_0_chat_completion.json").unwrap()))
521 | .unwrap(),
522 | )
523 | }
524 |
--------------------------------------------------------------------------------
/src/providers/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod anthropic;
2 | pub mod azure;
3 | pub mod bedrock;
4 | pub mod openai;
5 | pub mod provider;
6 | pub mod registry;
7 | pub mod vertexai;
8 |
--------------------------------------------------------------------------------
/src/providers/openai/mod.rs:
--------------------------------------------------------------------------------
1 | mod provider;
2 |
3 | pub use provider::OpenAIProvider;
4 |
--------------------------------------------------------------------------------
/src/providers/openai/provider.rs:
--------------------------------------------------------------------------------
1 | use crate::config::constants::stream_buffer_size_bytes;
2 | use crate::config::models::{ModelConfig, Provider as ProviderConfig};
3 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
4 | use crate::models::completion::{CompletionRequest, CompletionResponse};
5 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
6 | use crate::models::streaming::ChatCompletionChunk;
7 | use crate::providers::provider::Provider;
8 | use axum::async_trait;
9 | use axum::http::StatusCode;
10 | use reqwest::Client;
11 | use reqwest_streams::*;
12 | use tracing::info;
13 |
14 | pub struct OpenAIProvider {
15 | config: ProviderConfig,
16 | http_client: Client,
17 | }
18 |
19 | impl OpenAIProvider {
20 | fn base_url(&self) -> String {
21 | self.config
22 | .params
23 | .get("base_url")
24 | .unwrap_or(&String::from("https://api.openai.com/v1"))
25 | .to_string()
26 | }
27 | }
28 |
29 | #[async_trait]
30 | impl Provider for OpenAIProvider {
31 | fn new(config: &ProviderConfig) -> Self {
32 | Self {
33 | config: config.clone(),
34 | http_client: Client::new(),
35 | }
36 | }
37 |
38 | fn key(&self) -> String {
39 | self.config.key.clone()
40 | }
41 |
42 | fn r#type(&self) -> String {
43 | "openai".to_string()
44 | }
45 |
46 | async fn chat_completions(
47 | &self,
48 | payload: ChatCompletionRequest,
49 | _model_config: &ModelConfig,
50 | ) -> Result {
51 | let response = self
52 | .http_client
53 | .post(format!("{}/chat/completions", self.base_url()))
54 | .header("Authorization", format!("Bearer {}", self.config.api_key))
55 | .json(&payload)
56 | .send()
57 | .await
58 | .map_err(|e| {
59 | eprintln!("OpenAI API request error: {}", e);
60 | StatusCode::INTERNAL_SERVER_ERROR
61 | })?;
62 |
63 | let status = response.status();
64 | if status.is_success() {
65 | if payload.stream.unwrap_or(false) {
66 | let stream =
67 | response.json_array_stream::(stream_buffer_size_bytes());
68 | Ok(ChatCompletionResponse::Stream(stream))
69 | } else {
70 | response
71 | .json()
72 | .await
73 | .map(ChatCompletionResponse::NonStream)
74 | .map_err(|e| {
75 | eprintln!("OpenAI API response error: {}", e);
76 | StatusCode::INTERNAL_SERVER_ERROR
77 | })
78 | }
79 | } else {
80 | info!(
81 | "OpenAI API request error: {}",
82 | response.text().await.unwrap()
83 | );
84 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
85 | }
86 | }
87 |
88 | async fn completions(
89 | &self,
90 | payload: CompletionRequest,
91 | _model_config: &ModelConfig,
92 | ) -> Result {
93 | let response = self
94 | .http_client
95 | .post(format!("{}/completions", self.base_url()))
96 | .header("Authorization", format!("Bearer {}", self.config.api_key))
97 | .json(&payload)
98 | .send()
99 | .await
100 | .map_err(|e| {
101 | eprintln!("OpenAI API request error: {}", e);
102 | StatusCode::INTERNAL_SERVER_ERROR
103 | })?;
104 |
105 | let status = response.status();
106 | if status.is_success() {
107 | response.json().await.map_err(|e| {
108 | eprintln!("OpenAI API response error: {}", e);
109 | StatusCode::INTERNAL_SERVER_ERROR
110 | })
111 | } else {
112 | eprintln!(
113 | "OpenAI API request error: {}",
114 | response.text().await.unwrap()
115 | );
116 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
117 | }
118 | }
119 |
120 | async fn embeddings(
121 | &self,
122 | payload: EmbeddingsRequest,
123 | _model_config: &ModelConfig,
124 | ) -> Result {
125 | let response = self
126 | .http_client
127 | .post(format!("{}/embeddings", self.base_url()))
128 | .header("Authorization", format!("Bearer {}", self.config.api_key))
129 | .json(&payload)
130 | .send()
131 | .await
132 | .map_err(|e| {
133 | eprintln!("OpenAI API request error: {}", e);
134 | StatusCode::INTERNAL_SERVER_ERROR
135 | })?;
136 |
137 | let status = response.status();
138 | if status.is_success() {
139 | response.json().await.map_err(|e| {
140 | eprintln!("OpenAI API response error: {}", e);
141 | StatusCode::INTERNAL_SERVER_ERROR
142 | })
143 | } else {
144 | eprintln!(
145 | "OpenAI API request error: {}",
146 | response.text().await.unwrap()
147 | );
148 | Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
149 | }
150 | }
151 | }
152 |
--------------------------------------------------------------------------------
/src/providers/provider.rs:
--------------------------------------------------------------------------------
1 | use axum::async_trait;
2 | use axum::http::StatusCode;
3 |
4 | use crate::config::models::{ModelConfig, Provider as ProviderConfig};
5 | use crate::models::chat::{ChatCompletionRequest, ChatCompletionResponse};
6 | use crate::models::completion::{CompletionRequest, CompletionResponse};
7 | use crate::models::embeddings::{EmbeddingsRequest, EmbeddingsResponse};
8 |
9 | #[async_trait]
10 | pub trait Provider: Send + Sync {
11 | fn new(config: &ProviderConfig) -> Self
12 | where
13 | Self: Sized;
14 | fn key(&self) -> String;
15 | fn r#type(&self) -> String;
16 |
17 | async fn chat_completions(
18 | &self,
19 | payload: ChatCompletionRequest,
20 | model_config: &ModelConfig,
21 | ) -> Result;
22 |
23 | async fn completions(
24 | &self,
25 | payload: CompletionRequest,
26 | model_config: &ModelConfig,
27 | ) -> Result;
28 |
29 | async fn embeddings(
30 | &self,
31 | payload: EmbeddingsRequest,
32 | model_config: &ModelConfig,
33 | ) -> Result;
34 | }
35 |
--------------------------------------------------------------------------------
/src/providers/registry.rs:
--------------------------------------------------------------------------------
1 | use anyhow::Result;
2 | use std::collections::HashMap;
3 | use std::sync::Arc;
4 |
5 | use crate::config::models::Provider as ProviderConfig;
6 | use crate::providers::{
7 | anthropic::AnthropicProvider, azure::AzureProvider, bedrock::BedrockProvider,
8 | openai::OpenAIProvider, provider::Provider, vertexai::VertexAIProvider,
9 | };
10 |
11 | pub struct ProviderRegistry {
12 | providers: HashMap>,
13 | }
14 |
15 | impl ProviderRegistry {
16 | pub fn new(provider_configs: &[ProviderConfig]) -> Result {
17 | let mut providers = HashMap::new();
18 |
19 | for config in provider_configs {
20 | let provider: Arc = match config.r#type.as_str() {
21 | "openai" => Arc::new(OpenAIProvider::new(config)),
22 | "anthropic" => Arc::new(AnthropicProvider::new(config)),
23 | "azure" => Arc::new(AzureProvider::new(config)),
24 | "bedrock" => Arc::new(BedrockProvider::new(config)),
25 | "vertexai" => Arc::new(VertexAIProvider::new(config)),
26 | _ => continue,
27 | };
28 | providers.insert(config.key.clone(), provider);
29 | }
30 |
31 | Ok(Self { providers })
32 | }
33 |
34 | pub fn get(&self, name: &str) -> Option> {
35 | self.providers.get(name).cloned()
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/providers/vertexai/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod models;
2 | pub mod provider;
3 | #[cfg(test)]
4 | mod tests;
5 |
6 | pub use provider::VertexAIProvider;
7 |
--------------------------------------------------------------------------------
/src/providers/vertexai/models.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 | use serde_json::Value;
3 |
4 | use crate::models::chat::{ChatCompletion, ChatCompletionChoice, ChatCompletionRequest};
5 | use crate::models::content::{ChatCompletionMessage, ChatMessageContent};
6 | use crate::models::streaming::{ChatCompletionChunk, Choice, ChoiceDelta};
7 | use crate::models::tool_calls::{ChatMessageToolCall, FunctionCall};
8 | use crate::models::tool_choice::{SimpleToolChoice, ToolChoice};
9 | use crate::models::usage::Usage;
10 |
11 | #[derive(Debug, Serialize, Deserialize)]
12 | pub struct GeminiChatRequest {
13 | pub contents: Vec,
14 | #[serde(skip_serializing_if = "Option::is_none")]
15 | pub generation_config: Option,
16 | #[serde(skip_serializing_if = "Option::is_none")]
17 | pub safety_settings: Option>,
18 | #[serde(skip_serializing_if = "Option::is_none")]
19 | pub tools: Option>,
20 | #[serde(skip_serializing_if = "Option::is_none")]
21 | pub tool_choice: Option,
22 | }
23 |
24 | #[derive(Debug, Serialize, Deserialize)]
25 | pub struct GeminiTool {
26 | pub function_declarations: Vec,
27 | }
28 |
29 | #[derive(Debug, Serialize, Deserialize)]
30 | pub struct GeminiFunctionDeclaration {
31 | pub name: String,
32 | pub description: Option,
33 | pub parameters: Value,
34 | }
35 |
36 | #[derive(Debug, Serialize, Deserialize)]
37 | #[serde(rename_all = "snake_case")]
38 | pub enum GeminiToolChoice {
39 | None,
40 | Auto,
41 | Function(GeminiFunctionChoice),
42 | }
43 |
44 | #[derive(Debug, Serialize, Deserialize)]
45 | pub struct GeminiFunctionChoice {
46 | pub name: String,
47 | }
48 |
49 | #[derive(Debug, Serialize, Deserialize)]
50 | pub struct GeminiContent {
51 | pub role: String,
52 | pub parts: Vec,
53 | }
54 |
55 | #[derive(Debug, Serialize, Deserialize)]
56 | pub struct ContentPart {
57 | #[serde(skip_serializing_if = "Option::is_none")]
58 | pub text: Option,
59 | #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")]
60 | pub function_call: Option,
61 | }
62 |
63 | #[derive(Debug, Serialize, Deserialize)]
64 | pub struct GenerationConfig {
65 | #[serde(skip_serializing_if = "Option::is_none")]
66 | pub temperature: Option,
67 | #[serde(skip_serializing_if = "Option::is_none")]
68 | pub top_p: Option,
69 | #[serde(skip_serializing_if = "Option::is_none")]
70 | pub top_k: Option,
71 | #[serde(skip_serializing_if = "Option::is_none")]
72 | pub max_output_tokens: Option,
73 | #[serde(skip_serializing_if = "Option::is_none")]
74 | pub stop_sequences: Option>,
75 | }
76 |
77 | #[derive(Debug, Serialize, Deserialize)]
78 | pub struct SafetySetting {
79 | pub category: String,
80 | pub threshold: String,
81 | }
82 |
83 | #[derive(Debug, Serialize, Deserialize)]
84 | pub struct GeminiChatResponse {
85 | pub candidates: Vec,
86 | pub usage_metadata: Option,
87 | }
88 |
89 | #[derive(Debug, Serialize, Deserialize)]
90 | pub struct GeminiCandidate {
91 | pub content: GeminiContent,
92 | pub finish_reason: Option,
93 | pub safety_ratings: Option