├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── dependabot.yml └── workflows │ └── rust-check.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── count_tokens.rs ├── get_model.rs ├── get_models.rs ├── text_request.rs ├── text_request_json.rs ├── text_request_stream.rs ├── vertex_count_tokens.rs ├── vertex_text_request.rs └── vertex_text_request_stream.rs └── src ├── lib.rs └── v1 ├── api.rs ├── errors.rs ├── gemini.rs ├── mod.rs └── vertexai.rs /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Install '...' 16 | 2. Run '....' 17 | 3. See error 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Version:** 26 | - Version [e.g. 0.1.5] 27 | 28 | **Additional context** 29 | Add any other context about the problem here. 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Please see the documentation for all configuration options: 2 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 3 | 4 | version: 2 5 | updates: 6 | - package-ecosystem: "cargo" # See documentation for possible values 7 | directory: "/" # Location of package manifests 8 | schedule: 9 | interval: "weekly" -------------------------------------------------------------------------------- /.github/workflows/rust-check.yml: -------------------------------------------------------------------------------- 1 | name: Rust Check 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v3 19 | 20 | - name: Install rust 21 | uses: actions-rs/toolchain@v1 22 | with: 23 | profile: minimal 24 | toolchain: stable 25 | override: true 26 | components: rustfmt 27 | 28 | - name: Check formatting 29 | run: cargo fmt -- --check 30 | - name: Check clippy 31 | run: cargo clippy --all-targets --all-features -- -D warnings 32 | - name: Run tests 33 | run: cargo test 34 | 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # Sensitive or local environment variables 17 | .env 18 | .envrc 19 | 20 | # Local test files 21 | request.json 22 | 23 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to cosmonaut-code 2 | 3 | Whether you want to report a bug, suggest a new feature, or contribute code, I appreciate your input. 4 | 5 | Before you start contributing, please take a moment to review the following guidelines. 6 | 7 | ## Code of Conduct 8 | 9 | I expect all contributors to abide by the [Code of Conduct](CODE_OF_CONDUCT.md) in all project-related interactions. 10 | 11 | ## Reporting Bugs 12 | 13 | If you encounter a bug while using cosmonaut-code, please [open an issue](https://github.com/avastmick/google-generative-ai-api-rs/issues/new) and provide as much information as possible, including: 14 | 15 | - A clear and descriptive title 16 | - A detailed description of the bug and the expected behavior 17 | - Steps to reproduce the bug 18 | - Any relevant error messages or screenshots 19 | 20 | ## Suggesting Features 21 | 22 | If you have an idea for a new feature or improvement, please [open an issue](https://github.com/avastmick/google-generative-ai-api-rs/issues/new) and provide as much information as possible, including: 23 | 24 | - A clear and descriptive title 25 | - A detailed description of the proposed feature or improvement 26 | - Any relevant examples or use cases 27 | 28 | ## Contributing Code 29 | 30 | If you want to contribute code to cosmonaut-code, please follow these steps: 31 | 32 | 1. [Fork](https://docs.github.com/en/get-started/quickstart/fork-a-repo) the repository to your Github account 33 | 2. Clone the forked repository to your local machine 34 | 3. Create a new branch for your changes 35 | 4. Make your changes, following our [code style guide](CODE_STYLE_GUIDE.md) 36 | 5. Commit your changes and push them to your forked repository 37 | 6. Ensure that the pipeline runs successfully 38 | 7. [Create a pull request](https://docs.github.com/en/github/collaborating-with-pull-requests/creating-a-pull-request) to the main repository 39 | 40 | When creating your pull request, please include: 41 | 42 | - A clear and descriptive title 43 | - A detailed description of the changes you made and the reasoning behind them 44 | - Any relevant screenshots or examples 45 | 46 | I will review your pull request as soon as possible and provide feedback. If your pull request requires any changes, we will let you know what needs to be done. 47 | 48 | ## Code Style Guide 49 | 50 | I follow the [Rust code style guide](https://doc.rust-lang.org/1.0.0/style/README.html) for all code contributed. 51 | 52 | Please ensure you run `cargo fmt --all --` frequently. 53 | 54 | Additionally, run `cargo clippy --all-targets --all-features -- -D warnings` similarly and resolve any issues you find with your code. 55 | 56 | I suggest using a pre-commit hook to do this automatically. 57 | 58 | ## License 59 | 60 | By contributing to google-generative-ai-rs, you agree that your contributions will be licensed under the [MIT License](LICENSE). 61 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "google-generative-ai-rs" 3 | version = "0.3.4" 4 | edition = "2021" 5 | authors = ["Mick Clarke "] 6 | license = "MIT" 7 | description = "An unofficial rust-based client library to interact with the Google Gemini generative AI API" 8 | repository = "https://github.com/avastmick/google-generative-ai-rs" 9 | readme = "README.md" 10 | keywords = ["google", "generative", "ai", "gemini", "client"] 11 | 12 | [features] 13 | beta = [] 14 | 15 | [dependencies] 16 | bytecount = "0.6.7" 17 | env_logger = { version = "0.11" } 18 | futures = { version = "0.3" } 19 | gcp_auth = { version = "0.12" } 20 | log = { version = "0.4.20" } 21 | reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } 22 | reqwest-streams = { version = "0.8.2", default-features = false, features = ["json"] } 23 | serde = { version = "1.0", features = ["derive"] } 24 | serde_json = { version = "1.0" } 25 | tokio = { version = "1.35", features = ["full"] } 26 | 27 | [package.metadata.docs.rs] 28 | all-features = true 29 | rustdoc-args = ["--cfg", "docsrs", "--generate-link-to-definition"] 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mick Clarke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Google Generative AI API client (unofficial) 2 | 3 | [![Rust Check](https://github.com/avastmick/google-generative-ai-rs/actions/workflows/rust-check.yml/badge.svg)](https://github.com/avastmick/google-generative-ai-rs/actions/workflows/rust-check.yml) 4 | An unofficial rust-based client library to interact with the Google generative AI API. 5 | 6 | The goal is to emulate the [Google AI Python SDK](https://github.com/google/generative-ai-python) but in Rust. 7 | 8 | 9 | ## Usage 10 | 11 | Start point, gain familiarity with Google's Gemini generative AI. 12 | 13 | - For the public Gemini endpoint, see the [Gemini API Overview docs](https://ai.google.dev/docs/gemini_api_overview) 14 | 15 | - Similarly, for the Vertex AI endpoint, see the [Vertex AI Gemini API docs](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#text_1) 16 | 17 | See [examples](examples) and follow the in-comment instructions. The code is (hopefully) easy and readable. 18 | 19 | ## Contributing 20 | 21 | Yes, please!! Create a fork and branch, make your contribution, and raise a PR. 22 | 23 | Please see [contributing](CONTRIBUTING.md) for the rules; they are standard though. 24 | 25 | ## Work status 26 | 27 | ``` 28 | google-generative-ai-rs = { version = "0.3.4", features = ["beta"] } 29 | ``` 30 | 31 | Using the `beta` feature will enable the following: 32 | 33 | - `gemini-1.5-pro-latest` 34 | - `gemini-1.0-pro` 35 | - `gemini-1.5-pro-latest")` 36 | - `gemini-1.5-flash")` 37 | - `"gemini-1.5-flash-8b")` 38 | - `gemini-2.0-flash-exp")` 39 | - or custom `Model::Custom(name)` 40 | - system instructions 41 | - `json_mode` 42 | 43 | Note: `gemini-1.0-pro` is deprecated and will be unavailable from 15th February 2025. 44 | 45 | I do my best to release working code. 46 | 47 | Status today is: *"Happy path for both public and Vertex AI endpoints work for Gemini."* 48 | 49 | ## Outline tasks 50 | 51 | - [X] Create request and response structs 52 | - [X] Create the public API happy path for Gemini 53 | - [X] Create the Vertex AI (private) API happy path for Gemini 54 | - [X] Create basic error handling 55 | - [X] get - see: "" and "" 56 | - [X] countTokens - see: "" 57 | - [ ] function - see "" 58 | - [ ] embedContent - see: "" 59 | -------------------------------------------------------------------------------- /examples/count_tokens.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use google_generative_ai_rs::v1::{ 4 | api::Client, 5 | gemini::{request::Request, Content, Part, ResponseType, Role}, 6 | }; 7 | use log::info; 8 | 9 | /// Counts the tokens used in a prompt using the public API and an API key for authn 10 | /// See: `https://ai.google.dev/tutorials/rest_quickstart#count_tokens` 11 | /// 12 | /// To run: 13 | /// ``` 14 | /// API_KEY=[YOUR_API_KEY] RUST_LOG=info cargo run --package google-generative-ai-rs --example count_tokens 15 | /// `` 16 | #[tokio::main] 17 | async fn main() -> Result<(), Box> { 18 | env_logger::init(); 19 | 20 | let client = Client::new_from_response_type( 21 | ResponseType::CountTokens, 22 | env::var("API_KEY").unwrap().to_string(), 23 | ); 24 | 25 | let txt_request = Request { 26 | contents: vec![Content { 27 | role: Role::User, 28 | parts: vec![Part { 29 | text: Some("Write a story about a magic backpack.".to_string()), 30 | inline_data: None, 31 | file_data: None, 32 | video_metadata: None, 33 | }], 34 | }], 35 | tools: vec![], 36 | safety_settings: vec![], 37 | generation_config: None, 38 | 39 | #[cfg(feature = "beta")] 40 | system_instruction: None, 41 | }; 42 | 43 | let response = client.post(30, &txt_request).await?; 44 | 45 | info!("{:#?}", response); 46 | 47 | Ok(()) 48 | } 49 | -------------------------------------------------------------------------------- /examples/get_model.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use google_generative_ai_rs::v1::{api::Client, gemini::ResponseType}; 4 | use log::info; 5 | 6 | /// Simple text request using the public API and an API key for authn 7 | /// To run: 8 | /// ``` 9 | /// API_KEY=[YOUR_API_KEY] RUST_LOG=info cargo run --package google-generative-ai-rs --example get_model 10 | /// `` 11 | #[tokio::main] 12 | async fn main() -> Result<(), Box> { 13 | env_logger::init(); 14 | 15 | let client = Client::new_from_response_type( 16 | ResponseType::GetModel, 17 | env::var("API_KEY").unwrap().to_string(), 18 | ); 19 | 20 | let response = client.get_model(30).await?; 21 | 22 | info!("{:#?}", response); 23 | 24 | Ok(()) 25 | } 26 | -------------------------------------------------------------------------------- /examples/get_models.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use google_generative_ai_rs::v1::{api::Client, gemini::ResponseType}; 4 | use log::info; 5 | 6 | /// Simple text request using the public API and an API key for authn 7 | /// To run: 8 | /// ``` 9 | /// API_KEY=[YOUR_API_KEY] RUST_LOG=info cargo run --package google-generative-ai-rs --example get_models 10 | /// `` 11 | #[tokio::main] 12 | async fn main() -> Result<(), Box> { 13 | env_logger::init(); 14 | 15 | let client = Client::new_from_response_type( 16 | ResponseType::GetModelList, 17 | env::var("API_KEY").unwrap().to_string(), 18 | ); 19 | 20 | let response = client.get_model_list(30).await?; 21 | 22 | info!("{:#?}", response); 23 | 24 | Ok(()) 25 | } 26 | -------------------------------------------------------------------------------- /examples/text_request.rs: -------------------------------------------------------------------------------- 1 | use log::info; 2 | use std::env; 3 | 4 | use google_generative_ai_rs::v1::{ 5 | api::Client, 6 | gemini::{request::Request, Content, Part, Role}, 7 | }; 8 | 9 | /// Simple text request using the public API and an API key for authn 10 | /// To run: 11 | /// ``` 12 | /// API_KEY=[YOUR_API_KEY] RUST_LOG=info cargo run --package google-generative-ai-rs --example text_request 13 | /// `` 14 | #[tokio::main] 15 | async fn main() -> Result<(), Box> { 16 | env_logger::init(); 17 | 18 | // Either run as a standard text request or a stream generate content request 19 | let client = Client::new(env::var("API_KEY").unwrap().to_string()); 20 | 21 | let txt_request = Request { 22 | contents: vec![Content { 23 | role: Role::User, 24 | parts: vec![Part { 25 | text: Some("Give me a recipe for banana bread.".to_string()), 26 | inline_data: None, 27 | file_data: None, 28 | video_metadata: None, 29 | }], 30 | }], 31 | tools: vec![], 32 | safety_settings: vec![], 33 | generation_config: None, 34 | 35 | #[cfg(feature = "beta")] 36 | system_instruction: None, 37 | }; 38 | 39 | let response = client.post(30, &txt_request).await?; 40 | 41 | info!("{:#?}", response); 42 | 43 | Ok(()) 44 | } 45 | -------------------------------------------------------------------------------- /examples/text_request_json.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "beta")] 2 | use std::env; 3 | 4 | #[cfg(feature = "beta")] 5 | use google_generative_ai_rs::v1::gemini::request::GenerationConfig; 6 | 7 | #[cfg(feature = "beta")] 8 | use google_generative_ai_rs::v1::{ 9 | api::Client, 10 | gemini::{request::Request, Content, Model, Part, Role}, 11 | }; 12 | 13 | /// JSON-based text request using the public API and an API key for authn 14 | /// 15 | /// NOTE: Currently, only available on the v1beta API. 16 | /// 17 | /// To run: 18 | /// ``` 19 | /// API_KEY=[YOUR_API_KEY] RUST_LOG=info cargo run -- features "beta" --package google-generative-ai-rs --example text_request_json 20 | /// `` 21 | #[tokio::main] 22 | async fn main() -> Result<(), Box> { 23 | env_logger::init(); 24 | 25 | #[cfg(not(feature = "beta"))] 26 | { 27 | log::error!("JSON-mode only works currently on Gemini 1.5 Pro and on 'beta'"); 28 | 29 | Ok(()) 30 | } 31 | 32 | #[cfg(feature = "beta")] 33 | { 34 | // Either run as a standard text request or a stream generate content request 35 | let client = Client::new_from_model( 36 | Model::Gemini1_5Pro, 37 | env::var("API_KEY").unwrap().to_string(), 38 | ); 39 | 40 | let prompt = r#"List 5 popular cookie recipes using this JSON schema: 41 | { "type": "object", "properties": { "recipe_name": { "type": "string" }}}"# 42 | .to_string(); 43 | 44 | log::info!("Prompt: {:#?}", prompt); 45 | 46 | let txt_request = Request { 47 | contents: vec![Content { 48 | role: Role::User, 49 | parts: vec![Part { 50 | text: Some(prompt), 51 | inline_data: None, 52 | file_data: None, 53 | video_metadata: None, 54 | }], 55 | }], 56 | tools: vec![], 57 | safety_settings: vec![], 58 | generation_config: Some(GenerationConfig { 59 | temperature: None, 60 | top_p: None, 61 | top_k: None, 62 | candidate_count: None, 63 | max_output_tokens: None, 64 | stop_sequences: None, 65 | response_mime_type: Some("application/json".to_string()), 66 | response_schema: None, 67 | }), 68 | 69 | system_instruction: None, 70 | }; 71 | 72 | let response = client.post(30, &txt_request).await?; 73 | 74 | log::info!("{:#?}", response); 75 | 76 | Ok(()) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /examples/text_request_stream.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::io::{stdout, Write}; 3 | 4 | use google_generative_ai_rs::v1::{ 5 | api::Client, 6 | gemini::{request::Request, response::GeminiResponse, Content, Part, ResponseType, Role}, 7 | }; 8 | 9 | /// Simple text request using the public API and an API key for authn 10 | /// To run: 11 | /// ``` 12 | /// API_KEY=[YOUR_API_KEY] RUST_LOG=info cargo run --package google-generative-ai-rs --example text_request 13 | /// `` 14 | #[tokio::main] 15 | async fn main() -> Result<(), Box> { 16 | env_logger::init(); 17 | 18 | let token = match env::var("API_KEY") { 19 | Ok(v) => v, 20 | Err(e) => { 21 | let msg = "$API_KEY not found".to_string(); 22 | panic!("{e:?}:{msg}"); 23 | } 24 | }; 25 | 26 | // Either run as a standard text request or a stream generate content request 27 | let client = Client::new_from_model_response_type( 28 | google_generative_ai_rs::v1::gemini::Model::Gemini1_0Pro, 29 | token.clone(), 30 | ResponseType::StreamGenerateContent, 31 | ); 32 | 33 | println!("token {:#?}", token); 34 | 35 | let txt_request = Request { 36 | contents: vec![Content { 37 | role: Role::User, 38 | parts: vec![Part { 39 | text: Some("Give me a recipe for banana bread.".to_string()), 40 | inline_data: None, 41 | file_data: None, 42 | video_metadata: None, 43 | }], 44 | }], 45 | tools: vec![], 46 | safety_settings: vec![], 47 | generation_config: None, 48 | 49 | #[cfg(feature = "beta")] 50 | system_instruction: None, 51 | }; 52 | 53 | let response = client.post(30, &txt_request).await?; 54 | 55 | println!("output streaming content"); 56 | 57 | if let Some(stream_response) = response.streamed() { 58 | if let Some(json_stream) = stream_response.response_stream { 59 | Client::for_each_async(json_stream, move |response: GeminiResponse| async move { 60 | let mut lock = stdout().lock(); 61 | write!( 62 | lock, 63 | "{}", 64 | response.candidates[0].content.parts[0] 65 | .text 66 | .clone() 67 | .unwrap() 68 | .as_str() 69 | ) 70 | .unwrap(); 71 | }) 72 | .await 73 | } 74 | } 75 | 76 | Ok(()) 77 | } 78 | -------------------------------------------------------------------------------- /examples/vertex_count_tokens.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use google_generative_ai_rs::v1::{ 4 | api::Client, 5 | gemini::{request::Request, Content, Part, ResponseType, Role}, 6 | }; 7 | use log::info; 8 | 9 | /// Counts the tokens used in a prompt using the public API and an API key for authn 10 | /// See: `https://ai.google.dev/tutorials/rest_quickstart#count_tokens` 11 | /// You'll need to install the GCP cli tools and set up your GCP project and region. 12 | /// 13 | /// The ensure you locally authenticated with GCP using the following commands: 14 | /// ``` 15 | /// gcloud init 16 | /// gcloud auth application-default login 17 | /// ``` 18 | /// 19 | /// To run: 20 | /// ``` 21 | /// GCP_REGION_NAME=[THE REGION WHERE YOUR ENDPOINT IS HOSTED] GCP_PROJECT_ID=[YOUR GCP PROJECT_ID] RUST_LOG=info cargo run --package google-generative-ai-rs --example vertex_count_tokens 22 | /// `` 23 | #[tokio::main] 24 | async fn main() -> Result<(), Box> { 25 | env_logger::init(); 26 | let region = env::var("GCP_REGION_NAME").unwrap().to_string(); 27 | let project_id = env::var("GCP_PROJECT_ID").unwrap().to_string(); 28 | 29 | let client = Client::new_from_region_project_id_response_type( 30 | region.to_string(), 31 | project_id.to_string(), 32 | ResponseType::CountTokens, 33 | ); 34 | 35 | let txt_request = Request { 36 | contents: vec![Content { 37 | role: Role::User, 38 | parts: vec![Part { 39 | text: Some("Write a story about a magic backpack.".to_string()), 40 | inline_data: None, 41 | file_data: None, 42 | video_metadata: None, 43 | }], 44 | }], 45 | tools: vec![], 46 | safety_settings: vec![], 47 | generation_config: None, 48 | 49 | #[cfg(feature = "beta")] 50 | system_instruction: None, 51 | }; 52 | 53 | let response = client.post(30, &txt_request).await?; 54 | 55 | info!("{:#?}", response); 56 | 57 | Ok(()) 58 | } 59 | -------------------------------------------------------------------------------- /examples/vertex_text_request.rs: -------------------------------------------------------------------------------- 1 | use log::info; 2 | use std::env; 3 | 4 | use google_generative_ai_rs::v1::{ 5 | api::Client, 6 | gemini::{request::Request, Content, Part, ResponseType, Role}, 7 | }; 8 | 9 | /// Simple text request using the public API and an API key for authn 10 | /// 11 | /// You'll need to install the GCP cli tools and set up your GCP project and region. 12 | /// 13 | /// The ensure you locally authenticated with GCP using the following commands: 14 | /// ``` 15 | /// gcloud init 16 | /// gcloud auth application-default login 17 | /// ``` 18 | /// 19 | /// To run: 20 | /// ``` 21 | /// GCP_REGION_NAME=[THE REGION WHERE YOUR ENDPOINT IS HOSTED] GCP_PROJECT_ID=[YOUR GCP PROJECT_ID] RUST_LOG=info cargo run --package google-generative-ai-rs --example vertex_text_request 22 | /// `` 23 | #[tokio::main] 24 | async fn main() -> Result<(), Box> { 25 | env_logger::init(); 26 | let region = env::var("GCP_REGION_NAME").unwrap().to_string(); 27 | let project_id = env::var("GCP_PROJECT_ID").unwrap().to_string(); 28 | 29 | let client = Client::new_from_region_project_id_response_type( 30 | region.to_string(), 31 | project_id.to_string(), 32 | ResponseType::GenerateContent, 33 | ); 34 | 35 | let txt_request = Request { 36 | contents: vec![Content { 37 | role: Role::User, 38 | parts: vec![Part { 39 | text: Some("Give me a recipe for banana bread.".to_string()), 40 | inline_data: None, 41 | file_data: None, 42 | video_metadata: None, 43 | }], 44 | }], 45 | tools: vec![], 46 | safety_settings: vec![], 47 | generation_config: None, 48 | 49 | #[cfg(feature = "beta")] 50 | system_instruction: None, 51 | }; 52 | 53 | let response = client.post(30, &txt_request).await?; 54 | 55 | info!("{:#?}", response); 56 | 57 | Ok(()) 58 | } 59 | -------------------------------------------------------------------------------- /examples/vertex_text_request_stream.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::io::{stdout, Write}; 3 | 4 | use google_generative_ai_rs::v1::{ 5 | api::Client, 6 | gemini::{request::Request, response::GeminiResponse, Content, Part, Role}, 7 | }; 8 | 9 | /// Streamed text request using Vertex AI API endpoint and GCP application default credentials (ADC) authn 10 | /// 11 | /// You'll need to install the GCP cli tools and set up your GCP project and region. 12 | /// 13 | /// The ensure you locally authenticated with GCP using the following commands: 14 | /// ``` 15 | /// gcloud init 16 | /// gcloud auth application-default login 17 | /// ``` 18 | /// To run: 19 | /// ``` 20 | /// GCP_REGION_NAME=[THE REGION WHERE YOUR ENDPOINT IS HOSTED] GCP_PROJECT_ID=[YOUR GCP PROJECT_ID] RUST_LOG=info cargo run --package google-generative-ai-rs --example vertex_text_request 21 | /// `` 22 | #[tokio::main] 23 | async fn main() -> Result<(), Box> { 24 | env_logger::init(); 25 | let region = env::var("GCP_REGION_NAME").unwrap().to_string(); 26 | let project_id = env::var("GCP_PROJECT_ID").unwrap().to_string(); 27 | 28 | let client = Client::new_from_region_project_id(region.to_string(), project_id.to_string()); 29 | 30 | let txt_request = Request { 31 | contents: vec![Content { 32 | role: Role::User, 33 | parts: vec![Part { 34 | text: Some("Give me a recipe for banana bread.".to_string()), 35 | inline_data: None, 36 | file_data: None, 37 | video_metadata: None, 38 | }], 39 | }], 40 | tools: vec![], 41 | safety_settings: vec![], 42 | generation_config: None, 43 | 44 | #[cfg(feature = "beta")] 45 | system_instruction: None, 46 | }; 47 | 48 | let response = client.post(30, &txt_request).await?; 49 | 50 | println!("output streaming content"); 51 | 52 | if let Some(stream_response) = response.streamed() { 53 | if let Some(json_stream) = stream_response.response_stream { 54 | Client::for_each_async(json_stream, move |response: GeminiResponse| async move { 55 | let mut lock = stdout().lock(); 56 | write!( 57 | lock, 58 | "{}", 59 | response.candidates[0].content.parts[0] 60 | .text 61 | .clone() 62 | .unwrap() 63 | .as_str() 64 | ) 65 | .unwrap(); 66 | }) 67 | .await 68 | } 69 | } 70 | 71 | Ok(()) 72 | } 73 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(docsrs, feature(doc_cfg))] 2 | 3 | pub mod v1; 4 | -------------------------------------------------------------------------------- /src/v1/api.rs: -------------------------------------------------------------------------------- 1 | //! Manages the interaction with the REST API for the Gemini API. 2 | use futures::prelude::*; 3 | use futures::stream::StreamExt; 4 | use reqwest::StatusCode; 5 | use reqwest_streams::error::StreamBodyError; 6 | use reqwest_streams::*; 7 | use serde_json; 8 | use std::pin::Pin; 9 | use std::sync::Arc; 10 | use std::time::Duration; 11 | use tokio::sync::Mutex; 12 | 13 | use crate::v1::errors::GoogleAPIError; 14 | use crate::v1::gemini::request::Request; 15 | use crate::v1::gemini::response::GeminiResponse; 16 | use crate::v1::gemini::Model; 17 | 18 | use super::gemini::response::{GeminiErrorResponse, StreamedGeminiResponse, TokenCount}; 19 | use super::gemini::{ModelInformation, ModelInformationList, ResponseType}; 20 | 21 | #[cfg(feature = "beta")] 22 | const PUBLIC_API_URL_BASE: &str = "https://generativelanguage.googleapis.com/v1beta"; 23 | 24 | #[cfg(not(feature = "beta"))] 25 | const PUBLIC_API_URL_BASE: &str = "https://generativelanguage.googleapis.com/v1"; 26 | 27 | /// Enables a streamed or non-streamed response to be returned from the API. 28 | #[derive(Debug)] 29 | pub enum PostResult { 30 | Rest(GeminiResponse), 31 | Streamed(StreamedGeminiResponse), 32 | Count(TokenCount), 33 | } 34 | impl PostResult { 35 | pub fn rest(self) -> Option { 36 | match self { 37 | PostResult::Rest(response) => Some(response), 38 | _ => None, 39 | } 40 | } 41 | pub fn streamed(self) -> Option { 42 | match self { 43 | PostResult::Streamed(streamed_response) => Some(streamed_response), 44 | _ => None, 45 | } 46 | } 47 | pub fn count(self) -> Option { 48 | match self { 49 | PostResult::Count(response) => Some(response), 50 | _ => None, 51 | } 52 | } 53 | } 54 | 55 | /// Manages the specific API connection 56 | pub struct Client { 57 | pub url: String, 58 | pub model: Model, 59 | pub region: Option, 60 | pub project_id: Option, 61 | pub response_type: ResponseType, 62 | } 63 | 64 | /// Implements the functions for the API client. 65 | /// TODO: This is getting unwieldy. We need to refactor this into a more manageable state. 66 | /// See Issue #26 - 'Code tidy and improvement' 67 | impl Client { 68 | /// Creates a default new public API client. 69 | pub fn new(api_key: String) -> Self { 70 | let url = Url::new(&Model::default(), api_key, &ResponseType::GenerateContent); 71 | Self { 72 | url: url.url, 73 | model: Model::default(), 74 | region: None, 75 | project_id: None, 76 | response_type: ResponseType::GenerateContent, 77 | } 78 | } 79 | 80 | /// Creates a default new public API client for a specified response type. 81 | pub fn new_from_response_type(response_type: ResponseType, api_key: String) -> Self { 82 | let url = Url::new(&Model::default(), api_key, &response_type); 83 | Self { 84 | url: url.url, 85 | model: Model::default(), 86 | region: None, 87 | project_id: None, 88 | response_type, 89 | } 90 | } 91 | 92 | /// Create a new public API client for a specified model. 93 | pub fn new_from_model(model: Model, api_key: String) -> Self { 94 | let url = Url::new(&model, api_key, &ResponseType::GenerateContent); 95 | Self { 96 | url: url.url, 97 | model, 98 | region: None, 99 | project_id: None, 100 | response_type: ResponseType::GenerateContent, 101 | } 102 | } 103 | 104 | /// Create a new public API client for a specified model. 105 | pub fn new_from_model_response_type( 106 | model: Model, 107 | api_key: String, 108 | response_type: ResponseType, 109 | ) -> Self { 110 | let url = Url::new(&model, api_key, &response_type); 111 | Self { 112 | url: url.url, 113 | model, 114 | region: None, 115 | project_id: None, 116 | response_type, 117 | } 118 | } 119 | 120 | // post 121 | pub async fn post( 122 | &self, 123 | timeout: u64, 124 | api_request: &Request, 125 | ) -> Result { 126 | let client: reqwest::Client = self.get_reqwest_client(timeout)?; 127 | match self.response_type { 128 | ResponseType::GenerateContent => { 129 | let result = self.get_post_result(client, api_request).await?; 130 | Ok(PostResult::Rest(result)) 131 | } 132 | ResponseType::StreamGenerateContent => { 133 | let result = self.get_streamed_post_result(client, api_request).await?; 134 | Ok(PostResult::Streamed(result)) 135 | } 136 | ResponseType::CountTokens => { 137 | let result = self.get_token_count(client, api_request).await?; 138 | Ok(PostResult::Count(result)) 139 | } 140 | _ => Err(GoogleAPIError { 141 | message: format!("Unsupported response type: {:?}", self.response_type), 142 | code: None, 143 | }), 144 | } 145 | } 146 | 147 | /// A standard post request, i.e., not streamed 148 | async fn get_post_result( 149 | &self, 150 | client: reqwest::Client, 151 | api_request: &Request, 152 | ) -> Result { 153 | let token_option = self.get_auth_token_option().await?; 154 | 155 | let result = self 156 | .get_post_response(client, api_request, token_option) 157 | .await; 158 | 159 | if let Ok(result) = result { 160 | match result.status() { 161 | reqwest::StatusCode::OK => { 162 | Ok(result.json::().await.map_err(|e|GoogleAPIError { 163 | message: format!( 164 | "Failed to deserialize API response into v1::gemini::response::GeminiResponse: {}", 165 | e 166 | ), 167 | code: None, 168 | })?) 169 | }, 170 | _ => { 171 | let status = result.status(); 172 | 173 | match result.json::().await { 174 | Ok(GeminiErrorResponse::Error { message, .. }) => Err(self.new_error_from_api_message(status, message)), 175 | Err(_) => Err(self.new_error_from_status_code(status)), 176 | } 177 | }, 178 | } 179 | } else { 180 | Err(self.new_error_from_reqwest_error(result.unwrap_err())) 181 | } 182 | } 183 | 184 | // Define the function that accepts the stream and the consumer 185 | /// A streamed post request 186 | async fn get_streamed_post_result( 187 | &self, 188 | client: reqwest::Client, 189 | api_request: &Request, 190 | ) -> Result { 191 | let token_option = self.get_auth_token_option().await?; 192 | 193 | let result = self 194 | .get_post_response(client, api_request, token_option) 195 | .await; 196 | 197 | match result { 198 | Ok(response) => match response.status() { 199 | reqwest::StatusCode::OK => { 200 | // Wire to enable introspection on the response stream 201 | let json_stream = response.json_array_stream::(2048); //TODO what is a good length?; 202 | 203 | Ok(StreamedGeminiResponse { 204 | response_stream: Some(json_stream), 205 | }) 206 | } 207 | _ => Err(self.new_error_from_status_code(response.status())), 208 | }, 209 | Err(e) => Err(self.new_error_from_reqwest_error(e)), 210 | } 211 | } 212 | 213 | /// Applies an asynchronous operation to each item in a stream, potentially concurrently. 214 | /// 215 | /// This function retrieves each item from the provided stream, processes it using the given 216 | /// consumer callback, and awaits the futures produced by the consumer. The concurrency level 217 | /// is unbounded, meaning items will be processed as soon as they are ready without a limit. 218 | /// 219 | /// # Type Parameters 220 | /// 221 | /// - `F`: The type of the consumer closure. It must accept a `GeminiResponse` and return a future. 222 | /// - `Fut`: The future type returned by the `consumer` closure. It must resolve to `()`. 223 | /// 224 | /// # Parameters 225 | /// 226 | /// - `stream`: A `Pin>` that produces items of type `Result`. 227 | /// The stream already needs to be pinned and boxed when passed into this function. 228 | /// - `consumer`: A mutable closure that is called for each `GeminiResponse`. The results of the 229 | /// closure are futures which will be awaited to completion. This closure needs to be `Send` and 230 | /// `'static` to allow for concurrent and potentially multi-threaded execution. 231 | pub async fn for_each_async( 232 | stream: Pin> + Send>>, 233 | consumer: F, 234 | ) where 235 | F: FnMut(GeminiResponse) -> Fut + Send + 'static, 236 | Fut: Future, 237 | { 238 | // Since the stream is already boxed and pinned, you can directly use it 239 | let consumer = Arc::new(Mutex::new(consumer)); 240 | 241 | // Use the for_each_concurrent method to apply the consumer to each item 242 | // in the stream, handling each item as it's ready. Set `None` for unbounded concurrency, 243 | // or set a limit with `Some(n)` 244 | 245 | stream 246 | .for_each_concurrent(None, |item: Result| { 247 | let consumer = Arc::clone(&consumer); 248 | async move { 249 | let res = match item { 250 | Ok(result) => { 251 | Client::convert_json_value_to_response(&result).map_err(|e| { 252 | GoogleAPIError { 253 | message: format!( 254 | "Failed to get JSON stream from request: {}", 255 | e 256 | ), 257 | code: None, 258 | } 259 | }) 260 | } 261 | Err(e) => Err(GoogleAPIError { 262 | message: format!("Failed to get JSON stream from request: {}", e), 263 | code: None, 264 | }), 265 | }; 266 | 267 | if let Ok(response) = res { 268 | let mut consumer = consumer.lock().await; 269 | consumer(response).await; 270 | } 271 | } 272 | }) 273 | .await; 274 | } 275 | 276 | /// Gets a ['reqwest::GeminiResponse'] from a post request. 277 | /// Parameters: 278 | /// * client - the ['reqwest::Client'] to use 279 | /// * api_request - the ['Request'] to send 280 | /// * authn_token - an optional authn token to use 281 | async fn get_post_response( 282 | &self, 283 | client: reqwest::Client, 284 | api_request: &Request, 285 | authn_token: Option, 286 | ) -> Result { 287 | let mut request_builder = client 288 | .post(&self.url) 289 | .header(reqwest::header::USER_AGENT, env!("CARGO_CRATE_NAME")) 290 | .header(reqwest::header::CONTENT_TYPE, "application/json"); 291 | 292 | // If a GCP authn token is provided, use it 293 | if let Some(token) = authn_token { 294 | request_builder = request_builder.bearer_auth(token); 295 | } 296 | 297 | request_builder.json(&api_request).send().await 298 | } 299 | // Count Tokens - see: "https://ai.google.dev/tutorials/rest_quickstart#count_tokens" 300 | // 301 | /// Parameters: 302 | /// * timeout - the timeout in seconds 303 | /// * api_request - the request to send to check token count 304 | pub async fn get_token_count( 305 | &self, 306 | client: reqwest::Client, 307 | api_request: &Request, 308 | ) -> Result { 309 | let token_option = self.get_auth_token_option().await?; 310 | 311 | let result = self 312 | .get_post_response(client, api_request, token_option) 313 | .await; 314 | 315 | match result { 316 | Ok(response) => match response.status() { 317 | reqwest::StatusCode::OK => Ok(response.json::().await.map_err(|e|GoogleAPIError { 318 | message: format!( 319 | "Failed to deserialize API response into v1::gemini::response::TokenCount: {}", 320 | e 321 | ), 322 | code: None, 323 | })?), 324 | _ => Err(self.new_error_from_status_code(response.status())), 325 | }, 326 | Err(e) => Err(self.new_error_from_reqwest_error(e)), 327 | } 328 | } 329 | 330 | /// Get for the url specified in 'self' 331 | async fn get( 332 | &self, 333 | timeout: u64, 334 | ) -> Result, GoogleAPIError> { 335 | let client: reqwest::Client = self.get_reqwest_client(timeout)?; 336 | let result = client 337 | .get(&self.url) 338 | .header(reqwest::header::USER_AGENT, env!("CARGO_CRATE_NAME")) 339 | .send() 340 | .await; 341 | Ok(result) 342 | } 343 | /// Gets a model - see: "https://ai.google.dev/tutorials/rest_quickstart#get_model" 344 | /// Parameters: 345 | /// * timeout - the timeout in seconds 346 | pub async fn get_model(&self, timeout: u64) -> Result { 347 | let result = self.get(timeout).await?; 348 | 349 | match result { 350 | Ok(response) => { 351 | match response.status() { 352 | reqwest::StatusCode::OK => Ok(response 353 | .json::() 354 | .await 355 | .map_err(|e| GoogleAPIError { 356 | message: format!( 357 | "Failed to deserialize API response into v1::gemini::ModelInformation: {}", 358 | e 359 | ), 360 | code: None, 361 | })?), 362 | _ => Err(self.new_error_from_status_code(response.status())), 363 | } 364 | } 365 | Err(e) => Err(self.new_error_from_reqwest_error(e)), 366 | } 367 | } 368 | /// Gets a list of models - see: "https://ai.google.dev/tutorials/rest_quickstart#list_models" 369 | /// Parameters: 370 | /// * timeout - the timeout in seconds 371 | pub async fn get_model_list( 372 | &self, 373 | timeout: u64, 374 | ) -> Result { 375 | let result = self.get(timeout).await?; 376 | 377 | match result { 378 | Ok(response) => { 379 | match response.status() { 380 | reqwest::StatusCode::OK => Ok(response 381 | .json::() 382 | .await 383 | .map_err(|e| GoogleAPIError { 384 | message: format!( 385 | "Failed to deserialize API response into Vec: {}", 386 | e 387 | ), 388 | code: None, 389 | })?), 390 | _ => Err(self.new_error_from_status_code(response.status())), 391 | } 392 | } 393 | Err(e) => Err(self.new_error_from_reqwest_error(e)), 394 | } 395 | } 396 | 397 | // TODO function - see "https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/function-calling" 398 | 399 | // TODO embedContent - see: "https://ai.google.dev/tutorials/rest_quickstart#embedding" 400 | 401 | /// The current version of the Vertex API only supports streamed responses, so 402 | /// in order to handle any issues we use a serde_json::Value and then convert to a Gemini [`Candidate`]. 403 | fn convert_json_value_to_response( 404 | json_value: &serde_json::Value, 405 | ) -> Result { 406 | serde_json::from_value(json_value.clone()) 407 | } 408 | 409 | fn get_reqwest_client(&self, timeout: u64) -> Result { 410 | let client: reqwest::Client = reqwest::Client::builder() 411 | .timeout(Duration::from_secs(timeout)) 412 | .build() 413 | .map_err(|e| self.new_error_from_reqwest_error(e.without_url()))?; 414 | Ok(client) 415 | } 416 | /// Creates a new error from a status code. 417 | fn new_error_from_status_code(&self, code: reqwest::StatusCode) -> GoogleAPIError { 418 | let status_text = code.canonical_reason().unwrap_or("Unknown Status"); 419 | let message = format!("HTTP Error: {}: {}", code.as_u16(), status_text); 420 | 421 | GoogleAPIError { 422 | message, 423 | code: Some(code), 424 | } 425 | } 426 | 427 | /// Creates a new error from a status code. 428 | fn new_error_from_api_message(&self, code: StatusCode, message: String) -> GoogleAPIError { 429 | let message = format!("API message: {message}."); 430 | 431 | GoogleAPIError { 432 | message, 433 | code: Some(code), 434 | } 435 | } 436 | 437 | /// Creates a new error from a reqwest error. 438 | fn new_error_from_reqwest_error(&self, mut e: reqwest::Error) -> GoogleAPIError { 439 | if let Some(url) = e.url_mut() { 440 | // Remove the API key from the URL, if any 441 | url.query_pairs_mut().clear(); 442 | } 443 | 444 | GoogleAPIError { 445 | message: format!("{}", e), 446 | code: e.status(), 447 | } 448 | } 449 | } 450 | 451 | /// There are two different URLs for the API, depending on whether the model is public or private. 452 | /// Authn for public models is via an API key, while authn for private models is via application default credentials (ADC). 453 | /// The public API URL is in the form of: https://generativelanguage.googleapis.com/v1/models/{model}:{generateContent|streamGenerateContent} 454 | /// The Vertex AI API URL is in the form of: https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/google/models/{model}:{streamGenerateContent} 455 | #[derive(Debug)] 456 | pub(crate) struct Url { 457 | pub url: String, 458 | } 459 | impl Url { 460 | pub(crate) fn new(model: &Model, api_key: String, response_type: &ResponseType) -> Self { 461 | let base_url = PUBLIC_API_URL_BASE.to_owned(); 462 | match response_type { 463 | ResponseType::GenerateContent => Self { 464 | url: format!( 465 | "{}/models/{}:{}?key={}", 466 | base_url, model, response_type, api_key 467 | ), 468 | }, 469 | ResponseType::StreamGenerateContent => Self { 470 | url: format!( 471 | "{}/models/{}:{}?key={}", 472 | base_url, model, response_type, api_key 473 | ), 474 | }, 475 | ResponseType::GetModel => Self { 476 | url: format!("{}/models/{}?key={}", base_url, model, api_key), 477 | }, 478 | ResponseType::GetModelList => Self { 479 | url: format!("{}/models?key={}", base_url, api_key), 480 | }, 481 | ResponseType::CountTokens => Self { 482 | url: format!( 483 | "{}/models/{}:{}?key={}", 484 | base_url, model, response_type, api_key 485 | ), 486 | }, 487 | _ => panic!("Unsupported response type: {:?}", response_type), 488 | } 489 | } 490 | } 491 | 492 | #[cfg(test)] 493 | mod tests { 494 | use super::*; 495 | use reqwest::StatusCode; 496 | 497 | #[test] 498 | fn test_new_error_from_status_code() { 499 | let client = Client::new("my-api-key".to_string()); 500 | let status_code = StatusCode::BAD_REQUEST; 501 | 502 | let error = client.new_error_from_status_code(status_code); 503 | 504 | assert_eq!(error.message, "HTTP Error: 400: Bad Request"); 505 | assert_eq!(error.code, Some(status_code)); 506 | } 507 | 508 | #[test] 509 | fn test_url_new() { 510 | let model = Model::default(); 511 | let api_key = String::from("my-api-key"); 512 | let url = Url::new(&model, api_key.clone(), &ResponseType::GenerateContent); 513 | 514 | assert_eq!( 515 | url.url, 516 | format!( 517 | "{}/models/{}:generateContent?key={}", 518 | PUBLIC_API_URL_BASE, model, api_key 519 | ) 520 | ); 521 | } 522 | } 523 | -------------------------------------------------------------------------------- /src/v1/errors.rs: -------------------------------------------------------------------------------- 1 | use reqwest::StatusCode; 2 | use std::error::Error; 3 | use std::fmt; 4 | 5 | #[derive(Debug)] 6 | pub struct GoogleAPIError { 7 | pub message: String, 8 | pub code: Option, 9 | } 10 | impl fmt::Display for GoogleAPIError { 11 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 12 | write!( 13 | f, 14 | "GoogleAPIError - code: {:?} error: {}", 15 | self.code, self.message 16 | ) 17 | } 18 | } 19 | impl Error for GoogleAPIError {} 20 | -------------------------------------------------------------------------------- /src/v1/gemini.rs: -------------------------------------------------------------------------------- 1 | //! Handles the text interaction with the API 2 | use core::fmt; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | use self::request::{FileData, InlineData, VideoMetadata}; 6 | /// Defines the type of response expected from the API. 7 | /// Used at the end of the API URL for the Gemini API. 8 | #[derive(Debug, Clone, Default, PartialEq)] 9 | pub enum ResponseType { 10 | #[default] 11 | GenerateContent, 12 | StreamGenerateContent, 13 | GetModel, 14 | GetModelList, 15 | CountTokens, 16 | EmbedContent, 17 | BatchEmbedContents, 18 | } 19 | impl fmt::Display for ResponseType { 20 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 21 | match self { 22 | ResponseType::GenerateContent => f.write_str("generateContent"), 23 | ResponseType::StreamGenerateContent => f.write_str("streamGenerateContent"), 24 | ResponseType::GetModel => f.write_str(""), // No display as its already in the URL 25 | ResponseType::GetModelList => f.write_str(""), // No display as its already in the URL 26 | ResponseType::CountTokens => f.write_str("countTokens"), 27 | ResponseType::EmbedContent => f.write_str("embedContent"), 28 | ResponseType::BatchEmbedContents => f.write_str("batchEmbedContents"), 29 | } 30 | } 31 | } 32 | /// Captures the information for a specific Google generative AI model. 33 | /// 34 | /// ```json 35 | /// { 36 | /// "name": "models/gemini-pro", 37 | /// "version": "001", 38 | /// "displayName": "Gemini Pro", 39 | /// "description": "The best model for scaling across a wide range of tasks", 40 | /// "inputTokenLimit": 30720, 41 | /// "outputTokenLimit": 2048, 42 | /// "supportedGenerationMethods": [ 43 | /// "generateContent", 44 | /// "countTokens" 45 | /// ], 46 | /// "temperature": 0.9, 47 | /// "topP": 1, 48 | /// "topK": 100, 49 | /// } 50 | /// ``` 51 | #[derive(Debug, Default, Deserialize)] 52 | #[serde(rename_all = "camelCase")] 53 | #[serde(rename = "model")] 54 | pub struct ModelInformation { 55 | pub name: String, 56 | pub version: String, 57 | pub display_name: String, 58 | pub description: String, 59 | pub input_token_limit: i32, 60 | pub output_token_limit: i32, 61 | pub supported_generation_methods: Vec, 62 | pub temperature: Option, 63 | #[serde(skip_serializing_if = "Option::is_none")] 64 | pub top_p: Option, 65 | #[serde(skip_serializing_if = "Option::is_none")] 66 | pub top_k: Option, 67 | } 68 | /// Lists the available models for the Gemini API. 69 | #[derive(Debug, Default, Deserialize)] 70 | #[serde(rename = "models")] 71 | pub struct ModelInformationList { 72 | pub models: Vec, 73 | } 74 | 75 | #[derive(Debug, Clone, Default, PartialEq, Serialize)] 76 | #[serde(rename_all = "kebab-case")] 77 | pub enum Model { 78 | #[default] 79 | Gemini1_0Pro, 80 | #[cfg(feature = "beta")] 81 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 82 | Gemini1_5Pro, 83 | #[cfg(feature = "beta")] 84 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 85 | Gemini1_5Flash, 86 | #[cfg(feature = "beta")] 87 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 88 | Gemini1_5Flash8B, 89 | #[cfg(feature = "beta")] 90 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 91 | Gemini2_0Flash, 92 | #[cfg(feature = "beta")] 93 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 94 | Custom(String), 95 | // TODO: Embedding004 96 | } 97 | impl fmt::Display for Model { 98 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 99 | match self { 100 | Model::Gemini1_0Pro => write!(f, "gemini-1.0-pro"), 101 | 102 | #[cfg(feature = "beta")] 103 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 104 | Model::Gemini1_5Pro => write!(f, "gemini-1.5-pro-latest"), 105 | #[cfg(feature = "beta")] 106 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 107 | Model::Gemini1_5Flash => write!(f, "gemini-1.5-flash"), 108 | #[cfg(feature = "beta")] 109 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 110 | Model::Gemini1_5Flash8B => write!(f, "gemini-1.5-flash-8b"), 111 | 112 | #[cfg(feature = "beta")] 113 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 114 | Model::Gemini2_0Flash => write!(f, "gemini-2.0-flash-exp"), 115 | 116 | #[cfg(feature = "beta")] 117 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 118 | Model::Custom(name) => write!(f, "{}", name), 119 | // TODO: Model::Embedding004 => write!(f, "text-embedding-004"), 120 | } 121 | } 122 | } 123 | 124 | #[derive(Debug, Clone, Deserialize, Serialize)] 125 | pub struct Content { 126 | pub role: Role, 127 | #[serde(default)] 128 | pub parts: Vec, 129 | } 130 | 131 | #[derive(Debug, Clone, Deserialize, Serialize)] 132 | #[serde(rename_all = "camelCase")] 133 | pub struct Part { 134 | #[serde(skip_serializing_if = "Option::is_none")] 135 | pub text: Option, 136 | #[serde(skip_serializing_if = "Option::is_none")] 137 | pub inline_data: Option, 138 | #[serde(skip_serializing_if = "Option::is_none")] 139 | pub file_data: Option, 140 | #[serde(skip_serializing_if = "Option::is_none")] 141 | pub video_metadata: Option, 142 | } 143 | 144 | #[derive(Debug, Clone, Deserialize, Serialize)] 145 | #[serde(rename_all = "lowercase")] 146 | pub enum Role { 147 | User, 148 | Model, 149 | } 150 | 151 | /// The request format follows the following structure: 152 | /// ```json 153 | /// { 154 | /// "contents": [ 155 | /// { 156 | /// "role": string, 157 | /// "parts": [ 158 | /// { 159 | /// /// Union field data can be only one of the following: 160 | /// "text": string, 161 | /// "inlineData": { 162 | /// "mimeType": string, 163 | /// "data": string 164 | /// }, 165 | /// "fileData": { 166 | /// "mimeType": string, 167 | /// "fileUri": string 168 | /// }, 169 | /// /// End of list of possible types for union field data. 170 | /// "videoMetadata": { 171 | /// "startOffset": { 172 | /// "seconds": integer, 173 | /// "nanos": integer 174 | /// }, 175 | /// "endOffset": { 176 | /// "seconds": integer, 177 | /// "nanos": integer 178 | /// } 179 | /// } 180 | /// } 181 | /// ] 182 | /// } 183 | /// ], 184 | /// "tools": [ 185 | /// { 186 | /// "functionDeclarations": [ 187 | /// { 188 | /// "name": string, 189 | /// "description": string, 190 | /// "parameters": { 191 | /// object (OpenAPI Object Schema) 192 | /// } 193 | /// } 194 | /// ] 195 | /// } 196 | /// ], 197 | /// "safetySettings": [ 198 | /// { 199 | /// "category": enum (HarmCategory), 200 | /// "threshold": enum (HarmBlockThreshold) 201 | /// } 202 | /// ], 203 | /// "generationConfig": { 204 | /// "temperature": number, 205 | /// "topP": number, 206 | /// "topK": number, 207 | /// "candidateCount": integer, 208 | /// "maxOutputTokens": integer, 209 | /// "stopSequences": [ 210 | /// string 211 | /// ] 212 | /// } 213 | /// } 214 | /// ``` 215 | /// See https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini 216 | pub mod request { 217 | use serde::{Deserialize, Serialize}; 218 | 219 | use super::{ 220 | safety::{HarmBlockThreshold, HarmCategory}, 221 | Content, 222 | }; 223 | 224 | /// Holds the data to be used for a specific text request 225 | #[derive(Debug, Clone, Deserialize, Serialize)] 226 | pub struct Request { 227 | pub contents: Vec, 228 | #[serde(skip_serializing_if = "Vec::is_empty")] 229 | pub tools: Vec, 230 | #[serde(skip_serializing_if = "Vec::is_empty")] 231 | #[serde(default, rename = "safetySettings")] 232 | pub safety_settings: Vec, 233 | #[serde(skip_serializing_if = "Option::is_none")] 234 | #[serde(default, rename = "generationConfig")] 235 | pub generation_config: Option, 236 | 237 | #[cfg(feature = "beta")] 238 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 239 | #[serde(skip_serializing_if = "Option::is_none")] 240 | #[serde(default, rename = "system_instruction")] 241 | pub system_instruction: Option, 242 | } 243 | impl Request { 244 | pub fn new( 245 | contents: Vec, 246 | tools: Vec, 247 | safety_settings: Vec, 248 | generation_config: Option, 249 | ) -> Self { 250 | Request { 251 | contents, 252 | tools, 253 | safety_settings, 254 | generation_config, 255 | #[cfg(feature = "beta")] 256 | system_instruction: None, 257 | } 258 | } 259 | 260 | #[cfg(feature = "beta")] 261 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 262 | pub fn set_system_instruction(&mut self, instruction: SystemInstructionContent) { 263 | self.system_instruction = Some(instruction); 264 | } 265 | 266 | /// Gets the total character count of the prompt. 267 | /// As per the Gemini API, "Text input is charged by every 1,000 characters of input (prompt). 268 | /// Characters are counted by UTF-8 code points and white space is excluded from the count." 269 | /// See: https://cloud.google.com/vertex-ai/pricing 270 | /// 271 | /// Returns the total character count of the prompt as per the Gemini API. 272 | pub fn get_prompt_character_count(&self) -> usize { 273 | let mut text_count = 0; 274 | for content in &self.contents { 275 | for part in &content.parts { 276 | if let Some(text) = &part.text { 277 | // Exclude white space from the count 278 | let num_chars = bytecount::num_chars(text.as_bytes()); 279 | let num_spaces = bytecount::count(text.as_bytes(), b' '); 280 | text_count += num_chars - num_spaces; 281 | } 282 | } 283 | } 284 | text_count 285 | } 286 | } 287 | #[derive(Debug, Clone, Deserialize, Serialize)] 288 | #[serde(rename_all = "camelCase")] 289 | pub struct InlineData { 290 | pub mime_type: String, 291 | pub data: String, 292 | } 293 | #[derive(Debug, Clone, Deserialize, Serialize)] 294 | #[serde(rename_all = "camelCase")] 295 | pub struct FileData { 296 | pub mime_type: String, 297 | pub file_uri: String, 298 | } 299 | #[derive(Debug, Clone, Deserialize, Serialize)] 300 | #[serde(rename_all = "camelCase")] 301 | pub struct VideoMetadata { 302 | pub start_offset: StartOffset, 303 | pub end_offset: EndOffset, 304 | } 305 | #[derive(Debug, Clone, Deserialize, Serialize)] 306 | pub struct StartOffset { 307 | pub seconds: i32, 308 | pub nanos: i32, 309 | } 310 | #[derive(Debug, Clone, Deserialize, Serialize)] 311 | pub struct EndOffset { 312 | pub seconds: i32, 313 | pub nanos: i32, 314 | } 315 | #[derive(Debug, Clone, Deserialize, Serialize)] 316 | pub struct Tools { 317 | #[serde(rename = "functionDeclarations")] 318 | pub function_declarations: Vec, 319 | } 320 | 321 | #[derive(Debug, Clone, Deserialize, Serialize)] 322 | pub struct FunctionDeclaration { 323 | pub name: String, 324 | pub description: String, 325 | pub parameters: serde_json::Value, 326 | } 327 | 328 | #[derive(Debug, Clone, Deserialize, Serialize)] 329 | pub struct SafetySettings { 330 | pub category: HarmCategory, 331 | pub threshold: HarmBlockThreshold, 332 | } 333 | #[derive(Debug, Clone, Deserialize, Serialize)] 334 | #[serde(rename_all = "camelCase")] 335 | pub struct GenerationConfig { 336 | pub temperature: Option, 337 | pub top_p: Option, 338 | pub top_k: Option, 339 | pub candidate_count: Option, 340 | pub max_output_tokens: Option, 341 | pub stop_sequences: Option>, 342 | 343 | #[cfg(feature = "beta")] 344 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 345 | pub response_mime_type: Option, 346 | 347 | #[cfg(feature = "beta")] 348 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 349 | pub response_schema: Option, 350 | } 351 | 352 | #[cfg(feature = "beta")] 353 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 354 | #[derive(Debug, Clone, Deserialize, Serialize)] 355 | pub struct SystemInstructionContent { 356 | #[serde(default)] 357 | pub parts: Vec, 358 | } 359 | 360 | #[cfg(feature = "beta")] 361 | #[cfg_attr(docsrs, doc(cfg(feature = "beta")))] 362 | #[derive(Debug, Clone, Deserialize, Serialize)] 363 | #[serde(rename_all = "camelCase")] 364 | pub struct SystemInstructionPart { 365 | #[serde(skip_serializing_if = "Option::is_none")] 366 | pub text: Option, 367 | } 368 | } 369 | 370 | /// The response format follows the following structure: 371 | /// ```json 372 | /// { 373 | /// "candidates": [ 374 | /// { 375 | /// "content": { 376 | /// "parts": [ 377 | /// { 378 | /// "text": string 379 | /// } 380 | /// ] 381 | /// }, 382 | /// "finishReason": enum (FinishReason), 383 | /// "safetyRatings": [ 384 | /// { 385 | /// "category": enum (HarmCategory), 386 | /// "probability": enum (HarmProbability), 387 | /// "blocked": boolean 388 | /// } 389 | /// ], 390 | /// "citationMetadata": { 391 | /// "citations": [ 392 | /// { 393 | /// "startIndex": integer, 394 | /// "endIndex": integer, 395 | /// "uri": string, 396 | /// "title": string, 397 | /// "license": string, 398 | /// "publicationDate": { 399 | /// "year": integer, 400 | /// "month": integer, 401 | /// "day": integer 402 | /// } 403 | /// } 404 | /// ] 405 | /// } 406 | /// } 407 | /// ], 408 | /// "usageMetadata": { 409 | /// "promptTokenCount": integer, 410 | /// "candidatesTokenCount": integer, 411 | /// "totalTokenCount": integer 412 | /// } 413 | /// } 414 | /// ``` 415 | pub mod response { 416 | use core::fmt; 417 | use futures::Stream; 418 | use reqwest_streams::error::StreamBodyError; 419 | use serde::Deserialize; 420 | use std::pin::Pin; 421 | 422 | use super::{ 423 | safety::{HarmCategory, HarmProbability}, 424 | Content, 425 | }; 426 | 427 | impl fmt::Debug for StreamedGeminiResponse { 428 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 429 | write!(f, "StreamedGeminiResponse {{ /* stream values */ }}") 430 | } 431 | } 432 | 433 | type ResponseJsonStream = 434 | Pin> + Send>>; 435 | 436 | /// The token count for a given prompt. 437 | #[derive(Debug, Default, Deserialize)] 438 | #[serde(rename_all = "camelCase")] 439 | pub struct TokenCount { 440 | pub total_tokens: u64, 441 | } 442 | 443 | // The streamGenerateContent response 444 | #[derive(Default)] 445 | pub struct StreamedGeminiResponse { 446 | pub response_stream: Option, 447 | } 448 | 449 | #[derive(Debug, Clone, Deserialize)] 450 | #[serde(rename_all = "camelCase")] 451 | pub struct GeminiResponse { 452 | pub candidates: Vec, 453 | pub prompt_feedback: Option, 454 | pub usage_metadata: Option, 455 | } 456 | #[derive(Debug, Clone, Deserialize)] 457 | #[serde(rename_all = "camelCase")] 458 | pub enum GeminiErrorResponse { 459 | Error { 460 | code: u16, 461 | message: String, 462 | status: String, 463 | }, 464 | } 465 | 466 | impl GeminiResponse { 467 | /// Returns the total character count of the response as per the Gemini API. 468 | pub fn get_response_character_count(&self) -> usize { 469 | let mut text_count = 0; 470 | for candidate in &self.candidates { 471 | for content in &candidate.content.parts { 472 | if let Some(text) = &content.text { 473 | // Exclude white space from the count 474 | let num_chars = bytecount::num_chars(text.as_bytes()); 475 | let num_spaces = bytecount::count(text.as_bytes(), b' '); 476 | text_count += num_chars - num_spaces; 477 | } 478 | } 479 | } 480 | text_count 481 | } 482 | } 483 | #[derive(Debug, Clone, Deserialize)] 484 | #[serde(rename_all = "camelCase")] 485 | pub struct Candidate { 486 | pub content: Content, 487 | pub finish_reason: Option, 488 | pub index: Option, 489 | #[serde(default)] 490 | pub safety_ratings: Vec, 491 | } 492 | #[derive(Debug, Clone, Deserialize)] 493 | #[serde(rename_all = "camelCase")] 494 | pub struct UsageMetadata { 495 | pub prompt_token_count: u64, 496 | pub candidates_token_count: u64, 497 | } 498 | #[derive(Debug, Clone, Deserialize)] 499 | pub struct PromptFeedback { 500 | #[serde(rename = "safetyRatings")] 501 | pub safety_ratings: Vec, 502 | } 503 | 504 | #[derive(Debug, Clone, Deserialize)] 505 | pub struct SafetyRating { 506 | pub category: HarmCategory, 507 | pub probability: HarmProbability, 508 | #[serde(default)] 509 | pub blocked: bool, 510 | } 511 | 512 | /// The reason why the model stopped generating tokens. If empty, the model has not stopped generating the tokens. 513 | #[derive(Debug, Clone, Deserialize)] 514 | #[serde(rename_all = "SCREAMING_SNAKE_CASE")] 515 | pub enum FinishReason { 516 | FinishReasonUnspecified, // The finish reason is unspecified. 517 | FinishReasonStop, // Natural stop point of the model or provided stop sequence. 518 | FinishReasonMaxTokens, // The maximum number of tokens as specified in the request was reached. 519 | FinishReasonSafety, // The token generation was stopped as the response was flagged for safety reasons. Note that [`Candidate`].content is empty if content filters block the output. 520 | FinishReasonRecitation, // The token generation was stopped as the response was flagged for unauthorized citations. 521 | FinishReasonOther, // All other reasons that stopped the token 522 | } 523 | #[cfg(test)] 524 | mod tests {} 525 | } 526 | 527 | /// The safety data for HarmCategory, HarmBlockThreshold and HarmProbability 528 | pub mod safety { 529 | use serde::{Deserialize, Serialize}; 530 | 531 | /// The safety category to configure a threshold for. 532 | #[derive(Debug, Clone, Deserialize, Serialize)] 533 | #[serde(rename_all = "SCREAMING_SNAKE_CASE")] 534 | pub enum HarmCategory { 535 | HarmCategorySexuallyExplicit, 536 | HarmCategoryHateSpeech, 537 | HarmCategoryHarassment, 538 | HarmCategoryDangerousContent, 539 | } 540 | /// For a request: the safety category to configure a threshold for. For a response: the harm probability levels in the content. 541 | #[derive(Debug, Clone, Deserialize, Serialize)] 542 | #[serde(rename_all = "SCREAMING_SNAKE_CASE")] 543 | pub enum HarmProbability { 544 | HarmProbabilityUnspecified, 545 | Negligible, 546 | Low, 547 | Medium, 548 | High, 549 | } 550 | /// The threshold for blocking responses that could belong to the specified safety category based on probability. 551 | #[derive(Debug, Clone, Deserialize, Serialize)] 552 | #[serde(rename_all = "SCREAMING_SNAKE_CASE")] 553 | pub enum HarmBlockThreshold { 554 | BlockNone, 555 | BlockLowAndAbove, 556 | BlockMedAndAbove, 557 | BlockHighAndAbove, 558 | } 559 | } 560 | -------------------------------------------------------------------------------- /src/v1/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod api; 2 | pub mod errors; 3 | pub mod gemini; 4 | pub mod vertexai; 5 | -------------------------------------------------------------------------------- /src/v1/vertexai.rs: -------------------------------------------------------------------------------- 1 | //! Contains logic and types specific to the Vertex AI endpoint (opposed to the public Gemini API endpoint) 2 | use std::{fmt, sync::Arc}; 3 | 4 | use super::{ 5 | api::{Client, Url}, 6 | gemini::{Model, ResponseType}, 7 | }; 8 | use crate::v1::errors::GoogleAPIError; 9 | 10 | const VERTEX_AI_API_URL_BASE: &str = "https://{region}-aiplatform.googleapis.com/v1"; 11 | 12 | const GCP_API_AUTH_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform"; 13 | 14 | impl Client { 15 | /// Create a new private API client (Vertex AI) using the default model, `Gemini-pro`. 16 | /// 17 | /// Parameters: 18 | /// * region - the GCP region to use 19 | /// * project_id - the GCP account project_id to use 20 | pub fn new_from_region_project_id(region: String, project_id: String) -> Self { 21 | Client::new_from_region_project_id_response_type( 22 | region, 23 | project_id, 24 | ResponseType::StreamGenerateContent, 25 | ) 26 | } 27 | pub fn new_from_region_project_id_response_type( 28 | region: String, 29 | project_id: String, 30 | response_type: ResponseType, 31 | ) -> Self { 32 | let url = Url::new_from_region_project_id( 33 | &Model::default(), 34 | region.clone(), 35 | project_id.clone(), 36 | &response_type, 37 | ); 38 | Self { 39 | url: url.url, 40 | model: Model::default(), 41 | region: Some(region), 42 | project_id: Some(project_id), 43 | response_type, 44 | } 45 | } 46 | /// Create a new private API client. 47 | /// Parameters: 48 | /// * model - the Gemini model to use 49 | /// * region - the GCP region to use 50 | /// * project_id - the GCP account project_id to use 51 | pub fn new_from_model_region_project_id( 52 | model: Model, 53 | region: String, 54 | project_id: String, 55 | ) -> Self { 56 | let url = Url::new_from_region_project_id( 57 | &model, 58 | region.clone(), 59 | project_id.clone(), 60 | &ResponseType::StreamGenerateContent, 61 | ); 62 | Self { 63 | url: url.url, 64 | model, 65 | region: Some(region), 66 | project_id: Some(project_id), 67 | response_type: ResponseType::StreamGenerateContent, 68 | } 69 | } 70 | 71 | /// If this is a Vertex AI request, get the token from the GCP authn library, if it is correctly configured, else None. 72 | pub(crate) async fn get_auth_token_option(&self) -> Result, GoogleAPIError> { 73 | let token_option = if self.project_id.is_some() && self.region.is_some() { 74 | let token = self.get_gcp_authn_token().await?.as_str().to_string(); 75 | Some(token) 76 | } else { 77 | None 78 | }; 79 | Ok(token_option) 80 | } 81 | /// Gets a GCP authn token. 82 | async fn get_gcp_authn_token(&self) -> Result, GoogleAPIError> { 83 | let provider = gcp_auth::provider().await.map_err(|e| GoogleAPIError { 84 | message: format!("Failed to create AuthenticationManager: {}", e), 85 | code: None, 86 | })?; 87 | let scopes = &[GCP_API_AUTH_SCOPE]; 88 | let token = provider.token(scopes).await.map_err(|e| GoogleAPIError { 89 | message: format!("Failed to generate authentication token: {}", e), 90 | code: None, 91 | })?; 92 | Ok(token) 93 | } 94 | } 95 | /// Ensuring there is no leakage of secrets 96 | impl fmt::Display for Client { 97 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 98 | if self.region.is_some() && self.project_id.is_some() { 99 | write!( 100 | f, 101 | "GenerativeAiClient {{ url: {:?}, model: {:?}, region: {:?}, project_id: {:?} }}", 102 | self.url, self.model, self.region, self.project_id 103 | ) 104 | } else { 105 | write!( 106 | f, 107 | "GenerativeAiClient {{ url: {:?}, model: {:?}, region: {:?}, project_id: {:?} }}", 108 | Url::new( 109 | &self.model, 110 | "*************".to_string(), 111 | &self.response_type 112 | ), 113 | self.model, 114 | self.region, 115 | self.project_id 116 | ) 117 | } 118 | } 119 | } 120 | 121 | impl Url { 122 | pub(crate) fn new_from_region_project_id( 123 | model: &Model, 124 | region: String, 125 | project_id: String, 126 | response_type: &ResponseType, 127 | ) -> Self { 128 | let base_url = VERTEX_AI_API_URL_BASE 129 | .to_owned() 130 | .replace("{region}", ®ion); 131 | 132 | let url = format!( 133 | "{}/projects/{}/locations/{}/publishers/google/models/{}:{}", 134 | base_url, project_id, region, model, response_type, 135 | ); 136 | Self { url } 137 | } 138 | } 139 | #[cfg(test)] 140 | mod tests { 141 | use crate::v1::{ 142 | api::{Client, Url}, 143 | gemini::{Model, ResponseType}, 144 | }; 145 | 146 | use super::*; 147 | 148 | #[test] 149 | fn test_new_from_region_project_id() { 150 | let region = String::from("us-central1"); 151 | let project_id = String::from("my-project"); 152 | let client = Client::new_from_region_project_id(region.clone(), project_id.clone()); 153 | 154 | assert_eq!(client.region, Some(region)); 155 | assert_eq!(client.project_id, Some(project_id)); 156 | } 157 | 158 | #[test] 159 | fn test_new_from_model_region_project_id() { 160 | let model = Model::default(); 161 | let region = String::from("us-central1"); 162 | let project_id = String::from("my-project"); 163 | let client = Client::new_from_model_region_project_id( 164 | model.clone(), 165 | region.clone(), 166 | project_id.clone(), 167 | ); 168 | 169 | assert_eq!(client.model, model); 170 | assert_eq!(client.region, Some(region)); 171 | assert_eq!(client.project_id, Some(project_id)); 172 | } 173 | 174 | #[test] 175 | fn test_url_new_from_region_project_id() { 176 | let model = Model::default(); 177 | let region = String::from("us-central1"); 178 | let project_id = String::from("my-project"); 179 | let url = Url::new_from_region_project_id( 180 | &model, 181 | region.clone(), 182 | project_id.clone(), 183 | &ResponseType::StreamGenerateContent, 184 | ); 185 | 186 | assert_eq!( 187 | url.url, 188 | format!( 189 | "{}/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent", 190 | VERTEX_AI_API_URL_BASE.replace("{region}", ®ion), 191 | project_id, 192 | region, 193 | model 194 | ) 195 | ); 196 | } 197 | } 198 | --------------------------------------------------------------------------------