├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ └── feature_request.yml ├── dependabot.yml └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── chat_cli.rs ├── chat_simple.rs ├── chat_stream_cli.rs └── completions_cli.rs ├── src ├── chat.rs ├── completions.rs ├── edits.rs ├── embeddings.rs ├── files.rs ├── lib.rs ├── models.rs └── moderations.rs └── test_data └── file_upload_test1.jsonl /.editorconfig: -------------------------------------------------------------------------------- 1 | [*.yml] 2 | indent_size = 2 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report a bug you've encountered. 3 | body: 4 | - type: textarea 5 | attributes: 6 | label: Bug Description 7 | description: A clear and concise description of what the bug is. 8 | validations: 9 | required: true 10 | - type: textarea 11 | attributes: 12 | label: Terminal Output 13 | description: (If applicable.) Automatically formatted as a code block. 14 | render: text 15 | - type: textarea 16 | attributes: 17 | label: Reproduction Steps 18 | description: Steps to reproduce the behavior. 19 | value: | 20 | 1. Do this 21 | 2. Do that 22 | 3. ... 23 | 4. See error 24 | validations: 25 | required: true 26 | - type: textarea 27 | attributes: 28 | label: Expected Behavior 29 | validations: 30 | required: true 31 | - type: textarea 32 | attributes: 33 | label: Workaround 34 | description: If you've found a temporary workaround, please describe it here. 35 | - type: input 36 | attributes: 37 | label: Library Version 38 | placeholder: 1.0.0-alpha 39 | validations: 40 | required: true 41 | - type: input 42 | attributes: 43 | label: rustc Version 44 | description: "Found using `rustc --version`." 45 | placeholder: 1.66.0 (69f9c33d7 2022-12-12) 46 | validations: 47 | required: true 48 | - type: input 49 | attributes: 50 | label: cargo Version 51 | description: "Found using `cargo --version`." 52 | placeholder: 1.66.0 (d65d197ad 2022-11-15) 53 | validations: 54 | required: true 55 | - type: dropdown 56 | attributes: 57 | label: Platform 58 | options: 59 | - Windows 60 | - macOS 61 | - Linux 62 | - Other 63 | validations: 64 | required: true 65 | labels: [ "bug" ] 66 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Suggest a way to enhance this project. 3 | body: 4 | - type: textarea 5 | attributes: 6 | label: Problem 7 | description: Describe the problem you intend to fix. 8 | placeholder: I just wish it was easier to... 9 | validations: 10 | required: true 11 | - type: textarea 12 | attributes: 13 | label: Solution 14 | description: Describe how you think the aforementioned problem should be resolved. 15 | placeholder: This wouldn't be a problem if... 16 | validations: 17 | required: true 18 | - type: textarea 19 | attributes: 20 | label: Alternatives 21 | description: List any alternatives you've considered as a solution. 22 | placeholder: We could also [...] or add [...] to fix the problem. 23 | labels: [ "enhancement" ] 24 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "cargo" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | on: 3 | release: 4 | types: [created] 5 | env: 6 | CARGO_TERM_COLOR: always 7 | OPENAI_KEY: ${{ secrets.OPENAI_KEY }} 8 | jobs: 9 | publish: 10 | name: Publish to crates.io 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/cache@v3 15 | with: 16 | path: | 17 | ~/.cargo/bin/ 18 | ~/.cargo/registry/index/ 19 | ~/.cargo/registry/cache/ 20 | ~/.cargo/git/db/ 21 | target/ 22 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} 23 | - name: Format check 24 | run: cargo fmt --all -- --check 25 | - name: Build 26 | run: cargo build --verbose 27 | - name: Run tests (native-tls) 28 | run: cargo test --verbose 29 | - name: Run tests (rustls) 30 | run: cargo test --verbose --no-default-features --features=rustls 31 | - name: Publish 32 | run: cargo publish --token ${CRATES_TOKEN} 33 | env: 34 | CRATES_TOKEN: ${{ secrets.CRATES_TOKEN }} 35 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | push: 4 | branches: [ "master" ] 5 | pull_request_target: 6 | branches: [ "master" ] 7 | env: 8 | CARGO_TERM_COLOR: always 9 | OPENAI_KEY: ${{ secrets.OPENAI_KEY }} 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | with: 16 | ref: ${{ github.event.pull_request.head.sha }} 17 | - uses: actions/cache@v3 18 | with: 19 | path: | 20 | ~/.cargo/bin/ 21 | ~/.cargo/registry/index/ 22 | ~/.cargo/registry/cache/ 23 | ~/.cargo/git/db/ 24 | target/ 25 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} 26 | - name: Format check 27 | run: cargo fmt --all -- --check 28 | - name: Build 29 | run: cargo build --verbose 30 | - name: Run tests (native-tls) 31 | run: cargo test --verbose 32 | - name: Run tests (rustls) 33 | run: cargo test --verbose --no-default-features --features=rustls 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | # Environment variables loaded by dotenvy crate 13 | .env 14 | 15 | # macOS custom folder attributes 16 | .DS_Store 17 | 18 | # Jetbrains 19 | .idea 20 | *.iml 21 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | valentinegb@icloud.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "openai" 3 | version = "1.1.1" 4 | authors = ["Lorenzo Fontoura ", "valentinegb"] 5 | edition = "2021" 6 | description = "An unofficial Rust library for the OpenAI API." 7 | repository = "https://github.com/rellfy/openai" 8 | license = "MIT" 9 | keywords = ["ai", "machine-learning", "openai", "library"] 10 | 11 | [dependencies] 12 | serde_json = "1.0.94" 13 | derive_builder = "0.20.0" 14 | reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"], optional = true } 15 | serde = { version = "1.0.157", features = ["derive"] } 16 | reqwest-eventsource = "0.6" 17 | tokio = { version = "1.26.0", features = ["full"] } 18 | anyhow = "1.0.70" 19 | futures-util = "0.3.28" 20 | bytes = "1.4.0" 21 | 22 | [dev-dependencies] 23 | dotenvy = "0.15.7" 24 | 25 | [features] 26 | default = ["native-tls"] 27 | native-tls = ["reqwest/native-tls"] 28 | rustls = ["reqwest/rustls-tls"] 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Valentine Briese 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # openai 2 | 3 | [![crates.io](https://img.shields.io/crates/v/openai.svg)](https://crates.io/crates/openai/) 4 | [![Rust workflow](https://github.com/rellfy/openai/actions/workflows/test.yml/badge.svg)](https://github.com/rellfy/openai/actions/workflows/test.yml) 5 | 6 | An unofficial Rust library for the OpenAI API. 7 | 8 | ## Examples 9 | 10 | Examples can be found in the `examples` directory. 11 | 12 | Please note that examples are not available for all the crate's functionality, 13 | PRs are appreciated to expand the coverage. 14 | 15 | Currently, there are examples for the `completions` module and the `chat` 16 | module. 17 | For other modules, refer to the `tests` submodules for some reference. 18 | 19 | ### Chat Example 20 | 21 | ```rust 22 | // Relies on OPENAI_KEY and optionally OPENAI_BASE_URL. 23 | let credentials = Credentials::from_env(); 24 | let messages = vec![ 25 | ChatCompletionMessage { 26 | role: ChatCompletionMessageRole::System, 27 | content: Some("You are a helpful assistant.".to_string()), 28 | name: None, 29 | function_call: None, 30 | }, 31 | ChatCompletionMessage { 32 | role: ChatCompletionMessageRole::User, 33 | content: Some("Tell me a random crab fact".to_string()), 34 | name: None, 35 | function_call: None, 36 | }, 37 | ]; 38 | let chat_completion = ChatCompletion::builder("gpt-4o", messages.clone()) 39 | .credentials(credentials.clone()) 40 | .create() 41 | .await 42 | .unwrap(); 43 | let returned_message = chat_completion.choices.first().unwrap().message.clone(); 44 | // Assistant: Sure! Here's a random crab fact: ... 45 | println!( 46 | "{:#?}: {}", 47 | returned_message.role, 48 | returned_message.content.unwrap().trim() 49 | ); 50 | ``` 51 | 52 | ## Implementation Progress 53 | 54 | `██████████` Models 55 | 56 | `████████░░` Completions (Function calling is supported) 57 | 58 | `████████░░` Chat 59 | 60 | `██████████` Edits 61 | 62 | `░░░░░░░░░░` Images 63 | 64 | `█████████░` Embeddings 65 | 66 | `░░░░░░░░░░` Audio 67 | 68 | `███████░░░` Files 69 | 70 | `░░░░░░░░░░` Fine-tunes 71 | 72 | `██████████` Moderations 73 | 74 | ## Contributing 75 | 76 | All contributions are welcome. Unit tests are encouraged. 77 | 78 | > **Fork Notice** 79 | > 80 | > This package was initially developed by [Valentine Briese](https://github.com/valentinegb/openai). 81 | > As the original repo was archived, this is a fork and continuation of the project. 82 | -------------------------------------------------------------------------------- /examples/chat_cli.rs: -------------------------------------------------------------------------------- 1 | use dotenvy::dotenv; 2 | use openai::{ 3 | chat::{ChatCompletion, ChatCompletionMessage, ChatCompletionMessageRole}, 4 | Credentials, 5 | }; 6 | use std::io::{stdin, stdout, Write}; 7 | 8 | #[tokio::main] 9 | async fn main() { 10 | // Make sure you have a file named `.env` with the `OPENAI_KEY` environment variable defined! 11 | dotenv().unwrap(); 12 | let credentials = Credentials::from_env(); 13 | 14 | let mut messages = vec![ChatCompletionMessage { 15 | role: ChatCompletionMessageRole::System, 16 | content: Some("You are a large language model built into a command line interface as an example of what the `openai` Rust library made by Valentine Briese can do.".to_string()), 17 | ..Default::default() 18 | }]; 19 | 20 | loop { 21 | print!("User: "); 22 | stdout().flush().unwrap(); 23 | 24 | let mut user_message_content = String::new(); 25 | 26 | stdin().read_line(&mut user_message_content).unwrap(); 27 | messages.push(ChatCompletionMessage { 28 | role: ChatCompletionMessageRole::User, 29 | content: Some(user_message_content), 30 | ..Default::default() 31 | }); 32 | 33 | let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", messages.clone()) 34 | .credentials(credentials.clone()) 35 | .create() 36 | .await 37 | .unwrap(); 38 | let returned_message = chat_completion.choices.first().unwrap().message.clone(); 39 | 40 | println!( 41 | "{:#?}: {}", 42 | &returned_message.role, 43 | &returned_message.content.clone().unwrap().trim() 44 | ); 45 | 46 | messages.push(returned_message); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /examples/chat_simple.rs: -------------------------------------------------------------------------------- 1 | use dotenvy::dotenv; 2 | use openai::{ 3 | chat::{ChatCompletion, ChatCompletionMessage, ChatCompletionMessageRole}, 4 | Credentials, 5 | }; 6 | 7 | #[tokio::main] 8 | async fn main() { 9 | // Make sure you have a file named `.env` with the `OPENAI_KEY` environment variable defined! 10 | dotenv().unwrap(); 11 | // Relies on OPENAI_KEY and optionally OPENAI_BASE_URL. 12 | let credentials = Credentials::from_env(); 13 | let messages = vec![ 14 | ChatCompletionMessage { 15 | role: ChatCompletionMessageRole::System, 16 | content: Some("You are a helpful assistant.".to_string()), 17 | ..Default::default() 18 | }, 19 | ChatCompletionMessage { 20 | role: ChatCompletionMessageRole::User, 21 | content: Some("Tell me a random crab fact".to_string()), 22 | ..Default::default() 23 | }, 24 | ]; 25 | let chat_completion = ChatCompletion::builder("gpt-4o", messages.clone()) 26 | .credentials(credentials.clone()) 27 | .create() 28 | .await 29 | .unwrap(); 30 | let returned_message = chat_completion.choices.first().unwrap().message.clone(); 31 | // Assistant: Sure! Here's a random crab fact: Crabs communicate with each other by drumming or waving their pincers. 32 | println!( 33 | "{:#?}: {}", 34 | returned_message.role, 35 | returned_message.content.unwrap().trim() 36 | ); 37 | } 38 | -------------------------------------------------------------------------------- /examples/chat_stream_cli.rs: -------------------------------------------------------------------------------- 1 | use dotenvy::dotenv; 2 | use openai::chat::{ChatCompletion, ChatCompletionDelta}; 3 | use openai::{ 4 | chat::{ChatCompletionMessage, ChatCompletionMessageRole}, 5 | Credentials, 6 | }; 7 | use std::io::{stdin, stdout, Write}; 8 | use tokio::sync::mpsc::{error::TryRecvError, Receiver}; 9 | 10 | #[tokio::main] 11 | async fn main() { 12 | // Make sure you have a file named `.env` with the `OPENAI_KEY` environment variable defined! 13 | dotenv().unwrap(); 14 | let credentials = Credentials::from_env(); 15 | 16 | let mut messages = vec![ChatCompletionMessage { 17 | role: ChatCompletionMessageRole::System, 18 | content: Some("You're an AI that replies to each message verbosely.".to_string()), 19 | ..Default::default() 20 | }]; 21 | 22 | loop { 23 | print!("User: "); 24 | stdout().flush().unwrap(); 25 | 26 | let mut user_message_content = String::new(); 27 | 28 | stdin().read_line(&mut user_message_content).unwrap(); 29 | messages.push(ChatCompletionMessage { 30 | role: ChatCompletionMessageRole::User, 31 | content: Some(user_message_content), 32 | ..Default::default() 33 | }); 34 | 35 | let chat_stream = ChatCompletionDelta::builder("gpt-3.5-turbo", messages.clone()) 36 | .credentials(credentials.clone()) 37 | .create_stream() 38 | .await 39 | .unwrap(); 40 | 41 | let chat_completion: ChatCompletion = listen_for_tokens(chat_stream).await; 42 | let returned_message = chat_completion.choices.first().unwrap().message.clone(); 43 | 44 | messages.push(returned_message); 45 | } 46 | } 47 | 48 | async fn listen_for_tokens(mut chat_stream: Receiver) -> ChatCompletion { 49 | let mut merged: Option = None; 50 | loop { 51 | match chat_stream.try_recv() { 52 | Ok(delta) => { 53 | let choice = &delta.choices[0]; 54 | if let Some(role) = &choice.delta.role { 55 | print!("{:#?}: ", role); 56 | } 57 | if let Some(content) = &choice.delta.content { 58 | print!("{}", content); 59 | } 60 | stdout().flush().unwrap(); 61 | // Merge token into full completion. 62 | match merged.as_mut() { 63 | Some(c) => { 64 | c.merge(delta).unwrap(); 65 | } 66 | None => merged = Some(delta), 67 | }; 68 | } 69 | Err(TryRecvError::Empty) => { 70 | let duration = std::time::Duration::from_millis(50); 71 | tokio::time::sleep(duration).await; 72 | } 73 | Err(TryRecvError::Disconnected) => { 74 | break; 75 | } 76 | }; 77 | } 78 | println!(); 79 | merged.unwrap().into() 80 | } 81 | -------------------------------------------------------------------------------- /examples/completions_cli.rs: -------------------------------------------------------------------------------- 1 | use dotenvy::dotenv; 2 | use openai::{completions::Completion, Credentials}; 3 | use std::io::stdin; 4 | 5 | #[tokio::main] 6 | async fn main() { 7 | // Make sure you have a file named `.env` with the `OPENAI_KEY` environment variable defined! 8 | dotenv().unwrap(); 9 | let credentials = Credentials::from_env(); 10 | 11 | loop { 12 | println!("Prompt:"); 13 | 14 | let mut prompt = String::new(); 15 | 16 | stdin().read_line(&mut prompt).unwrap(); 17 | 18 | let completion = Completion::builder("gpt-3.5-turbo-instruct") 19 | .prompt(&prompt) 20 | .max_tokens(1024) 21 | .credentials(credentials.clone()) 22 | .create() 23 | .await 24 | .unwrap(); 25 | 26 | let response = &completion.choices.first().unwrap().text; 27 | 28 | println!("\nResponse:{response}\n"); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/chat.rs: -------------------------------------------------------------------------------- 1 | //! Given a chat conversation, the model will return a chat completion response. 2 | 3 | use super::{ 4 | openai_get, openai_get_with_query, openai_post, ApiResponseOrError, Credentials, 5 | RequestPagination, Usage, 6 | }; 7 | use crate::openai_request_stream; 8 | use derive_builder::Builder; 9 | use futures_util::StreamExt; 10 | use reqwest::Method; 11 | use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource}; 12 | use serde::{Deserialize, Serialize}; 13 | use serde_json::Value; 14 | use std::collections::HashMap; 15 | use tokio::sync::mpsc::{channel, Receiver, Sender}; 16 | 17 | /// A full chat completion. 18 | pub type ChatCompletion = ChatCompletionGeneric; 19 | 20 | /// A delta chat completion, which is streamed token by token. 21 | pub type ChatCompletionDelta = ChatCompletionGeneric; 22 | 23 | #[derive(Deserialize, Clone, Debug, Eq, PartialEq)] 24 | pub struct ChatCompletionGeneric { 25 | #[serde(default)] 26 | pub id: String, 27 | #[serde(default)] 28 | pub object: String, 29 | #[serde(default)] 30 | pub created: u64, 31 | #[serde(default)] 32 | pub model: String, 33 | #[serde(default = "default_empty_vec")] 34 | pub choices: Vec, 35 | pub usage: Option, 36 | } 37 | 38 | #[derive(Deserialize, Clone, Debug, Eq, PartialEq)] 39 | pub struct ChatCompletionChoice { 40 | pub index: u64, 41 | pub finish_reason: String, 42 | pub message: ChatCompletionMessage, 43 | } 44 | 45 | #[derive(Deserialize, Clone, Debug, Eq, PartialEq)] 46 | pub struct ChatCompletionChoiceDelta { 47 | pub index: u64, 48 | pub finish_reason: Option, 49 | pub delta: ChatCompletionMessageDelta, 50 | } 51 | 52 | fn is_none_or_empty_vec(opt: &Option>) -> bool { 53 | opt.as_ref().map(|v| v.is_empty()).unwrap_or(true) 54 | } 55 | 56 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq, Default)] 57 | pub struct ChatCompletionMessage { 58 | /// The role of the author of this message. 59 | pub role: ChatCompletionMessageRole, 60 | /// The contents of the message 61 | /// 62 | /// This is always required for all messages, except for when ChatGPT calls 63 | /// a function. 64 | pub content: Option, 65 | /// The name of the user in a multi-user chat 66 | #[serde(skip_serializing_if = "Option::is_none")] 67 | pub name: Option, 68 | /// The function that ChatGPT called. This should be "None" usually, and is returned by ChatGPT and not provided by the developer 69 | /// 70 | /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) 71 | #[serde(skip_serializing_if = "Option::is_none")] 72 | pub function_call: Option, 73 | /// Tool call that this message is responding to. 74 | /// Required if the role is `Tool`. 75 | #[serde(skip_serializing_if = "Option::is_none")] 76 | pub tool_call_id: Option, 77 | /// Tool calls that the assistant is requesting to invoke. 78 | /// Can only be populated if the role is `Assistant`, 79 | /// otherwise it should be empty. 80 | #[serde(skip_serializing_if = "is_none_or_empty_vec")] 81 | pub tool_calls: Option>, 82 | } 83 | 84 | /// Same as ChatCompletionMessage, but received during a response stream. 85 | #[derive(Deserialize, Clone, Debug, Eq, PartialEq)] 86 | pub struct ChatCompletionMessageDelta { 87 | /// The role of the author of this message. 88 | pub role: Option, 89 | /// The contents of the message 90 | pub content: Option, 91 | /// The name of the user in a multi-user chat 92 | #[serde(skip_serializing_if = "Option::is_none")] 93 | pub name: Option, 94 | /// The function that ChatGPT called 95 | /// 96 | /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) 97 | #[serde(skip_serializing_if = "Option::is_none")] 98 | pub function_call: Option, 99 | /// Tool call that this message is responding to. 100 | /// Required if the role is `Tool`. 101 | #[serde(skip_serializing_if = "Option::is_none")] 102 | pub tool_call_id: Option, 103 | /// Tool calls that the assistant is requesting to invoke. 104 | /// Can only be populated if the role is `Assistant`, 105 | /// otherwise it should be empty. 106 | #[serde(skip_serializing_if = "is_none_or_empty_vec")] 107 | pub tool_calls: Option>, 108 | } 109 | 110 | #[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] 111 | pub struct ToolCall { 112 | /// The ID of the tool call. 113 | pub id: String, 114 | /// The type of the tool. Currently, only `function` is supported. 115 | pub r#type: String, 116 | /// The function that the model called. 117 | pub function: ToolCallFunction, 118 | } 119 | 120 | #[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] 121 | pub struct ToolCallFunction { 122 | /// The name of the function to call. 123 | pub name: String, 124 | /// The arguments to call the function with, as generated by the model in 125 | /// JSON format. 126 | /// Note that the model does not always generate valid JSON, and may 127 | /// hallucinate parameters not defined by your function schema. 128 | /// Validate the arguments in your code before calling your function. 129 | pub arguments: String, 130 | } 131 | 132 | #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] 133 | pub struct ChatCompletionFunctionDefinition { 134 | /// The name of the function 135 | pub name: String, 136 | /// The description of the function 137 | #[serde(skip_serializing_if = "Option::is_none")] 138 | pub description: Option, 139 | /// The parameters of the function formatted in JSON Schema 140 | /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-parameters) 141 | /// [See more information about JSON Schema.](https://json-schema.org/understanding-json-schema/) 142 | #[serde(skip_serializing_if = "Option::is_none")] 143 | pub parameters: Option, 144 | } 145 | 146 | #[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] 147 | pub struct ChatCompletionFunctionCall { 148 | /// The name of the function ChatGPT called 149 | pub name: String, 150 | /// The arguments that ChatGPT called (formatted in JSON) 151 | /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) 152 | pub arguments: String, 153 | } 154 | 155 | /// Same as ChatCompletionFunctionCall, but received during a response stream. 156 | #[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] 157 | pub struct ChatCompletionFunctionCallDelta { 158 | /// The name of the function ChatGPT called 159 | pub name: Option, 160 | /// The arguments that ChatGPT called (formatted in JSON) 161 | /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) 162 | pub arguments: Option, 163 | } 164 | 165 | #[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)] 166 | #[serde(rename_all = "lowercase")] 167 | pub enum ChatCompletionMessageRole { 168 | System, 169 | User, 170 | Assistant, 171 | Function, 172 | Tool, 173 | Developer, 174 | } 175 | 176 | #[derive(Serialize, Builder, Debug, Clone)] 177 | #[builder(derive(Clone, Debug, PartialEq))] 178 | #[builder(pattern = "owned")] 179 | #[builder(name = "ChatCompletionBuilder")] 180 | #[builder(setter(strip_option, into))] 181 | pub struct ChatCompletionRequest { 182 | /// ID of the model to use. Currently, only `gpt-3.5-turbo`, `gpt-3.5-turbo-0301` and `gpt-4` 183 | /// are supported. 184 | model: String, 185 | /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction). 186 | messages: Vec, 187 | /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. 188 | /// 189 | /// We generally recommend altering this or `top_p` but not both. 190 | #[builder(default)] 191 | #[serde(skip_serializing_if = "Option::is_none")] 192 | temperature: Option, 193 | /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. 194 | /// 195 | /// We generally recommend altering this or `temperature` but not both. 196 | #[builder(default)] 197 | #[serde(skip_serializing_if = "Option::is_none")] 198 | top_p: Option, 199 | /// How many chat completion choices to generate for each input message. 200 | #[builder(default)] 201 | #[serde(skip_serializing_if = "Option::is_none")] 202 | n: Option, 203 | #[builder(default)] 204 | #[serde(skip_serializing_if = "Option::is_none")] 205 | stream: Option, 206 | /// Up to 4 sequences where the API will stop generating further tokens. 207 | #[builder(default)] 208 | #[serde(skip_serializing_if = "Vec::is_empty")] 209 | stop: Vec, 210 | /// This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend. 211 | #[builder(default)] 212 | #[serde(skip_serializing_if = "Option::is_none")] 213 | seed: Option, 214 | /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens). 215 | #[builder(default)] 216 | #[serde(skip_serializing_if = "Option::is_none")] 217 | max_tokens: Option, 218 | /// The maximum number of tokens allowed for the generated answer. 219 | /// For reasoning models such as o1 and o3-mini, this does not include reasoning tokens. 220 | #[builder(default)] 221 | #[serde(skip_serializing_if = "Option::is_none")] 222 | max_completion_tokens: Option, 223 | /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. 224 | /// 225 | /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) 226 | #[builder(default)] 227 | #[serde(skip_serializing_if = "Option::is_none")] 228 | presence_penalty: Option, 229 | /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. 230 | /// 231 | /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details) 232 | #[builder(default)] 233 | #[serde(skip_serializing_if = "Option::is_none")] 234 | frequency_penalty: Option, 235 | /// Modify the likelihood of specified tokens appearing in the completion. 236 | /// 237 | /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. 238 | #[builder(default)] 239 | #[serde(skip_serializing_if = "Option::is_none")] 240 | logit_bias: Option>, 241 | /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). 242 | #[builder(default)] 243 | #[serde(skip_serializing_if = "String::is_empty")] 244 | user: String, 245 | /// Describe functions that ChatGPT can call 246 | /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt. 247 | /// For example, you can define a function called "get_weather" that returns the weather in a given city 248 | /// 249 | /// [Function calling API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions) 250 | /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling) 251 | #[builder(default)] 252 | #[serde(skip_serializing_if = "Vec::is_empty")] 253 | functions: Vec, 254 | /// A string or object of the function to call 255 | /// 256 | /// Controls how the model responds to function calls 257 | /// 258 | /// - "none" means the model does not call a function, and responds to the end-user. 259 | /// - "auto" means the model can pick between an end-user or calling a function. 260 | /// - Specifying a particular function via {"name":\ "my_function"} forces the model to call that function. 261 | /// 262 | /// "none" is the default when no functions are present. "auto" is the default if functions are present. 263 | #[builder(default)] 264 | #[serde(skip_serializing_if = "Option::is_none")] 265 | function_call: Option, 266 | /// An object specifying the format that the model must output. Compatible with GPT-4 Turbo and all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. 267 | /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. 268 | /// Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length. 269 | #[builder(default)] 270 | #[serde(skip_serializing_if = "Option::is_none")] 271 | response_format: Option, 272 | /// The credentials to use for this request. 273 | #[serde(skip_serializing)] 274 | #[builder(default)] 275 | credentials: Option, 276 | /// Parameters unique to the Venice API. 277 | /// https://docs.venice.ai/api-reference/api-spec 278 | #[builder(default)] 279 | #[serde(skip_serializing_if = "Option::is_none")] 280 | venice_parameters: Option, 281 | /// Whether to store the completion for use in distillation or evals. 282 | #[serde(skip_serializing_if = "Option::is_none")] 283 | #[builder(default)] 284 | pub store: Option, 285 | } 286 | 287 | #[derive(Serialize, Debug, Clone, Eq, PartialEq)] 288 | pub struct VeniceParameters { 289 | pub include_venice_system_prompt: bool, 290 | } 291 | 292 | #[derive(Serialize, Debug, Clone, Eq, PartialEq)] 293 | pub struct ChatCompletionResponseFormat { 294 | /// Must be one of text or json_object (defaults to text) 295 | #[serde(rename = "type")] 296 | typ: String, 297 | } 298 | 299 | impl ChatCompletionResponseFormat { 300 | pub fn json_object() -> Self { 301 | ChatCompletionResponseFormat { 302 | typ: "json_object".to_string(), 303 | } 304 | } 305 | 306 | pub fn text() -> Self { 307 | ChatCompletionResponseFormat { 308 | typ: "text".to_string(), 309 | } 310 | } 311 | } 312 | 313 | impl ChatCompletionGeneric { 314 | pub fn builder( 315 | model: &str, 316 | messages: impl Into>, 317 | ) -> ChatCompletionBuilder { 318 | ChatCompletionBuilder::create_empty() 319 | .model(model) 320 | .messages(messages) 321 | } 322 | } 323 | 324 | #[derive(Serialize, Builder, Debug, Clone, Default)] 325 | #[builder(derive(Clone, Debug, PartialEq))] 326 | #[builder(pattern = "owned")] 327 | #[builder(name = "ChatCompletionMessagesRequestBuilder")] 328 | #[builder(setter(strip_option, into))] 329 | pub struct ChatCompletionMessagesRequest { 330 | #[serde(skip_serializing)] 331 | pub completion_id: String, 332 | 333 | #[builder(default)] 334 | #[serde(skip_serializing)] 335 | pub credentials: Option, 336 | 337 | #[builder(default)] 338 | #[serde(flatten)] 339 | pub pagination: RequestPagination, 340 | } 341 | 342 | /// A list of messages for a chat completion. 343 | #[derive(Deserialize, Clone, Debug, Eq, PartialEq)] 344 | pub struct ChatCompletionMessages { 345 | pub data: Vec, 346 | pub object: String, 347 | pub first_id: Option, 348 | pub last_id: Option, 349 | pub has_more: bool, 350 | } 351 | 352 | impl ChatCompletion { 353 | pub async fn create(request: ChatCompletionRequest) -> ApiResponseOrError { 354 | let credentials_opt = request.credentials.clone(); 355 | openai_post("chat/completions", &request, credentials_opt).await 356 | } 357 | 358 | /// Get a stored completion. 359 | pub async fn get(id: &str, credentials: Credentials) -> ApiResponseOrError { 360 | let route = format!("chat/completions/{}", id); 361 | openai_get(route.as_str(), Some(credentials)).await 362 | } 363 | } 364 | 365 | impl ChatCompletionDelta { 366 | pub async fn create( 367 | request: ChatCompletionRequest, 368 | ) -> Result, CannotCloneRequestError> { 369 | let credentials_opt = request.credentials.clone(); 370 | let stream = openai_request_stream( 371 | Method::POST, 372 | "chat/completions", 373 | |r| r.json(&request), 374 | credentials_opt, 375 | ) 376 | .await?; 377 | let (tx, rx) = channel::(32); 378 | tokio::spawn(forward_deserialized_chat_response_stream(stream, tx)); 379 | Ok(rx) 380 | } 381 | 382 | /// Merges the input delta completion into `self`. 383 | pub fn merge( 384 | &mut self, 385 | other: ChatCompletionDelta, 386 | ) -> Result<(), ChatCompletionDeltaMergeError> { 387 | if other.id.ne(&self.id) { 388 | return Err(ChatCompletionDeltaMergeError::DifferentCompletionIds); 389 | } 390 | for other_choice in other.choices.iter() { 391 | for choice in self.choices.iter_mut() { 392 | if choice.index != other_choice.index { 393 | continue; 394 | } 395 | choice.merge(other_choice)?; 396 | } 397 | } 398 | Ok(()) 399 | } 400 | } 401 | 402 | impl ChatCompletionChoiceDelta { 403 | pub fn merge( 404 | &mut self, 405 | other: &ChatCompletionChoiceDelta, 406 | ) -> Result<(), ChatCompletionDeltaMergeError> { 407 | if self.index != other.index { 408 | return Err(ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices); 409 | } 410 | if self.delta.role.is_none() { 411 | if let Some(other_role) = other.delta.role { 412 | // Set role to other_role. 413 | self.delta.role = Some(other_role); 414 | } 415 | } 416 | if self.delta.name.is_none() { 417 | if let Some(other_name) = &other.delta.name { 418 | // Set name to other_name. 419 | self.delta.name = Some(other_name.clone()); 420 | } 421 | } 422 | // Merge contents. 423 | match self.delta.content.as_mut() { 424 | Some(content) => { 425 | match &other.delta.content { 426 | Some(other_content) => { 427 | // Push other content into this one. 428 | content.push_str(other_content) 429 | } 430 | None => {} 431 | } 432 | } 433 | None => { 434 | match &other.delta.content { 435 | Some(other_content) => { 436 | // Set this content to other content. 437 | self.delta.content = Some(other_content.clone()); 438 | } 439 | None => {} 440 | } 441 | } 442 | }; 443 | 444 | // merge function calls 445 | // function call names are concatenated 446 | // arguments are merged by concatenating them 447 | match self.delta.function_call.as_mut() { 448 | Some(function_call) => { 449 | match &other.delta.function_call { 450 | Some(other_function_call) => { 451 | // push the arguments string of the other function call into this one 452 | match (&mut function_call.arguments, &other_function_call.arguments) { 453 | (Some(function_call), Some(other_function_call)) => { 454 | function_call.push_str(&other_function_call); 455 | } 456 | (None, Some(other_function_call)) => { 457 | function_call.arguments = Some(other_function_call.clone()); 458 | } 459 | _ => {} 460 | } 461 | } 462 | None => {} 463 | } 464 | } 465 | None => { 466 | match &other.delta.function_call { 467 | Some(other_function_call) => { 468 | // Set this content to other content. 469 | self.delta.function_call = Some(other_function_call.clone()); 470 | } 471 | None => {} 472 | } 473 | } 474 | }; 475 | Ok(()) 476 | } 477 | } 478 | 479 | impl From for ChatCompletion { 480 | fn from(delta: ChatCompletionDelta) -> Self { 481 | ChatCompletion { 482 | id: delta.id, 483 | object: delta.object, 484 | created: delta.created, 485 | model: delta.model, 486 | usage: delta.usage, 487 | choices: delta 488 | .choices 489 | .iter() 490 | .map(|choice| ChatCompletionChoice { 491 | index: choice.index, 492 | finish_reason: clone_default_unwrapped_option_string(&choice.finish_reason), 493 | message: ChatCompletionMessage { 494 | role: choice 495 | .delta 496 | .role 497 | .unwrap_or_else(|| ChatCompletionMessageRole::System), 498 | content: choice.delta.content.clone(), 499 | name: choice.delta.name.clone(), 500 | function_call: choice.delta.function_call.clone().map(|f| f.into()), 501 | tool_call_id: None, 502 | tool_calls: Some(Vec::new()), 503 | }, 504 | }) 505 | .collect(), 506 | } 507 | } 508 | } 509 | 510 | impl From for ChatCompletionFunctionCall { 511 | fn from(delta: ChatCompletionFunctionCallDelta) -> Self { 512 | ChatCompletionFunctionCall { 513 | name: delta.name.unwrap_or("".to_string()), 514 | arguments: delta.arguments.unwrap_or_default(), 515 | } 516 | } 517 | } 518 | 519 | impl ChatCompletionMessages { 520 | /// Create a builder for fetching messages for a stored completion. 521 | pub fn builder(completion_id: String) -> ChatCompletionMessagesRequestBuilder { 522 | ChatCompletionMessagesRequestBuilder::create_empty() 523 | .completion_id(completion_id.to_string()) 524 | } 525 | 526 | /// Fetch messages for a stored completion. 527 | pub async fn fetch( 528 | request: ChatCompletionMessagesRequest, 529 | ) -> ApiResponseOrError { 530 | let route = format!("chat/completions/{}/messages", request.completion_id); 531 | let credentials = request.credentials.clone(); 532 | openai_get_with_query(route.as_str(), &request, credentials).await 533 | } 534 | } 535 | 536 | #[derive(Debug)] 537 | pub enum ChatCompletionDeltaMergeError { 538 | DifferentCompletionIds, 539 | DifferentCompletionChoiceIndices, 540 | FunctionCallArgumentTypeMismatch, 541 | } 542 | 543 | impl std::fmt::Display for ChatCompletionDeltaMergeError { 544 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 545 | match self { 546 | ChatCompletionDeltaMergeError::DifferentCompletionIds => { 547 | f.write_str("Different completion IDs") 548 | } 549 | ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices => { 550 | f.write_str("Different completion choice indices") 551 | } 552 | ChatCompletionDeltaMergeError::FunctionCallArgumentTypeMismatch => { 553 | f.write_str("Function call argument type mismatch") 554 | } 555 | } 556 | } 557 | } 558 | 559 | impl std::error::Error for ChatCompletionDeltaMergeError {} 560 | 561 | async fn forward_deserialized_chat_response_stream( 562 | mut stream: EventSource, 563 | tx: Sender, 564 | ) -> anyhow::Result<()> { 565 | while let Some(event) = stream.next().await { 566 | let event = event?; 567 | match event { 568 | Event::Message(event) => { 569 | let completion = serde_json::from_str::(&event.data)?; 570 | tx.send(completion).await?; 571 | } 572 | _ => {} 573 | } 574 | } 575 | Ok(()) 576 | } 577 | 578 | impl ChatCompletionBuilder { 579 | pub async fn create(self) -> ApiResponseOrError { 580 | ChatCompletion::create(self.build().unwrap()).await 581 | } 582 | 583 | pub async fn create_stream( 584 | mut self, 585 | ) -> Result, CannotCloneRequestError> { 586 | self.stream = Some(Some(true)); 587 | ChatCompletionDelta::create(self.build().unwrap()).await 588 | } 589 | } 590 | 591 | impl ChatCompletionMessagesRequestBuilder { 592 | /// Fetch messages for the specified completion. 593 | pub async fn fetch(self) -> ApiResponseOrError { 594 | ChatCompletionMessages::fetch(self.build().unwrap()).await 595 | } 596 | } 597 | 598 | fn clone_default_unwrapped_option_string(string: &Option) -> String { 599 | match string { 600 | Some(value) => value.clone(), 601 | None => "".to_string(), 602 | } 603 | } 604 | 605 | impl Default for ChatCompletionMessageRole { 606 | fn default() -> Self { 607 | Self::User 608 | } 609 | } 610 | 611 | fn default_empty_vec() -> Vec { 612 | Vec::new() 613 | } 614 | 615 | #[cfg(test)] 616 | mod tests { 617 | use super::*; 618 | use dotenvy::dotenv; 619 | use std::time::Duration; 620 | use tokio::time::sleep; 621 | 622 | #[tokio::test] 623 | async fn chat() { 624 | dotenv().ok(); 625 | let credentials = Credentials::from_env(); 626 | 627 | let chat_completion = ChatCompletion::builder( 628 | "gpt-3.5-turbo", 629 | [ChatCompletionMessage { 630 | role: ChatCompletionMessageRole::User, 631 | content: Some("Hello!".to_string()), 632 | name: None, 633 | function_call: None, 634 | tool_call_id: None, 635 | tool_calls: Some(Vec::new()), 636 | }], 637 | ) 638 | .temperature(0.0) 639 | .response_format(ChatCompletionResponseFormat::text()) 640 | .credentials(credentials) 641 | .create() 642 | .await 643 | .unwrap(); 644 | 645 | assert_eq!( 646 | chat_completion 647 | .choices 648 | .first() 649 | .unwrap() 650 | .message 651 | .content 652 | .as_ref() 653 | .unwrap(), 654 | "Hello! How can I assist you today?" 655 | ); 656 | } 657 | 658 | // Seeds are not deterministic so the only point of the test is to 659 | // ensure that passing a seed still results in a valid response. 660 | #[tokio::test] 661 | async fn chat_seed() { 662 | dotenv().ok(); 663 | let credentials = Credentials::from_env(); 664 | 665 | let chat_completion = ChatCompletion::builder( 666 | "gpt-3.5-turbo", 667 | [ChatCompletionMessage { 668 | role: ChatCompletionMessageRole::User, 669 | content: Some( 670 | "What type of seed does Mr. England sow in the song? Reply with 1 word." 671 | .to_string(), 672 | ), 673 | name: None, 674 | function_call: None, 675 | tool_call_id: None, 676 | tool_calls: Some(Vec::new()), 677 | }], 678 | ) 679 | // Determinism currently comes from temperature 0, not seed. 680 | .temperature(0.0) 681 | .seed(1337u64) 682 | .credentials(credentials) 683 | .create() 684 | .await 685 | .unwrap(); 686 | 687 | assert_eq!( 688 | chat_completion 689 | .choices 690 | .first() 691 | .unwrap() 692 | .message 693 | .content 694 | .as_ref() 695 | .unwrap(), 696 | "Love" 697 | ); 698 | } 699 | 700 | #[tokio::test] 701 | async fn chat_stream() { 702 | dotenv().ok(); 703 | let credentials = Credentials::from_env(); 704 | 705 | let chat_stream = ChatCompletion::builder( 706 | "gpt-3.5-turbo", 707 | [ChatCompletionMessage { 708 | role: ChatCompletionMessageRole::User, 709 | content: Some("Hello!".to_string()), 710 | name: None, 711 | function_call: None, 712 | tool_call_id: None, 713 | tool_calls: Some(Vec::new()), 714 | }], 715 | ) 716 | .temperature(0.0) 717 | .credentials(credentials) 718 | .create_stream() 719 | .await 720 | .unwrap(); 721 | 722 | let chat_completion = stream_to_completion(chat_stream).await; 723 | 724 | assert_eq!( 725 | chat_completion 726 | .choices 727 | .first() 728 | .unwrap() 729 | .message 730 | .content 731 | .as_ref() 732 | .unwrap(), 733 | "Hello! How can I assist you today?" 734 | ); 735 | } 736 | 737 | #[tokio::test] 738 | async fn chat_function() { 739 | dotenv().ok(); 740 | let credentials = Credentials::from_env(); 741 | 742 | let chat_stream = ChatCompletion::builder( 743 | "gpt-4o", 744 | [ 745 | ChatCompletionMessage { 746 | role: ChatCompletionMessageRole::User, 747 | content: Some("What is the weather in Boston?".to_string()), 748 | name: None, 749 | function_call: None, 750 | tool_call_id: None, 751 | tool_calls: Some(Vec::new()), 752 | } 753 | ] 754 | ).functions([ChatCompletionFunctionDefinition { 755 | description: Some("Get the current weather in a given location.".to_string()), 756 | name: "get_current_weather".to_string(), 757 | parameters: Some(serde_json::json!({ 758 | "type": "object", 759 | "properties": { 760 | "location": { 761 | "type": "string", 762 | "description": "The city and state to get the weather for. (eg: San Francisco, CA)" 763 | } 764 | }, 765 | "required": ["location"] 766 | })), 767 | }]) 768 | .temperature(0.2) 769 | .credentials(credentials) 770 | .create_stream() 771 | .await 772 | .unwrap(); 773 | 774 | let chat_completion = stream_to_completion(chat_stream).await; 775 | 776 | assert_eq!( 777 | chat_completion 778 | .choices 779 | .first() 780 | .unwrap() 781 | .message 782 | .function_call 783 | .as_ref() 784 | .unwrap() 785 | .name, 786 | "get_current_weather".to_string(), 787 | ); 788 | 789 | assert_eq!( 790 | serde_json::from_str::( 791 | &chat_completion 792 | .choices 793 | .first() 794 | .unwrap() 795 | .message 796 | .function_call 797 | .as_ref() 798 | .unwrap() 799 | .arguments 800 | ) 801 | .unwrap(), 802 | serde_json::json!({ 803 | "location": "Boston, MA" 804 | }), 805 | ); 806 | } 807 | 808 | #[tokio::test] 809 | async fn chat_response_format_json() { 810 | dotenv().ok(); 811 | let credentials = Credentials::from_env(); 812 | let chat_completion = ChatCompletion::builder( 813 | "gpt-3.5-turbo", 814 | [ChatCompletionMessage { 815 | role: ChatCompletionMessageRole::User, 816 | content: Some("Write an example JSON for a JWT header using RS256".to_string()), 817 | name: None, 818 | function_call: None, 819 | tool_call_id: None, 820 | tool_calls: Some(Vec::new()), 821 | }], 822 | ) 823 | .temperature(0.0) 824 | .seed(1337u64) 825 | .response_format(ChatCompletionResponseFormat::json_object()) 826 | .credentials(credentials) 827 | .create() 828 | .await 829 | .unwrap(); 830 | let response_string = chat_completion 831 | .choices 832 | .first() 833 | .unwrap() 834 | .message 835 | .content 836 | .as_ref() 837 | .unwrap(); 838 | #[derive(Deserialize, Eq, PartialEq, Debug)] 839 | struct Response { 840 | alg: String, 841 | typ: String, 842 | } 843 | let response = serde_json::from_str::(response_string).unwrap(); 844 | assert_eq!( 845 | response, 846 | Response { 847 | alg: "RS256".to_owned(), 848 | typ: "JWT".to_owned() 849 | } 850 | ); 851 | } 852 | 853 | #[test] 854 | fn builder_clone_and_eq() { 855 | let builder_a = ChatCompletion::builder("gpt-4", []) 856 | .temperature(0.0) 857 | .seed(65u64); 858 | let builder_b = builder_a.clone(); 859 | let builder_c = builder_b.clone().temperature(1.0); 860 | let builder_d = ChatCompletionBuilder::default(); 861 | assert_eq!(builder_a, builder_b); 862 | assert_ne!(builder_a, builder_c); 863 | assert_ne!(builder_b, builder_c); 864 | assert_ne!(builder_a, builder_d); 865 | assert_ne!(builder_c, builder_d); 866 | } 867 | 868 | async fn stream_to_completion( 869 | mut chat_stream: Receiver, 870 | ) -> ChatCompletion { 871 | let mut merged: Option = None; 872 | while let Some(delta) = chat_stream.recv().await { 873 | match merged.as_mut() { 874 | Some(c) => { 875 | c.merge(delta).unwrap(); 876 | } 877 | None => merged = Some(delta), 878 | }; 879 | } 880 | merged.unwrap().into() 881 | } 882 | 883 | #[tokio::test] 884 | async fn chat_tool_response_completion() { 885 | dotenv().ok(); 886 | let credentials = Credentials::from_env(); 887 | 888 | let chat_completion = ChatCompletion::builder( 889 | "gpt-4o-mini", 890 | [ 891 | ChatCompletionMessage { 892 | role: ChatCompletionMessageRole::User, 893 | content: Some( 894 | "What's 0.9102847*28456? \ 895 | reply in plain text, \ 896 | round the number to to 2 decimals \ 897 | and reply with the result number only, \ 898 | with no full stop at the end" 899 | .to_string(), 900 | ), 901 | name: None, 902 | function_call: None, 903 | tool_call_id: None, 904 | tool_calls: Some(Vec::new()), 905 | }, 906 | ChatCompletionMessage { 907 | role: ChatCompletionMessageRole::Assistant, 908 | content: Some("Let me calculate that for you.".to_string()), 909 | name: None, 910 | function_call: None, 911 | tool_call_id: None, 912 | tool_calls: Some(vec![ToolCall { 913 | id: "the_tool_call".to_string(), 914 | r#type: "function".to_string(), 915 | function: ToolCallFunction { 916 | name: "mul".to_string(), 917 | arguments: "not_required_to_be_valid_here".to_string(), 918 | }, 919 | }]), 920 | }, 921 | ChatCompletionMessage { 922 | role: ChatCompletionMessageRole::Tool, 923 | content: Some("the result is 25903.061423199997".to_string()), 924 | name: None, 925 | function_call: None, 926 | tool_call_id: Some("the_tool_call".to_owned()), 927 | tool_calls: Some(Vec::new()), 928 | }, 929 | ], 930 | ) 931 | // Determinism currently comes from temperature 0, not seed. 932 | .temperature(0.0) 933 | .seed(1337u64) 934 | .credentials(credentials) 935 | .create() 936 | .await 937 | .unwrap(); 938 | 939 | assert_eq!( 940 | chat_completion 941 | .choices 942 | .first() 943 | .unwrap() 944 | .message 945 | .content 946 | .as_ref() 947 | .unwrap(), 948 | "25903.06" 949 | ); 950 | } 951 | 952 | #[tokio::test] 953 | async fn get_completion() { 954 | dotenv().ok(); 955 | let credentials = Credentials::from_env(); 956 | 957 | let chat_completion = ChatCompletion::builder( 958 | "gpt-3.5-turbo", 959 | [ChatCompletionMessage { 960 | role: ChatCompletionMessageRole::User, 961 | content: Some("Hello!".to_string()), 962 | ..Default::default() 963 | }], 964 | ) 965 | .credentials(credentials.clone()) 966 | .store(true) 967 | .create() 968 | .await 969 | .unwrap(); 970 | 971 | // Unfortunatelly completions are not available immediately so we need to wait a bit 972 | sleep(Duration::from_secs(7)).await; 973 | 974 | let retrieved_completion = ChatCompletion::get(&chat_completion.id, credentials.clone()) 975 | .await 976 | .unwrap(); 977 | 978 | assert_eq!(retrieved_completion, chat_completion); 979 | } 980 | 981 | #[tokio::test] 982 | async fn get_completion_non_existent() { 983 | dotenv().ok(); 984 | let credentials = Credentials::from_env(); 985 | 986 | match ChatCompletion::get("non_existent_id", credentials.clone()).await { 987 | Ok(_) => panic!("Expected error"), 988 | Err(e) => assert_eq!(e.code, Some("not_found".to_string())), 989 | } 990 | } 991 | 992 | #[tokio::test] 993 | async fn get_completion_messages() { 994 | dotenv().ok(); 995 | let credentials = Credentials::from_env(); 996 | 997 | let user_message = ChatCompletionMessage { 998 | role: ChatCompletionMessageRole::User, 999 | content: Some("Tell me a short joke".to_string()), 1000 | ..Default::default() 1001 | }; 1002 | 1003 | let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", [user_message.clone()]) 1004 | .credentials(credentials.clone()) 1005 | .store(true) 1006 | .create() 1007 | .await 1008 | .unwrap(); 1009 | 1010 | // Unfortunatelly completions are not available immediately so we need to wait a bit 1011 | sleep(Duration::from_secs(7)).await; 1012 | 1013 | let retrieved_messages = ChatCompletionMessages::builder(chat_completion.id) 1014 | .credentials(credentials.clone()) 1015 | .fetch() 1016 | .await 1017 | .unwrap(); 1018 | 1019 | assert_eq!(retrieved_messages.data, vec![user_message]); 1020 | assert_eq!(retrieved_messages.has_more, false); 1021 | } 1022 | 1023 | #[tokio::test] 1024 | async fn get_completion_messages_with_pagination() { 1025 | dotenv().ok(); 1026 | let credentials = Credentials::from_env(); 1027 | 1028 | let user_message = ChatCompletionMessage { 1029 | role: ChatCompletionMessageRole::User, 1030 | content: Some("Tell me a short joke".to_string()), 1031 | ..Default::default() 1032 | }; 1033 | 1034 | let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", [user_message.clone()]) 1035 | .credentials(credentials.clone()) 1036 | .store(true) 1037 | .create() 1038 | .await 1039 | .unwrap(); 1040 | 1041 | dbg!(&chat_completion); 1042 | 1043 | // Unfortunatelly completions are not available immediately so we need to wait a bit 1044 | sleep(Duration::from_secs(7)).await; 1045 | 1046 | // Fetch the first page 1047 | let retrieved_messages1 = ChatCompletionMessages::builder(chat_completion.id.clone()) 1048 | .credentials(credentials.clone()) 1049 | .pagination(RequestPagination { 1050 | limit: Some(1), 1051 | ..Default::default() 1052 | }) 1053 | .fetch() 1054 | .await 1055 | .unwrap(); 1056 | 1057 | assert_eq!(retrieved_messages1.data, vec![user_message]); 1058 | assert_eq!(retrieved_messages1.has_more, false); 1059 | assert!(retrieved_messages1.first_id.is_some()); 1060 | assert!(retrieved_messages1.last_id.is_some()); 1061 | 1062 | // Fetch the second page, which should be empty 1063 | let retrieved_messages2 = ChatCompletionMessages::builder(chat_completion.id.clone()) 1064 | .credentials(credentials.clone()) 1065 | .pagination(RequestPagination { 1066 | limit: Some(1), 1067 | after: Some(retrieved_messages1.first_id.unwrap()), 1068 | ..Default::default() 1069 | }) 1070 | .fetch() 1071 | .await 1072 | .unwrap(); 1073 | 1074 | assert_eq!(retrieved_messages2.data, vec![]); 1075 | assert_eq!(retrieved_messages2.has_more, false); 1076 | assert!(retrieved_messages2.first_id.is_none()); 1077 | assert!(retrieved_messages2.last_id.is_none()); 1078 | } 1079 | } 1080 | -------------------------------------------------------------------------------- /src/completions.rs: -------------------------------------------------------------------------------- 1 | //! Given a prompt, the model will return one or more predicted completions, 2 | //! and can also return the probabilities of alternative tokens at each position. 3 | use super::{openai_post, ApiResponseOrError, Credentials, Usage}; 4 | use derive_builder::Builder; 5 | use serde::{Deserialize, Serialize}; 6 | use std::collections::HashMap; 7 | 8 | #[derive(Deserialize, Clone)] 9 | pub struct Completion { 10 | pub id: String, 11 | pub created: u32, 12 | pub model: String, 13 | pub choices: Vec, 14 | pub usage: Usage, 15 | } 16 | 17 | #[derive(Deserialize, Clone)] 18 | pub struct CompletionChoice { 19 | pub text: String, 20 | pub index: u16, 21 | pub logprobs: Option, 22 | pub finish_reason: String, 23 | } 24 | 25 | #[derive(Serialize, Builder, Debug, Clone)] 26 | #[builder(pattern = "owned")] 27 | #[builder(name = "CompletionBuilder")] 28 | #[builder(setter(strip_option, into))] 29 | pub struct CompletionRequest { 30 | /// ID of the model to use. 31 | /// You can use the [List models](https://beta.openai.com/docs/api-reference/models/list) 32 | /// API to see all of your available models, 33 | /// or see our [Model overview](https://beta.openai.com/docs/models/overview) 34 | /// for descriptions of them. 35 | pub model: String, 36 | /// The prompt(s) to generate completions for, encoded as a string, 37 | /// array of strings, array of tokens, or array of token arrays. 38 | /// 39 | /// Note that <|endoftext|> is the document separator that the model sees during training, 40 | /// so if a prompt is not specified the model will generate as if from the beginning of a new document. 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | #[builder(default)] 43 | pub prompt: Option, 44 | /// The suffix that comes after a completion of inserted text. 45 | #[serde(skip_serializing_if = "Option::is_none")] 46 | #[builder(default)] 47 | pub suffix: Option, 48 | /// The maximum number of [tokens](https://beta.openai.com/tokenizer) to generate in the completion. 49 | /// 50 | /// The token count of your prompt plus `max_tokens` cannot exceed the model's context length. 51 | /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). 52 | #[serde(skip_serializing_if = "Option::is_none")] 53 | #[builder(setter(into = false), default)] 54 | pub max_tokens: Option, 55 | /// What [sampling temperature](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277) to use. 56 | /// Higher values means the model will take more risks. 57 | /// Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. 58 | /// 59 | /// We generally recommend altering this or `top_p` but not both. 60 | #[serde(skip_serializing_if = "Option::is_none")] 61 | #[builder(default)] 62 | pub temperature: Option, 63 | /// An alternative to sampling with temperature, called nucleus sampling, 64 | /// where the model considers the results of the tokens with top_p probability mass. 65 | /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. 66 | /// 67 | /// We generally recommend altering this or `temperature` but not both. 68 | #[serde(skip_serializing_if = "Option::is_none")] 69 | #[builder(default)] 70 | pub top_p: Option, 71 | /// How many completions to generate for each prompt. 72 | /// 73 | /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. 74 | /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. 75 | #[serde(skip_serializing_if = "Option::is_none")] 76 | #[builder(default)] 77 | pub n: Option, 78 | /// Whether to stream back partial progress. If set, tokens will be sent as data-only 79 | /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) 80 | /// as they become available, with the stream terminated by a `data: [DONE]` message. 81 | #[serde(skip_serializing_if = "Option::is_none")] 82 | #[builder(setter(skip), default)] // skipped until properly implemented 83 | pub stream: Option, 84 | /// Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens. 85 | /// For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. 86 | /// The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. 87 | /// 88 | /// The maximum value for `logprobs` is 5. 89 | /// If you need more than this, please contact us through our Help center and describe your use case. 90 | #[serde(skip_serializing_if = "Option::is_none")] 91 | #[builder(default)] 92 | pub logprobs: Option, 93 | /// Echo back the prompt in addition to the completion 94 | #[serde(skip_serializing_if = "Option::is_none")] 95 | #[builder(default)] 96 | pub echo: Option, 97 | /// Up to 4 sequences where the API will stop generating further tokens. 98 | /// The returned text will not contain the stop sequence. 99 | #[serde(skip_serializing_if = "Vec::is_empty")] 100 | #[builder(default)] 101 | pub stop: Vec, 102 | /// Number between -2.0 and 2.0. 103 | /// Positive values penalize new tokens based on whether they appear in the text so far, 104 | /// increasing the model's likelihood to talk about new topics. 105 | /// 106 | /// [See more information about frequency and presence penalties](https://beta.openai.com/docs/api-reference/parameter-details). 107 | #[serde(skip_serializing_if = "Option::is_none")] 108 | #[builder(default)] 109 | pub presence_penalty: Option, 110 | /// Number between -2.0 and 2.0. 111 | /// Positive values penalize new tokens based on their existing frequency in the text so far, 112 | /// decreasing the model's likelihood to repeat the same line verbatim. 113 | /// 114 | /// [See more information about frequency and presence penalties](https://beta.openai.com/docs/api-reference/parameter-details). 115 | #[serde(skip_serializing_if = "Option::is_none")] 116 | #[builder(default)] 117 | pub frequency_penalty: Option, 118 | /// Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). 119 | /// Results cannot be streamed. 120 | /// 121 | /// When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – 122 | /// `best_of` must be greater than `n`. 123 | /// 124 | /// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. 125 | /// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. 126 | #[serde(skip_serializing_if = "Option::is_none")] 127 | #[builder(default)] 128 | pub best_of: Option, 129 | /// Modify the likelihood of specified tokens appearing in the completion. 130 | /// 131 | /// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. 132 | /// You can use this [tokenizer tool](https://beta.openai.com/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to convert text to token IDs. 133 | /// Mathematically, the bias is added to the logits generated by the model prior to sampling. 134 | /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; 135 | /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token. 136 | /// 137 | /// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated. 138 | #[serde(skip_serializing_if = "HashMap::is_empty")] 139 | #[builder(default)] 140 | pub logit_bias: HashMap, 141 | /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. 142 | /// [Learn more](https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids). 143 | #[serde(skip_serializing_if = "Option::is_none")] 144 | #[builder(default)] 145 | pub user: Option, 146 | /// The credentials to use for this request. 147 | #[serde(skip_serializing)] 148 | #[builder(default)] 149 | pub credentials: Option, 150 | } 151 | 152 | impl Completion { 153 | /// Creates a completion for the provided prompt and parameters 154 | async fn create(request: CompletionRequest) -> ApiResponseOrError { 155 | let credentials_opt = request.credentials.clone(); 156 | openai_post("completions", &request, credentials_opt).await 157 | } 158 | 159 | pub fn builder(model: &str) -> CompletionBuilder { 160 | CompletionBuilder::create_empty().model(model) 161 | } 162 | } 163 | 164 | impl CompletionBuilder { 165 | pub async fn create(self) -> ApiResponseOrError { 166 | Completion::create(self.build().unwrap()).await 167 | } 168 | } 169 | 170 | #[cfg(test)] 171 | mod tests { 172 | use super::*; 173 | use crate::tests::DEFAULT_LEGACY_MODEL; 174 | use dotenvy::dotenv; 175 | 176 | #[tokio::test] 177 | async fn completion() { 178 | dotenv().ok(); 179 | let credentials = Credentials::from_env(); 180 | 181 | let completion = Completion::builder(DEFAULT_LEGACY_MODEL) 182 | .prompt("Say this is a test") 183 | .max_tokens(7) 184 | .temperature(0.0) 185 | .credentials(credentials) 186 | .create() 187 | .await 188 | .unwrap(); 189 | 190 | assert_eq!( 191 | completion.choices.first().unwrap().text, 192 | "\n\nThis is a test." 193 | ); 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/edits.rs: -------------------------------------------------------------------------------- 1 | //! Given a prompt and an instruction, the model will return an edited version of the prompt. 2 | use super::{openai_post, ApiResponseOrError, Credentials, OpenAiError, Usage}; 3 | use derive_builder::Builder; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Deserialize, Clone)] 7 | pub struct Edit { 8 | pub created: u32, 9 | #[serde(skip_deserializing)] 10 | pub choices: Vec, 11 | pub usage: Usage, 12 | #[serde(rename = "choices")] 13 | choices_bad: Vec, 14 | } 15 | 16 | #[derive(Deserialize, Clone)] 17 | struct EditChoice { 18 | text: String, 19 | } 20 | 21 | #[derive(Serialize, Builder, Debug, Clone)] 22 | #[builder(pattern = "owned")] 23 | #[builder(name = "EditBuilder")] 24 | #[builder(setter(strip_option, into))] 25 | pub struct EditRequest { 26 | /// ID of the model to use. 27 | /// You can use the `text-davinci-edit-001` or `code-davinci-edit-001` model with this endpoint. 28 | pub model: String, 29 | /// The input text to use as a starting point for the edit. 30 | #[serde(skip_serializing_if = "Option::is_none")] 31 | #[builder(default)] 32 | pub input: Option, 33 | /// The instruction that tells the model how to edit the prompt. 34 | pub instruction: String, 35 | /// How many edits to generate for the input and instruction. 36 | #[serde(skip_serializing_if = "Option::is_none")] 37 | #[builder(setter(into = false), default)] 38 | pub n: Option, 39 | /// What [sampling temperature](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277) to use. 40 | /// Higher values means the model will take more risks. 41 | /// Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. 42 | /// 43 | /// We generally recommend altering this or `top_p` but not both. 44 | #[serde(skip_serializing_if = "Option::is_none")] 45 | #[builder(default)] 46 | pub temperature: Option, 47 | /// An alternative to sampling with temperature, called nucleus sampling, 48 | /// where the model considers the results of the tokens with top_p probability mass. 49 | /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. 50 | /// 51 | /// We generally recommend altering this or `temperature` but not both. 52 | #[serde(skip_serializing_if = "Option::is_none")] 53 | #[builder(default)] 54 | pub top_p: Option, 55 | /// The credentials to use for this request. 56 | #[serde(skip_serializing)] 57 | #[builder(default)] 58 | pub credentials: Option, 59 | } 60 | 61 | impl Edit { 62 | async fn create(request: EditRequest) -> ApiResponseOrError { 63 | let credentials_opt = request.credentials.clone(); 64 | let response: Result = 65 | openai_post("edits", &request, credentials_opt).await?; 66 | 67 | match response { 68 | Ok(mut edit) => { 69 | for choice in &edit.choices_bad { 70 | edit.choices.push(choice.text.clone()); 71 | } 72 | 73 | Ok(edit) 74 | } 75 | Err(_) => response, 76 | } 77 | } 78 | 79 | pub fn builder(model: &str, instruction: impl Into) -> EditBuilder { 80 | EditBuilder::create_empty() 81 | .model(model) 82 | .instruction(instruction) 83 | } 84 | } 85 | 86 | impl EditBuilder { 87 | pub async fn create(self) -> ApiResponseOrError { 88 | Edit::create(self.build().unwrap()).await 89 | } 90 | } 91 | 92 | #[cfg(test)] 93 | mod tests { 94 | use super::*; 95 | use dotenvy::dotenv; 96 | 97 | #[tokio::test] 98 | #[ignore] 99 | async fn edit() { 100 | dotenv().ok(); 101 | let credentials = Credentials::from_env(); 102 | 103 | let edit = Edit::builder("text-davinci-edit-001", "Fix the spelling mistakes") 104 | .input("What day of the wek is it?") 105 | .temperature(0.0) 106 | .credentials(credentials) 107 | .create() 108 | .await 109 | .unwrap(); 110 | 111 | assert_eq!( 112 | edit.choices.first().unwrap(), 113 | "What day of the week is it?\n" 114 | ); 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/embeddings.rs: -------------------------------------------------------------------------------- 1 | //! Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. 2 | //! 3 | //! Related guide: [Embeddings](https://beta.openai.com/docs/guides/embeddings) 4 | 5 | use super::{openai_post, ApiResponseOrError, Credentials}; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | #[derive(Serialize, Clone)] 9 | struct CreateEmbeddingsRequestBody<'a> { 10 | model: &'a str, 11 | input: Vec<&'a str>, 12 | #[serde(skip_serializing_if = "str::is_empty")] 13 | user: &'a str, 14 | } 15 | 16 | #[derive(Deserialize, Clone)] 17 | pub struct Embeddings { 18 | pub data: Vec, 19 | pub model: String, 20 | pub usage: EmbeddingsUsage, 21 | } 22 | 23 | #[derive(Deserialize, Clone, Copy)] 24 | pub struct EmbeddingsUsage { 25 | pub prompt_tokens: u32, 26 | pub total_tokens: u32, 27 | } 28 | 29 | #[derive(Deserialize, Clone)] 30 | pub struct Embedding { 31 | #[serde(rename = "embedding")] 32 | pub vec: Vec, 33 | } 34 | 35 | impl Embeddings { 36 | /// Creates an embedding vector representing the input text. 37 | /// 38 | /// # Arguments 39 | /// 40 | /// * `model` - ID of the model to use. 41 | /// You can use the [List models](https://beta.openai.com/docs/api-reference/models/list) 42 | /// API to see all of your available models, or see our [Model overview](https://beta.openai.com/docs/models/overview) 43 | /// for descriptions of them. 44 | /// * `input` - Input text to get embeddings for, encoded as a string or array of tokens. 45 | /// To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays. 46 | /// Each input must not exceed 8192 tokens in length. 47 | /// * `user` - A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. 48 | /// [Learn more](https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids). 49 | /// * `credentials` - The OpenAI credentials. 50 | pub async fn create( 51 | model: &str, 52 | input: Vec<&str>, 53 | user: &str, 54 | credentials: Credentials, 55 | ) -> ApiResponseOrError { 56 | openai_post( 57 | "embeddings", 58 | &CreateEmbeddingsRequestBody { model, input, user }, 59 | Some(credentials), 60 | ) 61 | .await 62 | } 63 | 64 | pub fn distances(&self) -> Vec { 65 | let mut distances = Vec::new(); 66 | let mut last_embedding: Option<&Embedding> = None; 67 | 68 | for embedding in &self.data { 69 | if let Some(other) = last_embedding { 70 | distances.push(embedding.distance(other)); 71 | } 72 | 73 | last_embedding = Some(embedding); 74 | } 75 | 76 | distances 77 | } 78 | } 79 | 80 | impl Embedding { 81 | pub async fn create( 82 | model: &str, 83 | input: &str, 84 | user: &str, 85 | credentials: Credentials, 86 | ) -> ApiResponseOrError { 87 | let mut embeddings = Embeddings::create(model, vec![input], user, credentials).await?; 88 | Ok(embeddings.data.swap_remove(0)) 89 | } 90 | 91 | pub fn magnitude(&self) -> f64 { 92 | self.vec.iter().map(|x| x * x).sum::().sqrt() 93 | } 94 | 95 | pub fn distance(&self, other: &Self) -> f64 { 96 | let dot_product: f64 = self 97 | .vec 98 | .iter() 99 | .zip(other.vec.iter()) 100 | .map(|(x, y)| x * y) 101 | .sum(); 102 | let product_of_magnitudes = self.magnitude() * other.magnitude(); 103 | 104 | 1.0 - dot_product / product_of_magnitudes 105 | } 106 | } 107 | 108 | #[cfg(test)] 109 | mod tests { 110 | use super::*; 111 | use dotenvy::dotenv; 112 | 113 | #[tokio::test] 114 | async fn embeddings() { 115 | dotenv().ok(); 116 | let credentials = Credentials::from_env(); 117 | 118 | let embeddings = Embeddings::create( 119 | "text-embedding-ada-002", 120 | vec!["The food was delicious and the waiter..."], 121 | "", 122 | credentials, 123 | ) 124 | .await 125 | .unwrap(); 126 | 127 | assert!(!embeddings.data.first().unwrap().vec.is_empty()); 128 | } 129 | 130 | #[tokio::test] 131 | async fn embedding() { 132 | dotenv().ok(); 133 | let credentials = Credentials::from_env(); 134 | 135 | let embedding = Embedding::create( 136 | "text-embedding-ada-002", 137 | "The food was delicious and the waiter...", 138 | "", 139 | credentials, 140 | ) 141 | .await 142 | .unwrap(); 143 | 144 | assert!(!embedding.vec.is_empty()); 145 | } 146 | 147 | #[test] 148 | fn right_angle() { 149 | let embeddings = Embeddings { 150 | data: vec![ 151 | Embedding { 152 | vec: vec![1.0, 0.0, 0.0], 153 | }, 154 | Embedding { 155 | vec: vec![0.0, 1.0, 0.0], 156 | }, 157 | ], 158 | model: "text-embedding-ada-002".to_string(), 159 | usage: EmbeddingsUsage { 160 | prompt_tokens: 0, 161 | total_tokens: 0, 162 | }, 163 | }; 164 | assert_eq!(embeddings.distances()[0], 1.0); 165 | } 166 | 167 | #[test] 168 | fn non_right_angle() { 169 | let embeddings = Embeddings { 170 | data: vec![ 171 | Embedding { 172 | vec: vec![1.0, 1.0, 0.0], 173 | }, 174 | Embedding { 175 | vec: vec![0.0, 1.0, 0.0], 176 | }, 177 | ], 178 | model: "text-embedding-ada-002".to_string(), 179 | usage: EmbeddingsUsage { 180 | prompt_tokens: 0, 181 | total_tokens: 0, 182 | }, 183 | }; 184 | 185 | assert_eq!(embeddings.distances()[0], 0.29289321881345254); 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /src/files.rs: -------------------------------------------------------------------------------- 1 | //! Upload, download, list, and delete files in openapi platform. Usually used for fine-tuning files. 2 | //! 3 | //! See the [Files API for OpenAI](https://platform.openai.com/docs/api-reference/files) for 4 | //! more information. 5 | //! 6 | //! # Examples 7 | //! 8 | //! All examples and tests require the `OPENAI_KEY` environment variable 9 | //! be set with your personal openai platform API key. 10 | //! 11 | //! Upload a new file. [Reference API](https://platform.openai.com/docs/api-reference/files/upload) 12 | //! ``` 13 | //!use openai::files::File; 14 | //!use openai::ApiResponseOrError; 15 | //!use dotenvy::dotenv; 16 | //!use std::env; 17 | //!use openai::Credentials; 18 | //! 19 | //!#[tokio::main] 20 | //!async fn main() -> ApiResponseOrError<()> { 21 | //! dotenv().ok(); 22 | //! let credentials = Credentials::from_env(); 23 | //! let uploaded_file = File::builder() 24 | //! .file_name("test_data/file_upload_test1.jsonl") // local file path to upload. 25 | //! .purpose("fine-tune") 26 | //! .create() 27 | //! .await?; 28 | //! assert_eq!(uploaded_file.filename, "file_upload_test1.jsonl"); 29 | //! Ok(()) 30 | //!} 31 | //! ``` 32 | //! 33 | //! List files. [Reference API](https://platform.openai.com/docs/api-reference/files/list) 34 | //! ``` 35 | //!use openai::files::Files; 36 | //!use openai::ApiResponseOrError; 37 | //!use dotenvy::dotenv; 38 | //!use std::env; 39 | //!use openai::Credentials; 40 | //! 41 | //!#[tokio::main] 42 | //!async fn main() -> ApiResponseOrError<()> { 43 | //! dotenv().ok(); 44 | //! let credentials = Credentials::from_env(); 45 | //! let openai_files = Files::list(credentials).await?; 46 | //! let file_count = openai_files.len(); 47 | //! println!("Listing {} files", file_count); 48 | //! for openai_file in openai_files.into_iter() { 49 | //! println!(" id: {}, file: {}, size: {}", openai_file.id, openai_file.filename, openai_file.bytes) 50 | //! } 51 | //! Ok(()) 52 | //!} 53 | //! ``` 54 | //! 55 | //! Retrieve a file (json metadata only). [Reference API](https://platform.openai.com/docs/api-reference/files/retrieve) 56 | //! 57 | //! ```no_run 58 | //!use openai::files::File; 59 | //!use openai::ApiResponseOrError; 60 | //!use dotenvy::dotenv; 61 | //!use std::env; 62 | //!use openai::Credentials; 63 | //! 64 | //!#[tokio::main] 65 | //!async fn main() -> ApiResponseOrError<()> { 66 | //! dotenv().ok(); 67 | //! let credentials = Credentials::from_env(); 68 | //! let file_id = "file-XjGxS3KTG0uNmNOK362iJua3"; // Use a real file id. 69 | //! let file = File::fetch(file_id, credentials).await?; 70 | //! println!("id: {}, file: {}, size: {}", file.id, file.filename, file.bytes); 71 | //! Ok(()) 72 | //!} 73 | //! ``` 74 | //! 75 | //! Download to a local file. [Reference API](https://platform.openai.com/docs/api-reference/files/retrieve-content) 76 | //! 77 | //! ```no_run 78 | //!use openai::files::File; 79 | //!use openai::ApiResponseOrError; 80 | //!use dotenvy::dotenv; 81 | //!use std::env; 82 | //!use openai::Credentials; 83 | //! 84 | //!#[tokio::main] 85 | //!async fn main() -> ApiResponseOrError<()> { 86 | //! dotenv().ok(); 87 | //! let credentials = Credentials::from_env(); 88 | //! let test_file = "test_file.jsonl"; 89 | //! let file_id = "file-XjGxS3KTG0uNmNOK362iJua3"; // Use a real file id. 90 | //! File::download_content_to_file(file_id, test_file, credentials).await?; 91 | //! Ok(()) 92 | //!} 93 | //! ``` 94 | //! 95 | //! Delete a file. [Reference API](https://platform.openai.com/docs/api-reference/files/delete) 96 | //! 97 | //! ```no_run 98 | //!use openai::files::File; 99 | //!use openai::ApiResponseOrError; 100 | //!use dotenvy::dotenv; 101 | //!use std::env; 102 | //!use openai::Credentials; 103 | //! 104 | //!#[tokio::main] 105 | //!async fn main() -> ApiResponseOrError<()> { 106 | //! dotenv().ok(); 107 | //! let credentials = Credentials::from_env(); 108 | //! let file_id = "file-XjGxS3KTG0uNmNOK362iJua3"; // Use a real file id. 109 | //! File::delete(file_id, credentials).await?; 110 | //! Ok(()) 111 | //!} 112 | //! ``` 113 | //! 114 | //! For more examples see the files tests. 115 | //! 116 | 117 | use std::io::Write; 118 | use std::path::Path; 119 | 120 | use bytes::{BufMut, BytesMut}; 121 | use derive_builder::Builder; 122 | use futures_util::StreamExt; 123 | use reqwest::multipart::{Form, Part}; 124 | use reqwest::Method; 125 | use serde::{Deserialize, Serialize}; 126 | 127 | use crate::{openai_delete, openai_get, openai_post_multipart, openai_request, Credentials}; 128 | 129 | use super::ApiResponseOrError; 130 | 131 | /// Upload, download and delete a file from the openai platform. 132 | #[derive(Deserialize, Serialize, Clone)] 133 | pub struct File { 134 | /// The unique id for this uploaded the in the openai platform. 135 | /// This id is generated by openai for each uploaded file. 136 | pub id: String, 137 | /// The object type uploaded. ie: "file" 138 | pub object: String, 139 | /// The size in bytes of the uploaded file. 140 | pub bytes: usize, 141 | /// Unix timestamp, seconds since epoch, of when the file was uploaded. 142 | pub created_at: usize, 143 | /// The name of the file uploaded. 144 | pub filename: String, 145 | /// The purpose of the file. ie: "fine-tine" 146 | pub purpose: String, 147 | } 148 | 149 | #[derive(Deserialize, Serialize, Clone)] 150 | pub struct DeletedFile { 151 | pub id: String, 152 | pub object: String, 153 | pub deleted: bool, 154 | } 155 | 156 | /// List files in the openai platform. 157 | #[derive(Deserialize, Serialize, Clone)] 158 | pub struct Files { 159 | data: Vec, 160 | pub object: String, 161 | } 162 | 163 | #[derive(Serialize, Builder, Debug, Clone)] 164 | #[builder(pattern = "owned")] 165 | #[builder(name = "FileUploadBuilder")] 166 | #[builder(setter(strip_option, into))] 167 | pub struct FileUploadRequest { 168 | file_name: String, 169 | purpose: String, 170 | /// The credentials to use for this request. 171 | #[serde(skip_serializing)] 172 | #[builder(default)] 173 | pub credentials: Option, 174 | } 175 | 176 | impl File { 177 | async fn create(request: FileUploadRequest) -> ApiResponseOrError { 178 | let upload_file_path = Path::new(request.file_name.as_str()); 179 | let upload_file_path = upload_file_path.canonicalize()?; 180 | let simple_name = upload_file_path 181 | .file_name() 182 | .unwrap() 183 | .to_str() 184 | .unwrap() 185 | .to_string() 186 | .clone(); 187 | let async_file = tokio::fs::File::open(upload_file_path).await?; 188 | let file_part = Part::stream(async_file) 189 | .file_name(simple_name) 190 | .mime_str("application/jsonl")?; 191 | let form = Form::new() 192 | .part("file", file_part) 193 | .text("purpose", request.purpose); 194 | openai_post_multipart("files", form, request.credentials).await 195 | } 196 | 197 | /// New FileUploadBuilder 198 | pub fn builder() -> FileUploadBuilder { 199 | FileUploadBuilder::create_empty() 200 | } 201 | 202 | /// Delete a file from openai platform by id. 203 | pub async fn delete(id: &str, credentials: Credentials) -> ApiResponseOrError { 204 | openai_delete(format!("files/{}", id).as_str(), Some(credentials)).await 205 | } 206 | 207 | /// Get a file from openai platform by id. 208 | #[deprecated(since = "1.0.0-alpha.16", note = "use `fetch` instead")] 209 | pub async fn get(id: &str) -> ApiResponseOrError { 210 | openai_get(format!("files/{}", id).as_str(), None).await 211 | } 212 | 213 | /// Get a file from openai platform by id. 214 | pub async fn fetch(id: &str, credentials: Credentials) -> ApiResponseOrError { 215 | openai_get(format!("files/{}", id).as_str(), Some(credentials)).await 216 | } 217 | 218 | /// Download a file as bytes into memory by id. 219 | #[deprecated(since = "1.0.0-alpha.16", note = "use `fetch_content_bytes` instead")] 220 | pub async fn get_content_bytes(id: &str) -> ApiResponseOrError> { 221 | Self::fetch_content_bytes_with_credentials_opt(id, None).await 222 | } 223 | 224 | /// Download a file as bytes into memory by id. 225 | pub async fn fetch_content_bytes( 226 | id: &str, 227 | credentials: Credentials, 228 | ) -> ApiResponseOrError> { 229 | Self::fetch_content_bytes_with_credentials_opt(id, Some(credentials)).await 230 | } 231 | 232 | async fn fetch_content_bytes_with_credentials_opt( 233 | id: &str, 234 | credentials_opt: Option, 235 | ) -> ApiResponseOrError> { 236 | let route = format!("files/{}/content", id); 237 | let response = openai_request( 238 | Method::GET, 239 | route.as_str(), 240 | |request| request, 241 | credentials_opt, 242 | ) 243 | .await?; 244 | let content_len = response.content_length().unwrap_or(1024) as usize; 245 | let mut file_bytes = BytesMut::with_capacity(content_len); 246 | let mut bytes_stream = response.bytes_stream(); 247 | while let Some(Ok(bytes)) = bytes_stream.next().await { 248 | file_bytes.put(bytes); 249 | } 250 | Ok(file_bytes.to_vec()) 251 | } 252 | 253 | /// Download a file to a new local file by id. 254 | pub async fn download_content_to_file( 255 | id: &str, 256 | file_path: &str, 257 | credentials: Credentials, 258 | ) -> ApiResponseOrError<()> { 259 | let mut output_file = std::fs::File::create(file_path)?; 260 | let route = format!("files/{}/content", id); 261 | let response = openai_request( 262 | Method::GET, 263 | route.as_str(), 264 | |request| request, 265 | Some(credentials), 266 | ) 267 | .await?; 268 | let mut bytes_stream = response.bytes_stream(); 269 | while let Some(Ok(bytes)) = bytes_stream.next().await { 270 | output_file.write_all(bytes.as_ref())?; 271 | } 272 | Ok(()) 273 | } 274 | } 275 | 276 | impl FileUploadBuilder { 277 | /// Upload the file to the openai platform. 278 | pub async fn create(self) -> ApiResponseOrError { 279 | File::create(self.build().unwrap()).await 280 | } 281 | } 282 | 283 | impl Files { 284 | /// Get a list of all uploaded files in the openai platform. 285 | pub async fn list(credentials: Credentials) -> ApiResponseOrError { 286 | openai_get("files", Some(credentials)).await 287 | } 288 | pub fn len(&self) -> usize { 289 | self.data.len() 290 | } 291 | } 292 | 293 | impl<'a> IntoIterator for &'a Files { 294 | type Item = &'a File; 295 | type IntoIter = core::slice::Iter<'a, File>; 296 | 297 | fn into_iter(self) -> Self::IntoIter { 298 | self.data.as_slice().iter() 299 | } 300 | } 301 | 302 | #[cfg(test)] 303 | mod tests { 304 | use std::env; 305 | use std::io::Read; 306 | use std::time::Duration; 307 | 308 | use dotenvy::dotenv; 309 | 310 | use crate::DEFAULT_CREDENTIALS; 311 | 312 | use super::*; 313 | 314 | fn test_upload_builder() -> FileUploadBuilder { 315 | File::builder() 316 | .file_name("test_data/file_upload_test1.jsonl") 317 | .purpose("fine-tune") 318 | } 319 | 320 | fn test_upload_request() -> FileUploadRequest { 321 | test_upload_builder().build().unwrap() 322 | } 323 | 324 | #[tokio::test] 325 | async fn upload_file() { 326 | dotenv().ok(); 327 | let credentials = Credentials::from_env(); 328 | let file_upload = test_upload_builder() 329 | .credentials(credentials) 330 | .create() 331 | .await 332 | .unwrap(); 333 | println!( 334 | "upload: {}", 335 | serde_json::to_string_pretty(&file_upload).unwrap() 336 | ); 337 | assert_eq!(file_upload.id.as_bytes()[..5], *"file-".as_bytes()) 338 | } 339 | 340 | #[tokio::test] 341 | async fn missing_file() { 342 | dotenv().ok(); 343 | let credentials = Credentials::from_env(); 344 | let test_builder = File::builder() 345 | .file_name("test_data/missing_file.jsonl") 346 | .credentials(credentials) 347 | .purpose("fine-tune"); 348 | let response = test_builder.create().await; 349 | assert!(response.is_err()); 350 | let openapi_err = response.err().unwrap(); 351 | assert_eq!(openapi_err.error_type, "io"); 352 | assert_eq!( 353 | openapi_err.message, 354 | "No such file or directory (os error 2)" 355 | ) 356 | } 357 | 358 | #[tokio::test] 359 | async fn list_files() { 360 | dotenv().ok(); 361 | let credentials = Credentials::from_env(); 362 | // ensure at least one file exists 363 | test_upload_builder().create().await.unwrap(); 364 | let openai_files = Files::list(credentials).await.unwrap(); 365 | let file_count = openai_files.len(); 366 | assert!(file_count > 0); 367 | for openai_file in openai_files.into_iter() { 368 | assert_eq!(openai_file.id.as_bytes()[..5], *"file-".as_bytes()) 369 | } 370 | println!( 371 | "files [{}]: {}", 372 | file_count, 373 | serde_json::to_string_pretty(&openai_files).unwrap() 374 | ); 375 | } 376 | 377 | #[tokio::test] 378 | async fn delete_files() { 379 | dotenv().ok(); 380 | let credentials = Credentials::from_env(); 381 | // ensure at least one file exists 382 | test_upload_builder().create().await.unwrap(); 383 | // wait to avoid recent upload still processing error 384 | tokio::time::sleep(Duration::from_secs(7)).await; 385 | let openai_files = Files::list(credentials).await.unwrap(); 386 | assert!(openai_files.data.len() > 0); 387 | let mut files = openai_files.data; 388 | files.sort_by(|a, b| a.created_at.cmp(&b.created_at)); 389 | for file in files { 390 | let deleted_file = File::delete( 391 | file.id.as_str(), 392 | DEFAULT_CREDENTIALS.read().unwrap().clone(), 393 | ) 394 | .await 395 | .unwrap(); 396 | assert!(deleted_file.deleted); 397 | println!("deleted: {} {}", deleted_file.id, deleted_file.deleted) 398 | } 399 | } 400 | 401 | #[tokio::test] 402 | async fn get_file_and_contents() { 403 | dotenv().ok(); 404 | let credentials = Credentials::from_env(); 405 | 406 | let file = test_upload_builder() 407 | .credentials(credentials.clone()) 408 | .create() 409 | .await 410 | .unwrap(); 411 | let file_get = File::fetch(file.id.as_str(), credentials.clone()) 412 | .await 413 | .unwrap(); 414 | assert_eq!(file.id, file_get.id); 415 | 416 | // get file as bytes 417 | let body_bytes = File::fetch_content_bytes(file.id.as_str(), credentials.clone()) 418 | .await 419 | .unwrap(); 420 | assert_eq!(body_bytes.len(), file.bytes); 421 | 422 | // download file to a file 423 | let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); 424 | let test_dir = format!("{}/{}", manifest_dir, "target/files-test"); 425 | std::fs::create_dir_all(test_dir.as_str()).unwrap(); 426 | let test_file_save_path = format!("{}/{}", test_dir.as_str(), file.filename); 427 | File::download_content_to_file(file.id.as_str(), test_file_save_path.as_str(), credentials) 428 | .await 429 | .unwrap(); 430 | let mut local_file = std::fs::File::open(test_file_save_path.as_str()).unwrap(); 431 | let mut local_bytes: Vec = Vec::new(); 432 | local_file.read_to_end(&mut local_bytes).unwrap(); 433 | assert_eq!(body_bytes, local_bytes) 434 | } 435 | 436 | #[test] 437 | fn file_name_path_test() { 438 | let request = test_upload_request(); 439 | let file_upload_path = Path::new(request.file_name.as_str()); 440 | let file_name = file_upload_path.file_name().unwrap().to_str().unwrap(); 441 | assert_eq!(file_name, "file_upload_test1.jsonl"); 442 | let file_upload_path = file_upload_path.canonicalize().unwrap(); 443 | let file_exists = file_upload_path.exists(); 444 | assert!(file_exists) 445 | } 446 | } 447 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | use reqwest::multipart::Form; 2 | use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder, Response}; 3 | use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt}; 4 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; 5 | use std::env; 6 | use std::env::VarError; 7 | use std::sync::{LazyLock, RwLock}; 8 | 9 | pub mod chat; 10 | pub mod completions; 11 | pub mod edits; 12 | pub mod embeddings; 13 | pub mod files; 14 | pub mod models; 15 | pub mod moderations; 16 | 17 | pub static DEFAULT_BASE_URL: LazyLock = 18 | LazyLock::new(|| String::from("https://api.openai.com/v1/")); 19 | static DEFAULT_CREDENTIALS: LazyLock> = 20 | LazyLock::new(|| RwLock::new(Credentials::from_env())); 21 | 22 | /// Holds the API key and base URL for an OpenAI-compatible API. 23 | #[derive(Debug, Clone, Eq, PartialEq)] 24 | pub struct Credentials { 25 | api_key: String, 26 | base_url: String, 27 | } 28 | 29 | impl Credentials { 30 | /// Creates credentials with the given API key and base URL. 31 | /// 32 | /// If the base URL is empty, it will use the default. 33 | pub fn new(api_key: impl Into, base_url: impl Into) -> Self { 34 | let base_url = base_url.into(); 35 | let base_url = if base_url.is_empty() { 36 | DEFAULT_BASE_URL.clone() 37 | } else { 38 | parse_base_url(base_url) 39 | }; 40 | Self { 41 | api_key: api_key.into(), 42 | base_url, 43 | } 44 | } 45 | 46 | /// Fetches the credentials from the ENV variables 47 | /// OPENAI_KEY and OPENAI_BASE_URL. 48 | /// # Panics 49 | /// This function will panic if the key variable is missing from the env. 50 | /// If only the base URL variable is missing, it will use the default. 51 | pub fn from_env() -> Credentials { 52 | let api_key = env::var("OPENAI_KEY").unwrap(); 53 | let base_url_unparsed = env::var("OPENAI_BASE_URL").unwrap_or_else(|e| match e { 54 | VarError::NotPresent => DEFAULT_BASE_URL.clone(), 55 | VarError::NotUnicode(v) => panic!("OPENAI_BASE_URL is not unicode: {v:#?}"), 56 | }); 57 | let base_url = parse_base_url(base_url_unparsed); 58 | Credentials { api_key, base_url } 59 | } 60 | 61 | pub fn api_key(&self) -> &str { 62 | &self.api_key 63 | } 64 | 65 | pub fn base_url(&self) -> &str { 66 | &self.base_url 67 | } 68 | } 69 | 70 | #[derive(Deserialize, Debug, Clone, Eq, PartialEq)] 71 | pub struct OpenAiError { 72 | pub message: String, 73 | #[serde(rename = "type")] 74 | pub error_type: String, 75 | pub param: Option, 76 | pub code: Option, 77 | } 78 | 79 | /// Pagination options for API requests fetching lists of items. 80 | #[derive(Serialize, Debug, Clone, Eq, PartialEq, Default)] 81 | pub struct RequestPagination { 82 | #[serde(skip_serializing_if = "Option::is_none")] 83 | pub limit: Option, 84 | 85 | #[serde(skip_serializing_if = "Option::is_none")] 86 | pub after: Option, 87 | 88 | #[serde(skip_serializing_if = "Option::is_none")] 89 | pub order: Option, 90 | } 91 | 92 | /// Order of items in a list. 93 | #[derive(Serialize, Debug, Clone, Eq, PartialEq)] 94 | pub enum RequestOrder { 95 | /// Ascending order. 96 | #[serde(rename = "asc")] 97 | Ascending, 98 | 99 | /// Descending order. 100 | #[serde(rename = "desc")] 101 | Descending, 102 | } 103 | 104 | impl OpenAiError { 105 | fn new(message: String, error_type: String) -> OpenAiError { 106 | OpenAiError { 107 | message, 108 | error_type, 109 | param: None, 110 | code: None, 111 | } 112 | } 113 | } 114 | 115 | impl std::fmt::Display for OpenAiError { 116 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 117 | f.write_str(&self.message) 118 | } 119 | } 120 | 121 | impl std::error::Error for OpenAiError {} 122 | 123 | #[derive(Deserialize, Clone)] 124 | #[serde(untagged)] 125 | pub enum ApiResponse { 126 | Err { error: OpenAiError }, 127 | Ok(T), 128 | } 129 | 130 | #[derive(Deserialize, Clone, Copy, Debug, Eq, PartialEq)] 131 | pub struct Usage { 132 | pub prompt_tokens: u32, 133 | pub completion_tokens: u32, 134 | pub total_tokens: u32, 135 | } 136 | 137 | pub type ApiResponseOrError = Result; 138 | 139 | impl From for OpenAiError { 140 | fn from(value: reqwest::Error) -> Self { 141 | OpenAiError::new(value.to_string(), "reqwest".to_string()) 142 | } 143 | } 144 | 145 | impl From for OpenAiError { 146 | fn from(value: std::io::Error) -> Self { 147 | OpenAiError::new(value.to_string(), "io".to_string()) 148 | } 149 | } 150 | 151 | async fn openai_request_json( 152 | method: Method, 153 | route: &str, 154 | builder: F, 155 | credentials_opt: Option, 156 | ) -> ApiResponseOrError 157 | where 158 | F: FnOnce(RequestBuilder) -> RequestBuilder, 159 | T: DeserializeOwned, 160 | { 161 | let api_response = openai_request(method, route, builder, credentials_opt) 162 | .await? 163 | .json() 164 | .await?; 165 | match api_response { 166 | ApiResponse::Ok(t) => Ok(t), 167 | ApiResponse::Err { error } => Err(error), 168 | } 169 | } 170 | 171 | async fn openai_request( 172 | method: Method, 173 | route: &str, 174 | builder: F, 175 | credentials_opt: Option, 176 | ) -> ApiResponseOrError 177 | where 178 | F: FnOnce(RequestBuilder) -> RequestBuilder, 179 | { 180 | let client = Client::new(); 181 | let credentials = 182 | credentials_opt.unwrap_or_else(|| DEFAULT_CREDENTIALS.read().unwrap().clone()); 183 | let mut request = client.request(method, format!("{}{route}", credentials.base_url)); 184 | request = builder(request); 185 | let response = request 186 | .header(AUTHORIZATION, format!("Bearer {}", credentials.api_key)) 187 | .send() 188 | .await?; 189 | Ok(response) 190 | } 191 | 192 | async fn openai_request_stream( 193 | method: Method, 194 | route: &str, 195 | builder: F, 196 | credentials_opt: Option, 197 | ) -> Result 198 | where 199 | F: FnOnce(RequestBuilder) -> RequestBuilder, 200 | { 201 | let client = Client::new(); 202 | let credentials = 203 | credentials_opt.unwrap_or_else(|| DEFAULT_CREDENTIALS.read().unwrap().clone()); 204 | let mut request = client.request(method, format!("{}{route}", credentials.base_url)); 205 | request = builder(request); 206 | let stream = request 207 | .header(AUTHORIZATION, format!("Bearer {}", credentials.api_key)) 208 | .eventsource()?; 209 | Ok(stream) 210 | } 211 | 212 | async fn openai_get(route: &str, credentials_opt: Option) -> ApiResponseOrError 213 | where 214 | T: DeserializeOwned, 215 | { 216 | openai_request_json(Method::GET, route, |request| request, credentials_opt).await 217 | } 218 | 219 | async fn openai_get_with_query( 220 | route: &str, 221 | query: &Query, 222 | credentials_opt: Option, 223 | ) -> ApiResponseOrError 224 | where 225 | T: DeserializeOwned, 226 | Query: Serialize + ?Sized, 227 | { 228 | openai_request_json( 229 | Method::GET, 230 | route, 231 | |request| request.query(query), 232 | credentials_opt, 233 | ) 234 | .await 235 | } 236 | 237 | async fn openai_delete( 238 | route: &str, 239 | credentials_opt: Option, 240 | ) -> ApiResponseOrError 241 | where 242 | T: DeserializeOwned, 243 | { 244 | openai_request_json(Method::DELETE, route, |request| request, credentials_opt).await 245 | } 246 | 247 | async fn openai_post( 248 | route: &str, 249 | json: &J, 250 | credentials_opt: Option, 251 | ) -> ApiResponseOrError 252 | where 253 | J: Serialize + ?Sized, 254 | T: DeserializeOwned, 255 | { 256 | openai_request_json( 257 | Method::POST, 258 | route, 259 | |request| request.json(json), 260 | credentials_opt, 261 | ) 262 | .await 263 | } 264 | 265 | async fn openai_post_multipart( 266 | route: &str, 267 | form: Form, 268 | credentials_opt: Option, 269 | ) -> ApiResponseOrError 270 | where 271 | T: DeserializeOwned, 272 | { 273 | openai_request_json( 274 | Method::POST, 275 | route, 276 | |request| request.multipart(form), 277 | credentials_opt, 278 | ) 279 | .await 280 | } 281 | 282 | /// Sets the key for all OpenAI API functions. 283 | /// 284 | /// ## Examples 285 | /// 286 | /// Use environment variable `OPENAI_KEY` defined from `.env` file: 287 | /// 288 | /// ```rust 289 | /// use openai::set_key; 290 | /// use dotenvy::dotenv; 291 | /// use std::env; 292 | /// 293 | /// dotenv().ok(); 294 | /// set_key(env::var("OPENAI_KEY").unwrap()); 295 | /// ``` 296 | #[deprecated( 297 | since = "1.0.0-alpha.16", 298 | note = "use the `Credentials` struct instead" 299 | )] 300 | pub fn set_key(value: String) { 301 | let mut credentials = DEFAULT_CREDENTIALS.write().unwrap(); 302 | credentials.api_key = value; 303 | } 304 | 305 | /// Sets the base url for all OpenAI API functions. 306 | /// 307 | /// ## Examples 308 | /// 309 | /// Use environment variable `OPENAI_BASE_URL` defined from `.env` file: 310 | /// 311 | /// ```rust 312 | /// use openai::set_base_url; 313 | /// use dotenvy::dotenv; 314 | /// use std::env; 315 | /// 316 | /// dotenv().ok(); 317 | /// set_base_url(env::var("OPENAI_BASE_URL").unwrap_or_default()); 318 | /// ``` 319 | #[deprecated( 320 | since = "1.0.0-alpha.16", 321 | note = "use the `Credentials` struct instead" 322 | )] 323 | pub fn set_base_url(mut value: String) { 324 | if value.is_empty() { 325 | return; 326 | } 327 | value = parse_base_url(value); 328 | let mut credentials = DEFAULT_CREDENTIALS.write().unwrap(); 329 | credentials.base_url = value; 330 | } 331 | 332 | fn parse_base_url(mut value: String) -> String { 333 | if !value.ends_with('/') { 334 | value += "/"; 335 | } 336 | value 337 | } 338 | 339 | #[cfg(test)] 340 | pub mod tests { 341 | pub const DEFAULT_LEGACY_MODEL: &str = "gpt-3.5-turbo-instruct"; 342 | } 343 | -------------------------------------------------------------------------------- /src/models.rs: -------------------------------------------------------------------------------- 1 | //! List and describe the various models available in the API. 2 | //! You can refer to the [Models](https://beta.openai.com/docs/models) 3 | //! documentation to understand what models are available and the differences between them. 4 | 5 | use super::{openai_get, ApiResponseOrError, Credentials}; 6 | use serde::Deserialize; 7 | 8 | #[derive(Deserialize, Clone)] 9 | pub struct Model { 10 | pub id: String, 11 | pub object: String, 12 | pub created: u32, 13 | pub owned_by: String, 14 | } 15 | 16 | #[derive(Deserialize, Clone)] 17 | pub struct ModelPermission { 18 | pub id: String, 19 | pub created: u32, 20 | pub allow_create_engine: bool, 21 | pub allow_sampling: bool, 22 | pub allow_logprobs: bool, 23 | pub allow_search_indices: bool, 24 | pub allow_view: bool, 25 | pub allow_fine_tuning: bool, 26 | pub organization: String, 27 | pub group: Option, 28 | pub is_blocking: bool, 29 | } 30 | 31 | impl Model { 32 | /// Retrieves a model instance, 33 | /// providing basic information about the model such as the owner and permissioning. 34 | #[deprecated(since = "1.0.0-alpha.16", note = "use `fetch` instead")] 35 | pub async fn from(id: &str) -> ApiResponseOrError { 36 | openai_get(&format!("models/{id}"), None).await 37 | } 38 | 39 | /// Retrieves a model instance, 40 | /// providing basic information about the model such as the owner and permissioning. 41 | pub async fn fetch(id: &str, credentials: Credentials) -> ApiResponseOrError { 42 | openai_get(&format!("models/{id}"), Some(credentials)).await 43 | } 44 | } 45 | 46 | #[cfg(test)] 47 | mod tests { 48 | use super::*; 49 | use crate::tests::DEFAULT_LEGACY_MODEL; 50 | use dotenvy::dotenv; 51 | 52 | #[tokio::test] 53 | async fn model() { 54 | dotenv().ok(); 55 | let credentials = Credentials::from_env(); 56 | let model = Model::fetch(DEFAULT_LEGACY_MODEL, credentials) 57 | .await 58 | .unwrap(); 59 | assert_eq!(model.id, DEFAULT_LEGACY_MODEL); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/moderations.rs: -------------------------------------------------------------------------------- 1 | //! Given a input text, outputs if the model classifies it as violating OpenAI's content policy. 2 | use super::{openai_post, ApiResponseOrError, Credentials}; 3 | use derive_builder::Builder; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Deserialize, Clone, Debug)] 7 | pub struct Moderation { 8 | pub id: String, 9 | pub model: String, 10 | pub results: Vec, 11 | } 12 | 13 | #[derive(Deserialize, Clone, Debug)] 14 | pub struct ModerationResult { 15 | pub flagged: bool, 16 | pub categories: Categories, 17 | pub category_scores: CategoryScores, 18 | } 19 | 20 | #[derive(Deserialize, Clone, Copy, Debug)] 21 | pub struct Categories { 22 | pub hate: bool, 23 | #[serde(rename = "hate/threatening")] 24 | pub hate_threatening: bool, 25 | #[serde(rename = "self-harm")] 26 | pub self_harm: bool, 27 | pub sexual: bool, 28 | #[serde(rename = "sexual/minors")] 29 | pub sexual_minors: bool, 30 | pub violence: bool, 31 | #[serde(rename = "violence/graphic")] 32 | pub violence_graphic: bool, 33 | } 34 | 35 | #[derive(Deserialize, Clone, Debug)] 36 | pub struct CategoryScores { 37 | pub hate: f64, 38 | #[serde(rename = "hate/threatening")] 39 | pub hate_threatening: f64, 40 | #[serde(rename = "self-harm")] 41 | pub self_harm: f64, 42 | pub sexual: f64, 43 | #[serde(rename = "sexual/minors")] 44 | pub sexual_minors: f64, 45 | pub violence: f64, 46 | #[serde(rename = "violence/graphic")] 47 | pub violence_graphic: f64, 48 | } 49 | 50 | #[derive(Serialize, Builder, Debug, Clone)] 51 | #[builder(pattern = "owned")] 52 | #[builder(name = "ModerationBuilder")] 53 | #[builder(setter(strip_option, into))] 54 | pub struct ModerationRequest { 55 | /// The input text to classify. 56 | pub input: String, 57 | /// ID of the model to use. 58 | /// Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`. 59 | #[serde(skip_serializing_if = "Option::is_none")] 60 | #[builder(default)] 61 | pub model: Option, 62 | /// The credentials to use for this request. 63 | #[serde(skip_serializing)] 64 | #[builder(default)] 65 | pub credentials: Option, 66 | } 67 | 68 | impl Moderation { 69 | async fn create(request: ModerationRequest) -> ApiResponseOrError { 70 | openai_post("moderations", &request, request.credentials.clone()).await 71 | } 72 | 73 | pub fn builder(input: impl Into) -> ModerationBuilder { 74 | ModerationBuilder::create_empty().input(input) 75 | } 76 | } 77 | 78 | impl ModerationBuilder { 79 | pub async fn create(self) -> ApiResponseOrError { 80 | Moderation::create(self.build().unwrap()).await 81 | } 82 | } 83 | 84 | #[cfg(test)] 85 | mod tests { 86 | use super::*; 87 | use dotenvy::dotenv; 88 | 89 | #[tokio::test] 90 | async fn moderations() { 91 | dotenv().ok(); 92 | let credentials = Credentials::from_env(); 93 | 94 | let moderation = Moderation::builder("I want to kill them.") 95 | .model("text-moderation-latest") 96 | .credentials(credentials) 97 | .create() 98 | .await 99 | .unwrap(); 100 | 101 | assert_eq!( 102 | moderation.results.first().unwrap().categories.violence, 103 | true 104 | ); 105 | assert_eq!(moderation.results.first().unwrap().flagged, true); 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /test_data/file_upload_test1.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt": "example data: the most correct data\n###\n", "completion": "yes"} 2 | {"prompt": "example data: totally wrong data\n###\n", "completion": "no"} 3 | {"prompt": "example data: very correct data\n###\n", "completion": "yes"} --------------------------------------------------------------------------------