├── .github ├── renovate.json ├── pull_request_template.md ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ └── test.yml ├── tests ├── setup.rs ├── v1_tool_test.rs ├── v1_client_list_models_test.rs ├── v1_client_list_models_async_test.rs ├── v1_client_embeddings_test.rs ├── v1_client_embeddings_async_test.rs ├── v1_constants_test.rs ├── v1_client_chat_stream_test.rs ├── v1_client_chat_test.rs ├── v1_client_chat_async_test.rs └── v1_client_new_test.rs ├── .env.example ├── src ├── lib.rs └── v1 │ ├── mod.rs │ ├── common.rs │ ├── error.rs │ ├── utils.rs │ ├── model_list.rs │ ├── constants.rs │ ├── chat_stream.rs │ ├── embedding.rs │ ├── tool.rs │ ├── chat.rs │ └── client.rs ├── SECURITY.md ├── release.toml ├── examples ├── list_models.rs ├── list_models_async.rs ├── embeddings.rs ├── embeddings_async.rs ├── chat.rs ├── chat_async.rs ├── chat_with_streaming.rs ├── chat_with_function_calling.rs └── chat_with_function_calling_async.rs ├── .editorconfig ├── .gitignore ├── Cargo.toml ├── Makefile ├── CONTRIBUTING.md ├── README.template.md ├── CODE_OF_CONDUCT.md ├── CHANGELOG.md ├── LICENSE.md └── README.md /.github/renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["github>ivangabriele/renovate-config"] 3 | } 4 | -------------------------------------------------------------------------------- /tests/setup.rs: -------------------------------------------------------------------------------- 1 | pub fn setup() { 2 | let _ = env_logger::builder().is_test(true).try_init(); 3 | } 4 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # This key is only used for development purposes. 2 | # You'll only need one if you want to contribute to this library. 3 | export MISTRAL_API_KEY= 4 | -------------------------------------------------------------------------------- /tests/v1_tool_test.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::client::Client; 2 | 3 | trait _Trait: Send {} 4 | struct _Foo { 5 | _dummy: Client, 6 | } 7 | impl _Trait for _Foo {} 8 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate provides a easy bindings and types for MistralAI's API. 2 | 3 | /// The v1 module contains the types and methods for the v1 API endpoints. 4 | pub mod v1; 5 | -------------------------------------------------------------------------------- /src/v1/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod chat; 2 | pub mod chat_stream; 3 | pub mod client; 4 | pub mod common; 5 | pub mod constants; 6 | pub mod embedding; 7 | pub mod error; 8 | pub mod model_list; 9 | pub mod tool; 10 | pub mod utils; 11 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | We only support the latest version of this project. 6 | 7 | ## Reporting a Vulnerability 8 | 9 | You can report a vulnerability by opening an issue on this repository. 10 | -------------------------------------------------------------------------------- /src/v1/common.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Clone, Debug, Deserialize, Serialize)] 4 | pub struct ResponseUsage { 5 | pub prompt_tokens: u32, 6 | pub completion_tokens: u32, 7 | pub total_tokens: u32, 8 | } 9 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | A clear and concise description of what your pull request is about. 4 | 5 | ## Checklist 6 | 7 | - [ ] I updated the documentation accordingly. Or I don't need to. 8 | - [ ] I updated the tests accordingly. Or I don't need to. 9 | -------------------------------------------------------------------------------- /release.toml: -------------------------------------------------------------------------------- 1 | # https://github.com/crate-ci/cargo-release/blob/master/docs/reference.md 2 | allow-branch = ["main"] 3 | pre-release-commit-message = "ci(release): v{{version}}" 4 | pre-release-replacements = [{ file = "CHANGELOG.md", search = "## \\[\\]", replace = "## [{{version}}]" }] 5 | -------------------------------------------------------------------------------- /tests/v1_client_list_models_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::client::Client; 3 | 4 | #[test] 5 | fn test_client_list_models() { 6 | let client = Client::new(None, None, None, None).unwrap(); 7 | 8 | let response = client.list_models().unwrap(); 9 | 10 | expect!(response.object).to_be("list".to_string()); 11 | expect!(response.data.len()).to_be_greater_than(0); 12 | } 13 | -------------------------------------------------------------------------------- /examples/list_models.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::client::Client; 2 | 3 | fn main() { 4 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 5 | let client = Client::new(None, None, None, None).unwrap(); 6 | 7 | let result = client.list_models().unwrap(); 8 | println!("First Model ID: {:?}", result.data[0].id); 9 | // => "First Model ID: open-mistral-7b" 10 | } 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest a new idea for the project. 4 | title: "" 5 | labels: "enhancement" 6 | assignees: "" 7 | --- 8 | 9 | **Is your feature request related to some problems?** 10 | 11 | - _Ex. I'm always frustrated when..._ 12 | 13 | **What are the solutions you'd like?** 14 | 15 | - _Ex. A new option to..._ 16 | 17 | **Anything else?** 18 | 19 | - ... 20 | -------------------------------------------------------------------------------- /tests/v1_client_list_models_async_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::client::Client; 3 | 4 | #[tokio::test] 5 | async fn test_client_list_models_async() { 6 | let client = Client::new(None, None, None, None).unwrap(); 7 | 8 | let response = client.list_models_async().await.unwrap(); 9 | 10 | expect!(response.object).to_be("list".to_string()); 11 | expect!(response.data.len()).to_be_greater_than(0); 12 | } 13 | -------------------------------------------------------------------------------- /examples/list_models_async.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::client::Client; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 6 | let client = Client::new(None, None, None, None).unwrap(); 7 | 8 | let result = client.list_models_async().await.unwrap(); 9 | println!("First Model ID: {:?}", result.data[0].id); 10 | // => "First Model ID: open-mistral-7b" 11 | } 12 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # https://editorconfig.org 2 | root = true 3 | 4 | [*] 5 | charset = utf-8 6 | end_of_line = lf 7 | indent_size = 2 8 | indent_style = space 9 | insert_final_newline = true 10 | max_line_length = 120 11 | trim_trailing_whitespace = true 12 | 13 | [*.md] 14 | max_line_length = 0 15 | trim_trailing_whitespace = false 16 | 17 | [*.py] 18 | indent_size = 4 19 | 20 | [*.rs] 21 | indent_size = 4 22 | max_line_length = 80 23 | 24 | [*.xml] 25 | trim_trailing_whitespace = false 26 | 27 | [COMMIT_EDITMSG] 28 | max_line_length = 0 29 | 30 | [Makefile] 31 | indent_size = 8 32 | indent_style = tab 33 | max_line_length = 80 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "" 5 | labels: "bug" 6 | assignees: "" 7 | --- 8 | 9 | **Describe the bug** 10 | 11 | ... 12 | 13 | **Reproduction** 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. ... 18 | 2. ... 19 | 20 | **Expected behavior** 21 | 22 | ... 23 | 24 | **Screenshots** 25 | 26 | If applicable, add screenshots to help explain your problem. 27 | 28 | **Version** 29 | 30 | If applicable, what version did you use? 31 | 32 | **Environment** 33 | 34 | If applicable, add relevant information about your config and environment here. 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ######################################## 2 | # Rust 3 | 4 | # Generated by Cargo 5 | # will have compiled files and executables 6 | debug/ 7 | target/ 8 | 9 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 10 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 11 | Cargo.lock 12 | 13 | # These are backup files generated by rustfmt 14 | **/*.rs.bk 15 | 16 | # MSVC Windows builds of rustc generate these, which store debugging information 17 | *.pdb 18 | 19 | ######################################## 20 | # Custom 21 | 22 | # Tarpaulin coverage output 23 | /cobertura.xml 24 | 25 | .env 26 | -------------------------------------------------------------------------------- /examples/embeddings.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::{client::Client, constants::EmbedModel}; 2 | 3 | fn main() { 4 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 5 | let client: Client = Client::new(None, None, None, None).unwrap(); 6 | 7 | let model = EmbedModel::MistralEmbed; 8 | let input = vec!["Embed this sentence.", "As well as this one."] 9 | .iter() 10 | .map(|s| s.to_string()) 11 | .collect(); 12 | let options = None; 13 | 14 | let response = client.embeddings(model, input, options).unwrap(); 15 | println!("First Embedding: {:?}", response.data[0]); 16 | // => "First Embedding: {...}" 17 | } 18 | -------------------------------------------------------------------------------- /src/v1/error.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | 4 | #[derive(Debug)] 5 | pub struct ApiError { 6 | pub message: String, 7 | } 8 | impl fmt::Display for ApiError { 9 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 10 | write!(f, "ApiError: {}", self.message) 11 | } 12 | } 13 | impl Error for ApiError {} 14 | 15 | #[derive(Debug, PartialEq, thiserror::Error)] 16 | pub enum ClientError { 17 | #[error( 18 | "You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...)." 19 | )] 20 | MissingApiKey, 21 | #[error("Failed to read the response text.")] 22 | UnreadableResponseText, 23 | } 24 | -------------------------------------------------------------------------------- /examples/embeddings_async.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::{client::Client, constants::EmbedModel}; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 6 | let client: Client = Client::new(None, None, None, None).unwrap(); 7 | 8 | let model = EmbedModel::MistralEmbed; 9 | let input = vec!["Embed this sentence.", "As well as this one."] 10 | .iter() 11 | .map(|s| s.to_string()) 12 | .collect(); 13 | let options = None; 14 | 15 | let response = client 16 | .embeddings_async(model, input, options) 17 | .await 18 | .unwrap(); 19 | println!("First Embedding: {:?}", response.data[0]); 20 | // => "First Embedding: {...}" 21 | } 22 | -------------------------------------------------------------------------------- /examples/chat.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::{ 2 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 3 | client::Client, 4 | constants::Model, 5 | }; 6 | 7 | fn main() { 8 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 9 | let client = Client::new(None, None, None, None).unwrap(); 10 | 11 | let model = Model::OpenMistral7b; 12 | let messages = vec![ChatMessage { 13 | role: ChatMessageRole::User, 14 | content: "Just guess the next word: \"Eiffel ...\"?".to_string(), 15 | tool_calls: None, 16 | }]; 17 | let options = ChatParams { 18 | temperature: 0.0, 19 | random_seed: Some(42), 20 | ..Default::default() 21 | }; 22 | 23 | let result = client.chat(model, messages, Some(options)).unwrap(); 24 | println!("Assistant: {}", result.choices[0].message.content); 25 | // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." 26 | } 27 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mistralai-client" 3 | description = "Mistral AI API client library for Rust (unofficial)." 4 | license = "Apache-2.0" 5 | version = "0.14.0" 6 | 7 | edition = "2021" 8 | rust-version = "1.76.0" 9 | 10 | authors = ["Ivan Gabriele "] 11 | categories = ["api-bindings"] 12 | homepage = "https://github.com/ivangabriele/mistralai-client-rs#readme" 13 | keywords = ["mistral", "mistralai", "client", "api", "llm"] 14 | readme = "README.md" 15 | repository = "https://github.com/ivangabriele/mistralai-client-rs" 16 | 17 | [dependencies] 18 | async-stream = "0.3.5" 19 | async-trait = "0.1.77" 20 | env_logger = "0.11.3" 21 | futures = "0.3.30" 22 | log = "0.4.21" 23 | reqwest = { version = "0.12.0", features = ["json", "blocking", "stream"] } 24 | serde = { version = "1.0.197", features = ["derive"] } 25 | serde_json = "1.0.114" 26 | strum = "0.26.1" 27 | thiserror = "1.0.57" 28 | tokio = { version = "1.36.0", features = ["full"] } 29 | tokio-stream = "0.1.14" 30 | 31 | [dev-dependencies] 32 | jrest = "0.2.3" 33 | -------------------------------------------------------------------------------- /src/v1/utils.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use log::debug; 4 | use serde::Serialize; 5 | 6 | pub fn prettify_json_string(json: &String) -> String { 7 | match serde_json::from_str::(&json) { 8 | Ok(json_value) => { 9 | serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| json.to_owned()) 10 | } 11 | Err(_) => json.to_owned(), 12 | } 13 | } 14 | 15 | pub fn prettify_json_struct(value: T) -> String { 16 | match serde_json::to_string_pretty(&value) { 17 | Ok(pretty_json) => pretty_json, 18 | Err(_) => format!("{:?}", value), 19 | } 20 | } 21 | 22 | pub fn debug_pretty_json_from_string(label: &str, json: &String) -> () { 23 | let pretty_json = prettify_json_string(json); 24 | 25 | debug!("{label}: {}", pretty_json); 26 | } 27 | 28 | pub fn debug_pretty_json_from_struct(label: &str, value: &T) -> () { 29 | let pretty_json = prettify_json_struct(value); 30 | 31 | debug!("{label}: {}", pretty_json); 32 | } 33 | -------------------------------------------------------------------------------- /examples/chat_async.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::{ 2 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 3 | client::Client, 4 | constants::Model, 5 | }; 6 | 7 | #[tokio::main] 8 | async fn main() { 9 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 10 | let client = Client::new(None, None, None, None).unwrap(); 11 | 12 | let model = Model::OpenMistral7b; 13 | let messages = vec![ChatMessage { 14 | role: ChatMessageRole::User, 15 | content: "Just guess the next word: \"Eiffel ...\"?".to_string(), 16 | tool_calls: None, 17 | }]; 18 | let options = ChatParams { 19 | temperature: 0.0, 20 | random_seed: Some(42), 21 | ..Default::default() 22 | }; 23 | 24 | let result = client 25 | .chat_async(model, messages, Some(options)) 26 | .await 27 | .unwrap(); 28 | println!( 29 | "{:?}: {}", 30 | result.choices[0].message.role, result.choices[0].message.content 31 | ); 32 | // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." 33 | } 34 | -------------------------------------------------------------------------------- /tests/v1_client_embeddings_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::{client::Client, constants::EmbedModel}; 3 | 4 | #[test] 5 | fn test_client_embeddings() { 6 | let client: Client = Client::new(None, None, None, None).unwrap(); 7 | 8 | let model = EmbedModel::MistralEmbed; 9 | let input = vec!["Embed this sentence.", "As well as this one."] 10 | .iter() 11 | .map(|s| s.to_string()) 12 | .collect(); 13 | let options = None; 14 | 15 | let response = client.embeddings(model, input, options).unwrap(); 16 | 17 | expect!(response.model).to_be(EmbedModel::MistralEmbed); 18 | expect!(response.object).to_be("list".to_string()); 19 | expect!(response.data.len()).to_be(2); 20 | expect!(response.data[0].index).to_be(0); 21 | expect!(response.data[0].object.clone()).to_be("embedding".to_string()); 22 | expect!(response.data[0].embedding.len()).to_be_greater_than(0); 23 | expect!(response.usage.prompt_tokens).to_be_greater_than(0); 24 | expect!(response.usage.completion_tokens).to_be(0); 25 | expect!(response.usage.total_tokens).to_be_greater_than(0); 26 | } 27 | -------------------------------------------------------------------------------- /src/v1/model_list.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | // ----------------------------------------------------------------------------- 4 | // Response 5 | 6 | #[derive(Clone, Debug, Deserialize, Serialize)] 7 | pub struct ModelListResponse { 8 | pub object: String, 9 | pub data: Vec, 10 | } 11 | 12 | /// See: https://docs.mistral.ai/api/#tag/models 13 | #[derive(Clone, Debug, Deserialize, Serialize)] 14 | pub struct ModelListData { 15 | pub id: String, 16 | pub object: String, 17 | /// Unix timestamp (in seconds). 18 | pub created: u32, 19 | pub owned_by: String, 20 | pub root: Option, 21 | pub archived: bool, 22 | pub name: String, 23 | pub description: String, 24 | pub capabilities: ModelListDataCapabilies, 25 | pub max_context_length: u32, 26 | pub aliases: Vec, 27 | /// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`). 28 | pub deprecation: Option, 29 | } 30 | 31 | #[derive(Clone, Debug, Deserialize, Serialize)] 32 | pub struct ModelListDataCapabilies { 33 | pub completion_chat: bool, 34 | pub completion_fim: bool, 35 | pub function_calling: bool, 36 | pub fine_tuning: bool, 37 | } 38 | -------------------------------------------------------------------------------- /tests/v1_client_embeddings_async_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::{client::Client, constants::EmbedModel}; 3 | 4 | #[tokio::test] 5 | async fn test_client_embeddings_async() { 6 | let client: Client = Client::new(None, None, None, None).unwrap(); 7 | 8 | let model = EmbedModel::MistralEmbed; 9 | let input = vec!["Embed this sentence.", "As well as this one."] 10 | .iter() 11 | .map(|s| s.to_string()) 12 | .collect(); 13 | let options = None; 14 | 15 | let response = client 16 | .embeddings_async(model, input, options) 17 | .await 18 | .unwrap(); 19 | 20 | expect!(response.model).to_be(EmbedModel::MistralEmbed); 21 | expect!(response.object).to_be("list".to_string()); 22 | expect!(response.data.len()).to_be(2); 23 | expect!(response.data[0].index).to_be(0); 24 | expect!(response.data[0].object.clone()).to_be("embedding".to_string()); 25 | expect!(response.data[0].embedding.len()).to_be_greater_than(0); 26 | expect!(response.usage.prompt_tokens).to_be_greater_than(0); 27 | expect!(response.usage.completion_tokens).to_be(0); 28 | expect!(response.usage.total_tokens).to_be_greater_than(0); 29 | } 30 | -------------------------------------------------------------------------------- /src/v1/constants.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | pub const API_URL_BASE: &str = "https://api.mistral.ai/v1"; 4 | 5 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 6 | pub enum Model { 7 | #[serde(rename = "open-mistral-7b")] 8 | OpenMistral7b, 9 | #[serde(rename = "open-mixtral-8x7b")] 10 | OpenMixtral8x7b, 11 | #[serde(rename = "open-mixtral-8x22b")] 12 | OpenMixtral8x22b, 13 | #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")] 14 | OpenMistralNemo, 15 | #[serde(rename = "mistral-tiny")] 16 | MistralTiny, 17 | #[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")] 18 | MistralSmallLatest, 19 | #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")] 20 | MistralMediumLatest, 21 | #[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")] 22 | MistralLargeLatest, 23 | #[serde(rename = "mistral-large-2402")] 24 | MistralLarge, 25 | #[serde(rename = "codestral-latest", alias = "codestral-2405")] 26 | CodestralLatest, 27 | #[serde(rename = "open-codestral-mamba")] 28 | CodestralMamba, 29 | } 30 | 31 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 32 | pub enum EmbedModel { 33 | #[serde(rename = "mistral-embed")] 34 | MistralEmbed, 35 | } 36 | -------------------------------------------------------------------------------- /tests/v1_constants_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::{ 3 | chat::{ChatMessage, ChatParams}, 4 | client::Client, 5 | constants::Model, 6 | }; 7 | 8 | #[test] 9 | fn test_model_constant() { 10 | let models = vec![ 11 | Model::OpenMistral7b, 12 | Model::OpenMixtral8x7b, 13 | Model::OpenMixtral8x22b, 14 | Model::OpenMistralNemo, 15 | Model::MistralTiny, 16 | Model::MistralSmallLatest, 17 | Model::MistralMediumLatest, 18 | Model::MistralLargeLatest, 19 | Model::MistralLarge, 20 | Model::CodestralLatest, 21 | Model::CodestralMamba, 22 | ]; 23 | 24 | let client = Client::new(None, None, None, None).unwrap(); 25 | 26 | let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")]; 27 | let options = ChatParams { 28 | temperature: 0.0, 29 | random_seed: Some(42), 30 | ..Default::default() 31 | }; 32 | 33 | for model in models { 34 | let response = client 35 | .chat(model.clone(), messages.clone(), Some(options.clone())) 36 | .unwrap(); 37 | 38 | expect!(response.model).to_be(model); 39 | expect!(response.object).to_be("chat.completion".to_string()); 40 | expect!(response.choices.len()).to_be(1); 41 | expect!(response.choices[0].index).to_be(0); 42 | expect!(response.choices[0].message.content.len()).to_be_greater_than(0); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /examples/chat_with_streaming.rs: -------------------------------------------------------------------------------- 1 | use futures::stream::StreamExt; 2 | use mistralai_client::v1::{ 3 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 4 | client::Client, 5 | constants::Model, 6 | }; 7 | use std::io::{self, Write}; 8 | 9 | #[tokio::main] 10 | async fn main() { 11 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 12 | let client = Client::new(None, None, None, None).unwrap(); 13 | 14 | let model = Model::OpenMistral7b; 15 | let messages = vec![ChatMessage { 16 | role: ChatMessageRole::User, 17 | content: "Tell me a short happy story.".to_string(), 18 | tool_calls: None, 19 | }]; 20 | let options = ChatParams { 21 | temperature: 0.0, 22 | random_seed: Some(42), 23 | ..Default::default() 24 | }; 25 | 26 | let stream_result = client 27 | .chat_stream(model, messages, Some(options)) 28 | .await 29 | .unwrap(); 30 | stream_result 31 | .for_each(|chunk_result| async { 32 | match chunk_result { 33 | Ok(chunks) => chunks.iter().for_each(|chunk| { 34 | print!("{}", chunk.choices[0].delta.content); 35 | io::stdout().flush().unwrap(); 36 | // => "Once upon a time, [...]" 37 | }), 38 | Err(error) => { 39 | eprintln!("Error processing chunk: {:?}", error) 40 | } 41 | } 42 | }) 43 | .await; 44 | print!("\n") // To persist the last chunk output. 45 | } 46 | -------------------------------------------------------------------------------- /tests/v1_client_chat_stream_test.rs: -------------------------------------------------------------------------------- 1 | // use futures::stream::StreamExt; 2 | // use jrest::expect; 3 | // use mistralai_client::v1::{ 4 | // chat_completion::{ChatParams, ChatMessage, ChatMessageRole}, 5 | // client::Client, 6 | // constants::Model, 7 | // }; 8 | 9 | // #[tokio::test] 10 | // async fn test_client_chat_stream() { 11 | // let client = Client::new(None, None, None, None).unwrap(); 12 | 13 | // let model = Model::OpenMistral7b; 14 | // let messages = vec![ChatMessage::new_user_message( 15 | // "Just guess the next word: \"Eiffel ...\"?", 16 | // )]; 17 | // let options = ChatParams { 18 | // temperature: Some(0.0), 19 | // random_seed: Some(42), 20 | // ..Default::default() 21 | // }; 22 | 23 | // let stream_result = client.chat_stream(model, messages, Some(options)).await; 24 | // let mut stream = stream_result.expect("Failed to create stream."); 25 | // while let Some(maybe_chunk_result) = stream.next().await { 26 | // match maybe_chunk_result { 27 | // Some(Ok(chunk)) => { 28 | // if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant) 29 | // || chunk.choices[0].finish_reason == Some("stop".to_string()) 30 | // { 31 | // expect!(chunk.choices[0].delta.content.len()).to_be(0); 32 | // } else { 33 | // expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0); 34 | // } 35 | // } 36 | // Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error), 37 | // None => (), 38 | // } 39 | // } 40 | // } 41 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | pull_request: 5 | push: 6 | 7 | jobs: 8 | test: 9 | name: Test 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v4 14 | - name: Setup Rust 15 | uses: actions-rs/toolchain@v1 16 | with: 17 | toolchain: 1.76.0 18 | - name: Install cargo-llvm-cov 19 | uses: taiki-e/install-action@cargo-llvm-cov 20 | - name: Run tests (with coverage) 21 | run: cargo llvm-cov --lcov --output-path ./lcov.info 22 | env: 23 | MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} 24 | - name: Upload tests coverage 25 | uses: codecov/codecov-action@v4 26 | with: 27 | fail_ci_if_error: true 28 | files: ./lcov.info 29 | token: ${{ secrets.CODECOV_TOKEN }} 30 | 31 | test_documentation: 32 | name: Test Documentation 33 | runs-on: ubuntu-latest 34 | steps: 35 | - name: Checkout 36 | uses: actions/checkout@v4 37 | - name: Setup Rust 38 | uses: actions-rs/toolchain@v1 39 | with: 40 | toolchain: 1.76.0 41 | - name: Run documentation tests 42 | run: make test-doc 43 | env: 44 | MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} 45 | 46 | test_examples: 47 | name: Test Examples 48 | runs-on: ubuntu-latest 49 | steps: 50 | - name: Checkout 51 | uses: actions/checkout@v4 52 | - name: Setup Rust 53 | uses: actions-rs/toolchain@v1 54 | with: 55 | toolchain: 1.76.0 56 | - name: Run examples 57 | run: make test-examples 58 | env: 59 | MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} 60 | -------------------------------------------------------------------------------- /src/v1/chat_stream.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use serde_json::from_str; 3 | 4 | use crate::v1::{chat, common, constants, error}; 5 | 6 | // ----------------------------------------------------------------------------- 7 | // Response 8 | 9 | #[derive(Clone, Debug, Deserialize, Serialize)] 10 | pub struct ChatStreamChunk { 11 | pub id: String, 12 | pub object: String, 13 | /// Unix timestamp (in seconds). 14 | pub created: u32, 15 | pub model: constants::Model, 16 | pub choices: Vec, 17 | pub usage: Option, 18 | // TODO Check this prop (seen in API responses but undocumented). 19 | // pub logprobs: ???, 20 | } 21 | 22 | #[derive(Clone, Debug, Deserialize, Serialize)] 23 | pub struct ChatStreamChunkChoice { 24 | pub index: u32, 25 | pub delta: ChatStreamChunkChoiceDelta, 26 | pub finish_reason: Option, 27 | // TODO Check this prop (seen in API responses but undocumented). 28 | // pub logprobs: ???, 29 | } 30 | 31 | #[derive(Clone, Debug, Deserialize, Serialize)] 32 | pub struct ChatStreamChunkChoiceDelta { 33 | pub role: Option, 34 | pub content: String, 35 | } 36 | 37 | /// Extracts serialized chunks from a stream message. 38 | pub fn get_chunk_from_stream_message_line( 39 | line: &str, 40 | ) -> Result>, error::ApiError> { 41 | if line.trim() == "data: [DONE]" { 42 | return Ok(None); 43 | } 44 | 45 | let chunk_as_json = line.trim_start_matches("data: ").trim(); 46 | if chunk_as_json.is_empty() { 47 | return Ok(Some(vec![])); 48 | } 49 | 50 | // Attempt to deserialize the JSON string into ChatStreamChunk 51 | match from_str::(chunk_as_json) { 52 | Ok(chunk) => Ok(Some(vec![chunk])), 53 | Err(e) => Err(error::ApiError { 54 | message: e.to_string(), 55 | }), 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/v1/embedding.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::v1::{common, constants}; 4 | 5 | // ----------------------------------------------------------------------------- 6 | // Request 7 | 8 | #[derive(Debug)] 9 | pub struct EmbeddingRequestOptions { 10 | pub encoding_format: Option, 11 | } 12 | impl Default for EmbeddingRequestOptions { 13 | fn default() -> Self { 14 | Self { 15 | encoding_format: None, 16 | } 17 | } 18 | } 19 | 20 | #[derive(Debug, Serialize, Deserialize)] 21 | pub struct EmbeddingRequest { 22 | pub model: constants::EmbedModel, 23 | pub input: Vec, 24 | #[serde(skip_serializing_if = "Option::is_none")] 25 | pub encoding_format: Option, 26 | } 27 | impl EmbeddingRequest { 28 | pub fn new( 29 | model: constants::EmbedModel, 30 | input: Vec, 31 | options: Option, 32 | ) -> Self { 33 | let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default(); 34 | 35 | Self { 36 | model, 37 | input, 38 | encoding_format, 39 | } 40 | } 41 | } 42 | 43 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 44 | #[allow(non_camel_case_types)] 45 | pub enum EmbeddingRequestEncodingFormat { 46 | float, 47 | } 48 | 49 | // ----------------------------------------------------------------------------- 50 | // Response 51 | 52 | #[derive(Clone, Debug, Deserialize, Serialize)] 53 | pub struct EmbeddingResponse { 54 | pub id: String, 55 | pub object: String, 56 | pub model: constants::EmbedModel, 57 | pub data: Vec, 58 | pub usage: common::ResponseUsage, 59 | } 60 | 61 | #[derive(Clone, Debug, Deserialize, Serialize)] 62 | pub struct EmbeddingResponseDataItem { 63 | pub index: u32, 64 | pub embedding: Vec, 65 | pub object: String, 66 | } 67 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | .PHONY: doc readme test 4 | 5 | define source_env_if_not_ci 6 | @if [ -z "$${CI}" ]; then \ 7 | if [ -f ./.env ]; then \ 8 | source ./.env; \ 9 | else \ 10 | echo "No .env file found"; \ 11 | exit 1; \ 12 | fi \ 13 | fi 14 | endef 15 | 16 | define RELEASE_TEMPLATE 17 | npx conventional-changelog-cli -p conventionalcommits -i ./CHANGELOG.md -s 18 | git add . 19 | git commit -m "docs(changelog): update" 20 | git push origin HEAD 21 | cargo release $(1) --execute 22 | git push origin HEAD --tags 23 | endef 24 | 25 | doc: 26 | cargo doc 27 | open ./target/doc/mistralai_client/index.html 28 | 29 | readme: 30 | @echo "Generating README.md from template..." 31 | @> README.md # Clear README.md content before starting 32 | @while IFS= read -r line || [[ -n "$$line" ]]; do \ 33 | if [[ $$line == *""* && $$line == *""* ]]; then \ 34 | example_path=$$(echo $$line | sed -n 's/.*\(.*\)<\/CODE>.*/\1/p'); \ 35 | if [ -f $$example_path ]; then \ 36 | echo '```rs' >> README.md; \ 37 | cat $$example_path >> README.md; \ 38 | echo '```' >> README.md; \ 39 | else \ 40 | echo "Error: Example $$example_path not found." >&2; \ 41 | fi; \ 42 | else \ 43 | echo "$$line" >> README.md; \ 44 | fi; \ 45 | done < README.template.md 46 | @echo "README.md has been generated." 47 | 48 | release-patch: 49 | $(call RELEASE_TEMPLATE,patch) 50 | release-minor: 51 | $(call RELEASE_TEMPLATE,minor) 52 | release-major: 53 | $(call RELEASE_TEMPLATE,major) 54 | 55 | test: 56 | @$(source_env_if_not_ci) && \ 57 | cargo test --no-fail-fast 58 | test-cover: 59 | @$(source_env_if_not_ci) && \ 60 | cargo llvm-cov 61 | test-doc: 62 | @$(source_env_if_not_ci) && \ 63 | cargo test --doc --no-fail-fast 64 | test-examples: 65 | @$(source_env_if_not_ci) && \ 66 | for example in $$(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do \ 67 | echo "Running $$example"; \ 68 | cargo run --example $$example; \ 69 | done 70 | test-watch: 71 | @source ./.env && \ 72 | cargo watch -x "test -- --nocapture" 73 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribute 2 | 3 | - [Getting Started](#getting-started) 4 | - [Requirements](#requirements) 5 | - [First setup](#first-setup) 6 | - [Optional requirements](#optional-requirements) 7 | - [Local Development](#local-development) 8 | - [Test](#test) 9 | - [Documentation](#documentation) 10 | - [Readme](#readme) 11 | - [Code of Conduct](#code-of-conduct) 12 | - [Commit Message Format](#commit-message-format) 13 | 14 | --- 15 | 16 | ## Getting Started 17 | 18 | ### Requirements 19 | 20 | - [Rust](https://www.rust-lang.org/tools/install): v1 21 | 22 | ### First setup 23 | 24 | > [!IMPORTANT] 25 | > If you're under **Windows**, you nust run all CLI commands under a Linux shell-like terminal (i.e.: WSL or Git Bash). 26 | 27 | Then run: 28 | 29 | ```sh 30 | git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork 31 | cd ./mistralai-client-rs 32 | cargo build 33 | cp .env.example .env 34 | ``` 35 | 36 | Then edit the `.env` file to set your `MISTRAL_API_KEY`. 37 | 38 | > [!NOTE] 39 | > All tests use either the `open-mistral-7b` or `mistral-embed` models and only consume a few dozen tokens. 40 | > So you would have to run them thousands of times to even reach a single dollar of usage. 41 | 42 | ### Optional requirements 43 | 44 | - [cargo-llvm-cov](https://github.com/taiki-e/cargo-llvm-cov?tab=readme-ov-file#installation) for `make test-cover` 45 | - [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-watch`. 46 | 47 | ## Local Development 48 | 49 | ### Test 50 | 51 | ```sh 52 | make test 53 | ``` 54 | 55 | or 56 | 57 | ```sh 58 | make test-watch 59 | ``` 60 | 61 | ## Documentation 62 | 63 | ### Readme 64 | 65 | > [!IMPORTANT] 66 | > Do not edit the `README.md` file directly. It is generated from the `README.template.md` file. 67 | 68 | 1. Edit the `README.template.md` file. 69 | 2. Run `make readme` to generate/update the `README.md` file. 70 | 71 | ## Code of Conduct 72 | 73 | Help us keep this project open and inclusive. Please read and follow our [Code of Conduct](./CODE_OF_CONDUCT.md). 74 | 75 | ## Commit Message Format 76 | 77 | This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification. 78 | -------------------------------------------------------------------------------- /examples/chat_with_function_calling.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::{ 2 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 3 | client::Client, 4 | constants::Model, 5 | tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, 6 | }; 7 | use serde::Deserialize; 8 | use std::any::Any; 9 | 10 | #[derive(Debug, Deserialize)] 11 | struct GetCityTemperatureArguments { 12 | city: String, 13 | } 14 | 15 | struct GetCityTemperatureFunction; 16 | #[async_trait::async_trait] 17 | impl Function for GetCityTemperatureFunction { 18 | async fn execute(&self, arguments: String) -> Box { 19 | // Deserialize arguments, perform the logic, and return the result 20 | let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); 21 | 22 | let temperature = match city.as_str() { 23 | "Paris" => "20°C", 24 | _ => "Unknown city", 25 | }; 26 | 27 | Box::new(temperature.to_string()) 28 | } 29 | } 30 | 31 | fn main() { 32 | let tools = vec![Tool::new( 33 | "get_city_temperature".to_string(), 34 | "Get the current temperature in a city.".to_string(), 35 | vec![ToolFunctionParameter::new( 36 | "city".to_string(), 37 | "The name of the city.".to_string(), 38 | ToolFunctionParameterType::String, 39 | )], 40 | )]; 41 | 42 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 43 | let mut client = Client::new(None, None, None, None).unwrap(); 44 | client.register_function( 45 | "get_city_temperature".to_string(), 46 | Box::new(GetCityTemperatureFunction), 47 | ); 48 | 49 | let model = Model::MistralSmallLatest; 50 | let messages = vec![ChatMessage { 51 | role: ChatMessageRole::User, 52 | content: "What's the temperature in Paris?".to_string(), 53 | tool_calls: None, 54 | }]; 55 | let options = ChatParams { 56 | temperature: 0.0, 57 | random_seed: Some(42), 58 | tool_choice: Some(ToolChoice::Auto), 59 | tools: Some(tools), 60 | ..Default::default() 61 | }; 62 | 63 | client.chat(model, messages, Some(options)).unwrap(); 64 | let temperature = client 65 | .get_last_function_call_result() 66 | .unwrap() 67 | .downcast::() 68 | .unwrap(); 69 | println!("The temperature in Paris is: {}.", temperature); 70 | // => "The temperature in Paris is: 20°C." 71 | } 72 | -------------------------------------------------------------------------------- /examples/chat_with_function_calling_async.rs: -------------------------------------------------------------------------------- 1 | use mistralai_client::v1::{ 2 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 3 | client::Client, 4 | constants::Model, 5 | tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, 6 | }; 7 | use serde::Deserialize; 8 | use std::any::Any; 9 | 10 | #[derive(Debug, Deserialize)] 11 | struct GetCityTemperatureArguments { 12 | city: String, 13 | } 14 | 15 | struct GetCityTemperatureFunction; 16 | #[async_trait::async_trait] 17 | impl Function for GetCityTemperatureFunction { 18 | async fn execute(&self, arguments: String) -> Box { 19 | // Deserialize arguments, perform the logic, and return the result 20 | let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); 21 | 22 | let temperature = match city.as_str() { 23 | "Paris" => "20°C", 24 | _ => "Unknown city", 25 | }; 26 | 27 | Box::new(temperature.to_string()) 28 | } 29 | } 30 | 31 | #[tokio::main] 32 | async fn main() { 33 | let tools = vec![Tool::new( 34 | "get_city_temperature".to_string(), 35 | "Get the current temperature in a city.".to_string(), 36 | vec![ToolFunctionParameter::new( 37 | "city".to_string(), 38 | "The name of the city.".to_string(), 39 | ToolFunctionParameterType::String, 40 | )], 41 | )]; 42 | 43 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 44 | let mut client = Client::new(None, None, None, None).unwrap(); 45 | client.register_function( 46 | "get_city_temperature".to_string(), 47 | Box::new(GetCityTemperatureFunction), 48 | ); 49 | 50 | let model = Model::MistralSmallLatest; 51 | let messages = vec![ChatMessage { 52 | role: ChatMessageRole::User, 53 | content: "What's the temperature in Paris?".to_string(), 54 | tool_calls: None, 55 | }]; 56 | let options = ChatParams { 57 | temperature: 0.0, 58 | random_seed: Some(42), 59 | tool_choice: Some(ToolChoice::Auto), 60 | tools: Some(tools), 61 | ..Default::default() 62 | }; 63 | 64 | client 65 | .chat_async(model, messages, Some(options)) 66 | .await 67 | .unwrap(); 68 | let temperature = client 69 | .get_last_function_call_result() 70 | .unwrap() 71 | .downcast::() 72 | .unwrap(); 73 | println!("The temperature in Paris is: {}.", temperature); 74 | // => "The temperature in Paris is: 20°C." 75 | } 76 | -------------------------------------------------------------------------------- /tests/v1_client_chat_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::{ 3 | chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, 4 | client::Client, 5 | constants::Model, 6 | tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, 7 | }; 8 | 9 | mod setup; 10 | 11 | #[test] 12 | fn test_client_chat() { 13 | setup::setup(); 14 | 15 | let client = Client::new(None, None, None, None).unwrap(); 16 | 17 | let model = Model::OpenMistral7b; 18 | let messages = vec![ChatMessage::new_user_message( 19 | "Guess the next word: \"Eiffel ...\"?", 20 | )]; 21 | let options = ChatParams { 22 | temperature: 0.0, 23 | random_seed: Some(42), 24 | ..Default::default() 25 | }; 26 | 27 | let response = client.chat(model, messages, Some(options)).unwrap(); 28 | 29 | expect!(response.model).to_be(Model::OpenMistral7b); 30 | expect!(response.object).to_be("chat.completion".to_string()); 31 | expect!(response.choices.len()).to_be(1); 32 | expect!(response.choices[0].index).to_be(0); 33 | expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); 34 | expect!(response.choices[0] 35 | .message 36 | .content 37 | .clone() 38 | .contains("Tower")) 39 | .to_be(true); 40 | expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); 41 | expect!(response.usage.prompt_tokens).to_be_greater_than(0); 42 | expect!(response.usage.completion_tokens).to_be_greater_than(0); 43 | expect!(response.usage.total_tokens).to_be_greater_than(0); 44 | } 45 | 46 | #[test] 47 | fn test_client_chat_with_function_calling() { 48 | setup::setup(); 49 | 50 | let tools = vec![Tool::new( 51 | "get_city_temperature".to_string(), 52 | "Get the current temperature in a city.".to_string(), 53 | vec![ToolFunctionParameter::new( 54 | "city".to_string(), 55 | "The name of the city.".to_string(), 56 | ToolFunctionParameterType::String, 57 | )], 58 | )]; 59 | 60 | let client = Client::new(None, None, None, None).unwrap(); 61 | 62 | let model = Model::MistralSmallLatest; 63 | let messages = vec![ChatMessage::new_user_message( 64 | "What's the current temperature in Paris?", 65 | )]; 66 | let options = ChatParams { 67 | temperature: 0.0, 68 | random_seed: Some(42), 69 | tool_choice: Some(ToolChoice::Auto), 70 | tools: Some(tools), 71 | ..Default::default() 72 | }; 73 | 74 | let response = client.chat(model, messages, Some(options)).unwrap(); 75 | 76 | expect!(response.model).to_be(Model::MistralSmallLatest); 77 | expect!(response.object).to_be("chat.completion".to_string()); 78 | expect!(response.choices.len()).to_be(1); 79 | expect!(response.choices[0].index).to_be(0); 80 | expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); 81 | expect!(response.choices[0].message.content.clone()).to_be("".to_string()); 82 | expect!(response.choices[0].finish_reason.clone()) 83 | .to_be(ChatResponseChoiceFinishReason::ToolCalls); 84 | expect!(response.usage.prompt_tokens).to_be_greater_than(0); 85 | expect!(response.usage.completion_tokens).to_be_greater_than(0); 86 | expect!(response.usage.total_tokens).to_be_greater_than(0); 87 | } 88 | -------------------------------------------------------------------------------- /README.template.md: -------------------------------------------------------------------------------- 1 | # Mistral AI Rust Client 2 | 3 | [![Crates.io Package](https://img.shields.io/crates/v/mistralai-client?style=for-the-badge)](https://crates.io/crates/mistralai-client) 4 | [![Docs.rs Documentation](https://img.shields.io/docsrs/mistralai-client/latest?style=for-the-badge)](https://docs.rs/mistralai-client/latest/mistralai-client) 5 | [![Test Workflow Status](https://img.shields.io/github/actions/workflow/status/ivangabriele/mistralai-client-rs/test.yml?label=CI&style=for-the-badge)](https://github.com/ivangabriele/mistralai-client-rs/actions?query=branch%3Amain+workflow%3ATest++) 6 | [![Code Coverage](https://img.shields.io/codecov/c/github/ivangabriele/mistralai-client-rs/main?label=Cov&style=for-the-badge)](https://app.codecov.io/github/ivangabriele/mistralai-client-rs) 7 | 8 | Rust client for the Mistral AI API. 9 | 10 | > [!IMPORTANT] 11 | > While we are in v0, minor versions may introduce breaking changes. 12 | > Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information. 13 | 14 | --- 15 | 16 | - [Supported APIs](#supported-apis) 17 | - [Installation](#installation) 18 | - [Mistral API Key](#mistral-api-key) 19 | - [As an environment variable](#as-an-environment-variable) 20 | - [As a client argument](#as-a-client-argument) 21 | - [Usage](#usage) 22 | - [Chat](#chat) 23 | - [Chat (async)](#chat-async) 24 | - [Chat with streaming (async)](#chat-with-streaming-async) 25 | - [Chat with Function Calling](#chat-with-function-calling) 26 | - [Chat with Function Calling (async)](#chat-with-function-calling-async) 27 | - [Embeddings](#embeddings) 28 | - [Embeddings (async)](#embeddings-async) 29 | - [List models](#list-models) 30 | - [List models (async)](#list-models-async) 31 | - [Contributing](#contributing) 32 | 33 | --- 34 | 35 | ## Supported APIs 36 | 37 | - [x] Chat without streaming 38 | - [x] Chat without streaming (async) 39 | - [x] Chat with streaming 40 | - [x] Embedding 41 | - [x] Embedding (async) 42 | - [x] List models 43 | - [x] List models (async) 44 | - [x] Function Calling 45 | - [x] Function Calling (async) 46 | 47 | ## Installation 48 | 49 | You can install the library in your project using: 50 | 51 | ```sh 52 | cargo add mistralai-client 53 | ``` 54 | 55 | ### Mistral API Key 56 | 57 | You can get your Mistral API Key there: . 58 | 59 | #### As an environment variable 60 | 61 | Just set the `MISTRAL_API_KEY` environment variable. 62 | 63 | ```rs 64 | use mistralai_client::v1::client::Client; 65 | 66 | fn main() { 67 | let client = Client::new(None, None, None, None); 68 | } 69 | ``` 70 | 71 | ```sh 72 | MISTRAL_API_KEY=your_api_key cargo run 73 | ``` 74 | 75 | #### As a client argument 76 | 77 | ```rs 78 | use mistralai_client::v1::client::Client; 79 | 80 | fn main() { 81 | let api_key = "your_api_key"; 82 | 83 | let client = Client::new(Some(api_key), None, None, None).unwrap(); 84 | } 85 | ``` 86 | 87 | ## Usage 88 | 89 | ### Chat 90 | 91 | examples/chat.rs 92 | 93 | ### Chat (async) 94 | 95 | examples/chat_async.rs 96 | 97 | ### Chat with streaming (async) 98 | 99 | examples/chat_with_streaming.rs 100 | 101 | ### Chat with Function Calling 102 | 103 | examples/chat_with_function_calling.rs 104 | 105 | ### Chat with Function Calling (async) 106 | 107 | examples/chat_with_function_calling_async.rs 108 | 109 | ### Embeddings 110 | 111 | examples/embeddings.rs 112 | 113 | ### Embeddings (async) 114 | 115 | examples/embeddings_async.rs 116 | 117 | ### List models 118 | 119 | examples/list_models.rs 120 | 121 | ### List models (async) 122 | 123 | examples/list_models_async.rs 124 | 125 | ## Contributing 126 | 127 | Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for details on how to contribute to this library. 128 | -------------------------------------------------------------------------------- /tests/v1_client_chat_async_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::{ 3 | chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, 4 | client::Client, 5 | constants::Model, 6 | tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, 7 | }; 8 | 9 | mod setup; 10 | 11 | #[tokio::test] 12 | async fn test_client_chat_async() { 13 | setup::setup(); 14 | 15 | let client = Client::new(None, None, None, None).unwrap(); 16 | 17 | let model = Model::OpenMistral7b; 18 | let messages = vec![ChatMessage::new_user_message( 19 | "Guess the next word: \"Eiffel ...\"?", 20 | )]; 21 | let options = ChatParams { 22 | temperature: 0.0, 23 | random_seed: Some(42), 24 | ..Default::default() 25 | }; 26 | 27 | let response = client 28 | .chat_async(model, messages, Some(options)) 29 | .await 30 | .unwrap(); 31 | 32 | expect!(response.model).to_be(Model::OpenMistral7b); 33 | expect!(response.object).to_be("chat.completion".to_string()); 34 | 35 | expect!(response.choices.len()).to_be(1); 36 | expect!(response.choices[0].index).to_be(0); 37 | expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); 38 | 39 | expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); 40 | expect!(response.choices[0] 41 | .message 42 | .content 43 | .clone() 44 | .contains("Tower")) 45 | .to_be(true); 46 | 47 | expect!(response.usage.prompt_tokens).to_be_greater_than(0); 48 | expect!(response.usage.completion_tokens).to_be_greater_than(0); 49 | expect!(response.usage.total_tokens).to_be_greater_than(0); 50 | } 51 | 52 | #[tokio::test] 53 | async fn test_client_chat_async_with_function_calling() { 54 | setup::setup(); 55 | 56 | let tools = vec![Tool::new( 57 | "get_city_temperature".to_string(), 58 | "Get the current temperature in a city.".to_string(), 59 | vec![ToolFunctionParameter::new( 60 | "city".to_string(), 61 | "The name of the city.".to_string(), 62 | ToolFunctionParameterType::String, 63 | )], 64 | )]; 65 | 66 | let client = Client::new(None, None, None, None).unwrap(); 67 | 68 | let model = Model::MistralSmallLatest; 69 | let messages = vec![ChatMessage::new_user_message( 70 | "What's the current temperature in Paris?", 71 | )]; 72 | let options = ChatParams { 73 | temperature: 0.0, 74 | random_seed: Some(42), 75 | tool_choice: Some(ToolChoice::Any), 76 | tools: Some(tools), 77 | ..Default::default() 78 | }; 79 | 80 | let response = client 81 | .chat_async(model, messages, Some(options)) 82 | .await 83 | .unwrap(); 84 | 85 | expect!(response.model).to_be(Model::MistralSmallLatest); 86 | expect!(response.object).to_be("chat.completion".to_string()); 87 | 88 | expect!(response.choices.len()).to_be(1); 89 | expect!(response.choices[0].index).to_be(0); 90 | expect!(response.choices[0].finish_reason.clone()) 91 | .to_be(ChatResponseChoiceFinishReason::ToolCalls); 92 | 93 | expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); 94 | expect!(response.choices[0].message.content.clone()).to_be("".to_string()); 95 | // expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall { 96 | // function: ToolCallFunction { 97 | // name: "get_city_temperature".to_string(), 98 | // arguments: "{\"city\": \"Paris\"}".to_string(), 99 | // }, 100 | // }])); 101 | 102 | expect!(response.usage.prompt_tokens).to_be_greater_than(0); 103 | expect!(response.usage.completion_tokens).to_be_greater_than(0); 104 | expect!(response.usage.total_tokens).to_be_greater_than(0); 105 | } 106 | -------------------------------------------------------------------------------- /tests/v1_client_new_test.rs: -------------------------------------------------------------------------------- 1 | use jrest::expect; 2 | use mistralai_client::v1::{client::Client, error::ClientError}; 3 | 4 | #[derive(Debug)] 5 | struct _Foo { 6 | _client: Client, 7 | } 8 | 9 | #[test] 10 | fn test_client_new_with_none_params() { 11 | let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); 12 | std::env::remove_var("MISTRAL_API_KEY"); 13 | std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); 14 | 15 | let client = Client::new(None, None, None, None).unwrap(); 16 | 17 | expect!(client.api_key).to_be("test_api_key_from_env".to_string()); 18 | expect!(client.endpoint).to_be("https://api.mistral.ai/v1".to_string()); 19 | expect!(client.max_retries).to_be(5); 20 | expect!(client.timeout).to_be(120); 21 | 22 | match maybe_original_mistral_api_key { 23 | Some(original_mistral_api_key) => { 24 | std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) 25 | } 26 | None => std::env::remove_var("MISTRAL_API_KEY"), 27 | } 28 | } 29 | 30 | #[test] 31 | fn test_client_new_with_all_params() { 32 | let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); 33 | std::env::remove_var("MISTRAL_API_KEY"); 34 | 35 | let api_key = Some("test_api_key_from_param".to_string()); 36 | let endpoint = Some("https://example.org".to_string()); 37 | let max_retries = Some(10); 38 | let timeout = Some(20); 39 | 40 | let client = Client::new( 41 | api_key.clone(), 42 | endpoint.clone(), 43 | max_retries.clone(), 44 | timeout.clone(), 45 | ) 46 | .unwrap(); 47 | 48 | expect!(client.api_key).to_be(api_key.unwrap()); 49 | expect!(client.endpoint).to_be(endpoint.unwrap()); 50 | expect!(client.max_retries).to_be(max_retries.unwrap()); 51 | expect!(client.timeout).to_be(timeout.unwrap()); 52 | 53 | match maybe_original_mistral_api_key { 54 | Some(original_mistral_api_key) => { 55 | std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) 56 | } 57 | None => std::env::remove_var("MISTRAL_API_KEY"), 58 | } 59 | } 60 | 61 | #[test] 62 | fn test_client_new_with_api_key_as_both_env_and_param() { 63 | let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); 64 | std::env::remove_var("MISTRAL_API_KEY"); 65 | std::env::set_var("MISTRAL_API_KEY", "test_api_key_from_env"); 66 | 67 | let api_key = Some("test_api_key_from_param".to_string()); 68 | let endpoint = Some("https://example.org".to_string()); 69 | let max_retries = Some(10); 70 | let timeout = Some(20); 71 | 72 | let client = Client::new( 73 | api_key.clone(), 74 | endpoint.clone(), 75 | max_retries.clone(), 76 | timeout.clone(), 77 | ) 78 | .unwrap(); 79 | 80 | expect!(client.api_key).to_be(api_key.unwrap()); 81 | expect!(client.endpoint).to_be(endpoint.unwrap()); 82 | expect!(client.max_retries).to_be(max_retries.unwrap()); 83 | expect!(client.timeout).to_be(timeout.unwrap()); 84 | 85 | match maybe_original_mistral_api_key { 86 | Some(original_mistral_api_key) => { 87 | std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) 88 | } 89 | None => std::env::remove_var("MISTRAL_API_KEY"), 90 | } 91 | } 92 | 93 | #[test] 94 | fn test_client_new_with_missing_api_key() { 95 | let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok(); 96 | std::env::remove_var("MISTRAL_API_KEY"); 97 | 98 | let call = || Client::new(None, None, None, None); 99 | 100 | match call() { 101 | Ok(_) => panic!("Expected `ClientError::MissingApiKey` but got Ok.`"), 102 | Err(error) => assert_eq!(error, ClientError::MissingApiKey), 103 | } 104 | 105 | match maybe_original_mistral_api_key { 106 | Some(original_mistral_api_key) => { 107 | std::env::set_var("MISTRAL_API_KEY", original_mistral_api_key) 108 | } 109 | None => std::env::remove_var("MISTRAL_API_KEY"), 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/v1/tool.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use serde::{Deserialize, Serialize}; 3 | use std::{any::Any, collections::HashMap, fmt::Debug}; 4 | 5 | // ----------------------------------------------------------------------------- 6 | // Definitions 7 | 8 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 9 | pub struct ToolCall { 10 | pub function: ToolCallFunction, 11 | } 12 | 13 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 14 | pub struct ToolCallFunction { 15 | pub name: String, 16 | pub arguments: String, 17 | } 18 | 19 | #[derive(Clone, Debug, Deserialize, Serialize)] 20 | pub struct Tool { 21 | pub r#type: ToolType, 22 | pub function: ToolFunction, 23 | } 24 | impl Tool { 25 | pub fn new( 26 | function_name: String, 27 | function_description: String, 28 | function_parameters: Vec, 29 | ) -> Self { 30 | let properties: HashMap = function_parameters 31 | .into_iter() 32 | .map(|param| { 33 | ( 34 | param.name, 35 | ToolFunctionParameterProperty { 36 | r#type: param.r#type, 37 | description: param.description, 38 | }, 39 | ) 40 | }) 41 | .collect(); 42 | let property_names = properties.keys().cloned().collect(); 43 | 44 | let parameters = ToolFunctionParameters { 45 | r#type: ToolFunctionParametersType::Object, 46 | properties, 47 | required: property_names, 48 | }; 49 | 50 | Self { 51 | r#type: ToolType::Function, 52 | function: ToolFunction { 53 | name: function_name, 54 | description: function_description, 55 | parameters, 56 | }, 57 | } 58 | } 59 | } 60 | 61 | // ----------------------------------------------------------------------------- 62 | // Request 63 | 64 | #[derive(Clone, Debug, Deserialize, Serialize)] 65 | pub struct ToolFunction { 66 | name: String, 67 | description: String, 68 | parameters: ToolFunctionParameters, 69 | } 70 | 71 | #[derive(Clone, Debug, Deserialize, Serialize)] 72 | pub struct ToolFunctionParameter { 73 | name: String, 74 | description: String, 75 | r#type: ToolFunctionParameterType, 76 | } 77 | impl ToolFunctionParameter { 78 | pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self { 79 | Self { 80 | name, 81 | r#type, 82 | description, 83 | } 84 | } 85 | } 86 | 87 | #[derive(Clone, Debug, Deserialize, Serialize)] 88 | pub struct ToolFunctionParameters { 89 | r#type: ToolFunctionParametersType, 90 | properties: HashMap, 91 | required: Vec, 92 | } 93 | 94 | #[derive(Clone, Debug, Deserialize, Serialize)] 95 | pub struct ToolFunctionParameterProperty { 96 | r#type: ToolFunctionParameterType, 97 | description: String, 98 | } 99 | 100 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 101 | pub enum ToolFunctionParametersType { 102 | #[serde(rename = "object")] 103 | Object, 104 | } 105 | 106 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 107 | pub enum ToolFunctionParameterType { 108 | #[serde(rename = "string")] 109 | String, 110 | } 111 | 112 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 113 | pub enum ToolType { 114 | #[serde(rename = "function")] 115 | Function, 116 | } 117 | 118 | /// An enum representing how functions should be called. 119 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 120 | pub enum ToolChoice { 121 | /// The model is forced to call a function. 122 | #[serde(rename = "any")] 123 | Any, 124 | /// The model can choose to either generate a message or call a function. 125 | #[serde(rename = "auto")] 126 | Auto, 127 | /// The model won't call a function and will generate a message instead. 128 | #[serde(rename = "none")] 129 | None, 130 | } 131 | 132 | // ----------------------------------------------------------------------------- 133 | // Custom 134 | 135 | #[async_trait] 136 | pub trait Function: Send { 137 | async fn execute(&self, arguments: String) -> Box; 138 | } 139 | 140 | impl Debug for dyn Function { 141 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 142 | write!(f, "Function()") 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | [INSERT CONTACT METHOD]. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available 126 | at [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /src/v1/chat.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::v1::{common, constants, tool}; 4 | 5 | // ----------------------------------------------------------------------------- 6 | // Definitions 7 | 8 | #[derive(Clone, Debug, Deserialize, Serialize)] 9 | pub struct ChatMessage { 10 | pub role: ChatMessageRole, 11 | pub content: String, 12 | #[serde(skip_serializing_if = "Option::is_none")] 13 | pub tool_calls: Option>, 14 | } 15 | impl ChatMessage { 16 | pub fn new_assistant_message(content: &str, tool_calls: Option>) -> Self { 17 | Self { 18 | role: ChatMessageRole::Assistant, 19 | content: content.to_string(), 20 | tool_calls, 21 | } 22 | } 23 | 24 | pub fn new_user_message(content: &str) -> Self { 25 | Self { 26 | role: ChatMessageRole::User, 27 | content: content.to_string(), 28 | tool_calls: None, 29 | } 30 | } 31 | } 32 | 33 | /// See the [Mistral AI API documentation](https://docs.mistral.ai/capabilities/completion/#chat-messages) for more information. 34 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 35 | pub enum ChatMessageRole { 36 | #[serde(rename = "system")] 37 | System, 38 | #[serde(rename = "assistant")] 39 | Assistant, 40 | #[serde(rename = "user")] 41 | User, 42 | #[serde(rename = "tool")] 43 | Tool, 44 | } 45 | 46 | /// The format that the model must output. 47 | /// 48 | /// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. 49 | #[derive(Clone, Debug, Serialize, Deserialize)] 50 | pub struct ResponseFormat { 51 | #[serde(rename = "type")] 52 | pub type_: String, 53 | } 54 | impl ResponseFormat { 55 | pub fn json_object() -> Self { 56 | Self { 57 | type_: "json_object".to_string(), 58 | } 59 | } 60 | } 61 | 62 | // ----------------------------------------------------------------------------- 63 | // Request 64 | 65 | /// The parameters for the chat request. 66 | /// 67 | /// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. 68 | #[derive(Clone, Debug)] 69 | pub struct ChatParams { 70 | /// The maximum number of tokens to generate in the completion. 71 | /// 72 | /// Defaults to `None`. 73 | pub max_tokens: Option, 74 | /// The seed to use for random sampling. If set, different calls will generate deterministic results. 75 | /// 76 | /// Defaults to `None`. 77 | pub random_seed: Option, 78 | /// The format that the model must output. 79 | /// 80 | /// Defaults to `None`. 81 | pub response_format: Option, 82 | /// Whether to inject a safety prompt before all conversations. 83 | /// 84 | /// Defaults to `false`. 85 | pub safe_prompt: bool, 86 | /// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`. 87 | /// 88 | /// Defaults to `0.7`. 89 | pub temperature: f32, 90 | /// Specifies if/how functions are called. 91 | /// 92 | /// Defaults to `None`. 93 | pub tool_choice: Option, 94 | /// A list of available tools for the model. 95 | /// 96 | /// Defaults to `None`. 97 | pub tools: Option>, 98 | /// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass. 99 | /// 100 | /// Defaults to `1.0`. 101 | pub top_p: f32, 102 | } 103 | impl Default for ChatParams { 104 | fn default() -> Self { 105 | Self { 106 | max_tokens: None, 107 | random_seed: None, 108 | safe_prompt: false, 109 | response_format: None, 110 | temperature: 0.7, 111 | tool_choice: None, 112 | tools: None, 113 | top_p: 1.0, 114 | } 115 | } 116 | } 117 | impl ChatParams { 118 | pub fn json_default() -> Self { 119 | Self { 120 | max_tokens: None, 121 | random_seed: None, 122 | safe_prompt: false, 123 | response_format: None, 124 | temperature: 0.7, 125 | tool_choice: None, 126 | tools: None, 127 | top_p: 1.0, 128 | } 129 | } 130 | } 131 | 132 | #[derive(Debug, Serialize, Deserialize)] 133 | pub struct ChatRequest { 134 | pub messages: Vec, 135 | pub model: constants::Model, 136 | 137 | #[serde(skip_serializing_if = "Option::is_none")] 138 | pub max_tokens: Option, 139 | #[serde(skip_serializing_if = "Option::is_none")] 140 | pub random_seed: Option, 141 | #[serde(skip_serializing_if = "Option::is_none")] 142 | pub response_format: Option, 143 | pub safe_prompt: bool, 144 | pub stream: bool, 145 | pub temperature: f32, 146 | #[serde(skip_serializing_if = "Option::is_none")] 147 | pub tool_choice: Option, 148 | #[serde(skip_serializing_if = "Option::is_none")] 149 | pub tools: Option>, 150 | pub top_p: f32, 151 | } 152 | impl ChatRequest { 153 | pub fn new( 154 | model: constants::Model, 155 | messages: Vec, 156 | stream: bool, 157 | options: Option, 158 | ) -> Self { 159 | let ChatParams { 160 | max_tokens, 161 | random_seed, 162 | safe_prompt, 163 | temperature, 164 | tool_choice, 165 | tools, 166 | top_p, 167 | response_format, 168 | } = options.unwrap_or_default(); 169 | 170 | Self { 171 | messages, 172 | model, 173 | 174 | max_tokens, 175 | random_seed, 176 | safe_prompt, 177 | stream, 178 | temperature, 179 | tool_choice, 180 | tools, 181 | top_p, 182 | response_format, 183 | } 184 | } 185 | } 186 | 187 | // ----------------------------------------------------------------------------- 188 | // Response 189 | 190 | #[derive(Clone, Debug, Deserialize, Serialize)] 191 | pub struct ChatResponse { 192 | pub id: String, 193 | pub object: String, 194 | /// Unix timestamp (in seconds). 195 | pub created: u32, 196 | pub model: constants::Model, 197 | pub choices: Vec, 198 | pub usage: common::ResponseUsage, 199 | } 200 | 201 | #[derive(Clone, Debug, Deserialize, Serialize)] 202 | pub struct ChatResponseChoice { 203 | pub index: u32, 204 | pub message: ChatMessage, 205 | pub finish_reason: ChatResponseChoiceFinishReason, 206 | // TODO Check this prop (seen in API responses but undocumented). 207 | // pub logprobs: ??? 208 | } 209 | 210 | #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] 211 | pub enum ChatResponseChoiceFinishReason { 212 | #[serde(rename = "stop")] 213 | Stop, 214 | #[serde(rename = "tool_calls")] 215 | ToolCalls, 216 | } 217 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## [0.14.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.13.0...v) (2024-08-27) 2 | 3 | ### Features 4 | 5 | * **constants:** update model constants ([#17](https://github.com/ivangabriele/mistralai-client-rs/issues/17)) ([161b33c](https://github.com/ivangabriele/mistralai-client-rs/commit/161b33c72539a6e982207349942a436df95399b7)) 6 | ## [0.13.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.12.0...v) (2024-08-21) 7 | 8 | ### ⚠ BREAKING CHANGES 9 | 10 | * **client:** `v1::model_list::ModelListData` struct has been updated. 11 | 12 | ### Bug Fixes 13 | 14 | * **client:** remove the `Content-Type` from the headers of the reqwest builders. ([#14](https://github.com/ivangabriele/mistralai-client-rs/issues/14)) ([9bfbf2e](https://github.com/ivangabriele/mistralai-client-rs/commit/9bfbf2e9dc7b48103ac56923fb8b3ac9a5e2d9cf)), closes [#13](https://github.com/ivangabriele/mistralai-client-rs/issues/13) 15 | * **client:** update ModelListData struct following API changes ([2114916](https://github.com/ivangabriele/mistralai-client-rs/commit/2114916941e1ff5aa242290df5f092c0d4954afc)) 16 | ## [0.12.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.11.0...v) (2024-07-24) 17 | 18 | ### Features 19 | 20 | * implement the Debug trait for Client ([#11](https://github.com/ivangabriele/mistralai-client-rs/issues/11)) ([3afeec1](https://github.com/ivangabriele/mistralai-client-rs/commit/3afeec1d586022e43c7b10906acec5e65927ba7d)) 21 | * mark Function trait as Send ([#12](https://github.com/ivangabriele/mistralai-client-rs/issues/12)) ([8e9f7a5](https://github.com/ivangabriele/mistralai-client-rs/commit/8e9f7a53863879b2ad618e9e5707b198e4f3b135)) 22 | ## [0.11.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.10.0...v) (2024-06-22) 23 | 24 | ### Features 25 | 26 | * **constants:** add OpenMixtral8x22b, MistralTiny & CodestralLatest to Model enum ([ecd0c30](https://github.com/ivangabriele/mistralai-client-rs/commit/ecd0c3028fdcfab32b867eb1eed86182f5f4ab81)) 27 | 28 | ### Bug Fixes 29 | 30 | * **chat:** implement Clone trait for ChatParams & ResponseFormat ([0df67b1](https://github.com/ivangabriele/mistralai-client-rs/commit/0df67b1b2571fb04b636ce015a2daabe629ff352)) 31 | ## [0.10.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.9.0...v) (2024-06-07) 32 | 33 | ### ⚠ BREAKING CHANGES 34 | 35 | * **chat:** - `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option`. Default is `false`. 36 | - `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option`. Default is `0.7`. 37 | - `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option`. Default is `1.0`. 38 | 39 | ### Features 40 | 41 | * **chat:** add response_format for JSON return values ([85c3611](https://github.com/ivangabriele/mistralai-client-rs/commit/85c3611afbbe8df30dfc7512cc381ed304ce4024)) 42 | * **chat:** add the 'system' and 'tool' message roles ([#10](https://github.com/ivangabriele/mistralai-client-rs/issues/10)) ([2fc0642](https://github.com/ivangabriele/mistralai-client-rs/commit/2fc0642a5e4c024b15710acaab7735480e8dfe6a)) 43 | * **chat:** change safe_prompt, temperature & top_p to non-Option types ([cf68a77](https://github.com/ivangabriele/mistralai-client-rs/commit/cf68a773201ebe0e802face52af388711acf0c27)) 44 | 45 | ### Bug Fixes 46 | 47 | * **chat:** skip serializing tool_calls if null, to avoid 422 error ([da5fe54](https://github.com/ivangabriele/mistralai-client-rs/commit/da5fe54115ce622379776661a440e2708b24810c)) 48 | ## [0.9.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.8.0...v) (2024-04-13) 49 | 50 | 51 | ### ⚠ BREAKING CHANGES 52 | 53 | * `Model.OpenMistral8x7b` has been renamed to `Model.OpenMixtral8x7b`. 54 | 55 | ### Bug Fixes 56 | 57 | * **deps:** update rust crate reqwest to 0.12.0 ([#6](https://github.com/ivangabriele/mistralai-client-rs/issues/6)) ([fccd59c](https://github.com/ivangabriele/mistralai-client-rs/commit/fccd59c0cc783edddec1b404363faabb009eecd6)) 58 | * fix typo in OpenMixtral8x7b model name ([#8](https://github.com/ivangabriele/mistralai-client-rs/issues/8)) ([6a99eca](https://github.com/ivangabriele/mistralai-client-rs/commit/6a99eca49c0cc8e3764a56f6dfd7762ec44a4c3b)) 59 | 60 | ## [0.8.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.7.0...v) (2024-03-09) 61 | 62 | 63 | ### ⚠ BREAKING CHANGES 64 | 65 | * Too many to count in this version. Check the README examples. 66 | 67 | ### Features 68 | 69 | * add function calling support to client.chat() & client.chat_async() ([74bf8a9](https://github.com/ivangabriele/mistralai-client-rs/commit/74bf8a96ee31f9d54ee3d7404619e803a182918b)) 70 | 71 | ## [0.7.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.6.0...v) (2024-03-05) 72 | 73 | 74 | ### ⚠ BREAKING CHANGES 75 | 76 | * - Rename `ClientError.ApiKeyError` to `MissingApiKey`. 77 | - Rename `ClientError.ReadResponseTextError` to `ClientError.UnreadableResponseText`. 78 | 79 | ### Bug Fixes 80 | 81 | * fix failure when api key as param and not env ([ef5d475](https://github.com/ivangabriele/mistralai-client-rs/commit/ef5d475e2d0e3fe040c44d6adabf7249e9962835)) 82 | 83 | ## [0.6.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.5.0...v) (2024-03-04) 84 | 85 | 86 | ### ⚠ BREAKING CHANGES 87 | 88 | * You can't set the `stream` option for `client.chat*()`. 89 | 90 | Either use `client.chat_stream()` if you want to use streams 91 | or use `client.chat()` / `client.chat_async()` otherwise. 92 | 93 | ### Features 94 | 95 | * add client.chat_stream() method ([4a4219d](https://github.com/ivangabriele/mistralai-client-rs/commit/4a4219d3eaa8f0ae953ee6182b36bf464d1c4a21)) 96 | 97 | ## [0.5.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.4.0...v) (2024-03-04) 98 | 99 | 100 | ### Features 101 | 102 | * add client.embeddings_async() method ([3c22891](https://github.com/ivangabriele/mistralai-client-rs/commit/3c228914f78b0edd4a592091265b88d0bc55568b)) 103 | * add client.list_models_async() method ([b69f7c6](https://github.com/ivangabriele/mistralai-client-rs/commit/b69f7c617c15dd63abb61d004636512916d766bb)) 104 | 105 | ## [0.4.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.3.0...v) (2024-03-04) 106 | 107 | 108 | ### ⚠ BREAKING CHANGES 109 | 110 | * `Client::new()` now returns a `Result`. 111 | * `APIError` is renamed to `ApiError`. 112 | 113 | ### Features 114 | 115 | * add client.chat_async() method ([1dd59f6](https://github.com/ivangabriele/mistralai-client-rs/commit/1dd59f67048c10458ab0382af8fdfe4ed21c82fa)) 116 | * add missing api key error ([1deab88](https://github.com/ivangabriele/mistralai-client-rs/commit/1deab88251fc706e0415a5e416ab9aee4b52f6f3)) 117 | * wrap Client::new() return in a Result ([3387618](https://github.com/ivangabriele/mistralai-client-rs/commit/33876183e41340f426aa1dd1b6d8b5c05c8e15b9)) 118 | 119 | ## [0.3.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.2.0...v) (2024-03-04) 120 | 121 | 122 | ### ⚠ BREAKING CHANGES 123 | 124 | * Models are now enforced by `Model` & `EmbedModel` enums. 125 | 126 | ### Features 127 | 128 | * add client.embeddings() method ([f44d951](https://github.com/ivangabriele/mistralai-client-rs/commit/f44d95124767c3a3f14c78c4be3d9c203fac49ad)) 129 | 130 | ## [0.2.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.1.0...v) (2024-03-03) 131 | 132 | 133 | ### ⚠ BREAKING CHANGES 134 | 135 | * Chat completions must now be called directly from client.chat() without building a request in between. 136 | 137 | ### Features 138 | 139 | * add client.list_models() method ([814b991](https://github.com/ivangabriele/mistralai-client-rs/commit/814b9918b3aca78bfd606b5b9bb470b70ea2a5c6)) 140 | * simplify chat completion call ([7de2b19](https://github.com/ivangabriele/mistralai-client-rs/commit/7de2b19b981f1d65fe5c566fcaf521e4f2a9ced1)) 141 | 142 | ## [0.1.0](https://github.com/ivangabriele/mistralai-client-rs/compare/7d3b438d16e9936591b6454525968c5c2cdfd6ad...v0.1.0) (2024-03-03) 143 | 144 | ### Features 145 | 146 | - add chat completion without streaming ([7d3b438](https://github.com/ivangabriele/mistralai-client-rs/commit/7d3b438d16e9936591b6454525968c5c2cdfd6ad)) 147 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mistral AI Rust Client 2 | 3 | [![Crates.io Package](https://img.shields.io/crates/v/mistralai-client?style=for-the-badge)](https://crates.io/crates/mistralai-client) 4 | [![Docs.rs Documentation](https://img.shields.io/docsrs/mistralai-client/latest?style=for-the-badge)](https://docs.rs/mistralai-client/latest/mistralai-client) 5 | [![Test Workflow Status](https://img.shields.io/github/actions/workflow/status/ivangabriele/mistralai-client-rs/test.yml?label=CI&style=for-the-badge)](https://github.com/ivangabriele/mistralai-client-rs/actions?query=branch%3Amain+workflow%3ATest++) 6 | [![Code Coverage](https://img.shields.io/codecov/c/github/ivangabriele/mistralai-client-rs/main?label=Cov&style=for-the-badge)](https://app.codecov.io/github/ivangabriele/mistralai-client-rs) 7 | 8 | Rust client for the Mistral AI API. 9 | 10 | > [!IMPORTANT] 11 | > While we are in v0, minor versions may introduce breaking changes. 12 | > Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information. 13 | 14 | --- 15 | 16 | - [Supported APIs](#supported-apis) 17 | - [Installation](#installation) 18 | - [Mistral API Key](#mistral-api-key) 19 | - [As an environment variable](#as-an-environment-variable) 20 | - [As a client argument](#as-a-client-argument) 21 | - [Usage](#usage) 22 | - [Chat](#chat) 23 | - [Chat (async)](#chat-async) 24 | - [Chat with streaming (async)](#chat-with-streaming-async) 25 | - [Chat with Function Calling](#chat-with-function-calling) 26 | - [Chat with Function Calling (async)](#chat-with-function-calling-async) 27 | - [Embeddings](#embeddings) 28 | - [Embeddings (async)](#embeddings-async) 29 | - [List models](#list-models) 30 | - [List models (async)](#list-models-async) 31 | - [Contributing](#contributing) 32 | 33 | --- 34 | 35 | ## Supported APIs 36 | 37 | - [x] Chat without streaming 38 | - [x] Chat without streaming (async) 39 | - [x] Chat with streaming 40 | - [x] Embedding 41 | - [x] Embedding (async) 42 | - [x] List models 43 | - [x] List models (async) 44 | - [x] Function Calling 45 | - [x] Function Calling (async) 46 | 47 | ## Installation 48 | 49 | You can install the library in your project using: 50 | 51 | ```sh 52 | cargo add mistralai-client 53 | ``` 54 | 55 | ### Mistral API Key 56 | 57 | You can get your Mistral API Key there: . 58 | 59 | #### As an environment variable 60 | 61 | Just set the `MISTRAL_API_KEY` environment variable. 62 | 63 | ```rs 64 | use mistralai_client::v1::client::Client; 65 | 66 | fn main() { 67 | let client = Client::new(None, None, None, None); 68 | } 69 | ``` 70 | 71 | ```sh 72 | MISTRAL_API_KEY=your_api_key cargo run 73 | ``` 74 | 75 | #### As a client argument 76 | 77 | ```rs 78 | use mistralai_client::v1::client::Client; 79 | 80 | fn main() { 81 | let api_key = "your_api_key"; 82 | 83 | let client = Client::new(Some(api_key), None, None, None).unwrap(); 84 | } 85 | ``` 86 | 87 | ## Usage 88 | 89 | ### Chat 90 | 91 | ```rs 92 | use mistralai_client::v1::{ 93 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 94 | client::Client, 95 | constants::Model, 96 | }; 97 | 98 | fn main() { 99 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 100 | let client = Client::new(None, None, None, None).unwrap(); 101 | 102 | let model = Model::OpenMistral7b; 103 | let messages = vec![ChatMessage { 104 | role: ChatMessageRole::User, 105 | content: "Just guess the next word: \"Eiffel ...\"?".to_string(), 106 | tool_calls: None, 107 | }]; 108 | let options = ChatParams { 109 | temperature: 0.0, 110 | random_seed: Some(42), 111 | ..Default::default() 112 | }; 113 | 114 | let result = client.chat(model, messages, Some(options)).unwrap(); 115 | println!("Assistant: {}", result.choices[0].message.content); 116 | // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." 117 | } 118 | ``` 119 | 120 | ### Chat (async) 121 | 122 | ```rs 123 | use mistralai_client::v1::{ 124 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 125 | client::Client, 126 | constants::Model, 127 | }; 128 | 129 | #[tokio::main] 130 | async fn main() { 131 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 132 | let client = Client::new(None, None, None, None).unwrap(); 133 | 134 | let model = Model::OpenMistral7b; 135 | let messages = vec![ChatMessage { 136 | role: ChatMessageRole::User, 137 | content: "Just guess the next word: \"Eiffel ...\"?".to_string(), 138 | tool_calls: None, 139 | }]; 140 | let options = ChatParams { 141 | temperature: 0.0, 142 | random_seed: Some(42), 143 | ..Default::default() 144 | }; 145 | 146 | let result = client 147 | .chat_async(model, messages, Some(options)) 148 | .await 149 | .unwrap(); 150 | println!( 151 | "{:?}: {}", 152 | result.choices[0].message.role, result.choices[0].message.content 153 | ); 154 | // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." 155 | } 156 | ``` 157 | 158 | ### Chat with streaming (async) 159 | 160 | ```rs 161 | use futures::stream::StreamExt; 162 | use mistralai_client::v1::{ 163 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 164 | client::Client, 165 | constants::Model, 166 | }; 167 | use std::io::{self, Write}; 168 | 169 | #[tokio::main] 170 | async fn main() { 171 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 172 | let client = Client::new(None, None, None, None).unwrap(); 173 | 174 | let model = Model::OpenMistral7b; 175 | let messages = vec![ChatMessage { 176 | role: ChatMessageRole::User, 177 | content: "Tell me a short happy story.".to_string(), 178 | tool_calls: None, 179 | }]; 180 | let options = ChatParams { 181 | temperature: 0.0, 182 | random_seed: Some(42), 183 | ..Default::default() 184 | }; 185 | 186 | let stream_result = client 187 | .chat_stream(model, messages, Some(options)) 188 | .await 189 | .unwrap(); 190 | stream_result 191 | .for_each(|chunk_result| async { 192 | match chunk_result { 193 | Ok(chunks) => chunks.iter().for_each(|chunk| { 194 | print!("{}", chunk.choices[0].delta.content); 195 | io::stdout().flush().unwrap(); 196 | // => "Once upon a time, [...]" 197 | }), 198 | Err(error) => { 199 | eprintln!("Error processing chunk: {:?}", error) 200 | } 201 | } 202 | }) 203 | .await; 204 | print!("\n") // To persist the last chunk output. 205 | } 206 | ``` 207 | 208 | ### Chat with Function Calling 209 | 210 | ```rs 211 | use mistralai_client::v1::{ 212 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 213 | client::Client, 214 | constants::Model, 215 | tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, 216 | }; 217 | use serde::Deserialize; 218 | use std::any::Any; 219 | 220 | #[derive(Debug, Deserialize)] 221 | struct GetCityTemperatureArguments { 222 | city: String, 223 | } 224 | 225 | struct GetCityTemperatureFunction; 226 | #[async_trait::async_trait] 227 | impl Function for GetCityTemperatureFunction { 228 | async fn execute(&self, arguments: String) -> Box { 229 | // Deserialize arguments, perform the logic, and return the result 230 | let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); 231 | 232 | let temperature = match city.as_str() { 233 | "Paris" => "20°C", 234 | _ => "Unknown city", 235 | }; 236 | 237 | Box::new(temperature.to_string()) 238 | } 239 | } 240 | 241 | fn main() { 242 | let tools = vec![Tool::new( 243 | "get_city_temperature".to_string(), 244 | "Get the current temperature in a city.".to_string(), 245 | vec![ToolFunctionParameter::new( 246 | "city".to_string(), 247 | "The name of the city.".to_string(), 248 | ToolFunctionParameterType::String, 249 | )], 250 | )]; 251 | 252 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 253 | let mut client = Client::new(None, None, None, None).unwrap(); 254 | client.register_function( 255 | "get_city_temperature".to_string(), 256 | Box::new(GetCityTemperatureFunction), 257 | ); 258 | 259 | let model = Model::MistralSmallLatest; 260 | let messages = vec![ChatMessage { 261 | role: ChatMessageRole::User, 262 | content: "What's the temperature in Paris?".to_string(), 263 | tool_calls: None, 264 | }]; 265 | let options = ChatParams { 266 | temperature: 0.0, 267 | random_seed: Some(42), 268 | tool_choice: Some(ToolChoice::Auto), 269 | tools: Some(tools), 270 | ..Default::default() 271 | }; 272 | 273 | client.chat(model, messages, Some(options)).unwrap(); 274 | let temperature = client 275 | .get_last_function_call_result() 276 | .unwrap() 277 | .downcast::() 278 | .unwrap(); 279 | println!("The temperature in Paris is: {}.", temperature); 280 | // => "The temperature in Paris is: 20°C." 281 | } 282 | ``` 283 | 284 | ### Chat with Function Calling (async) 285 | 286 | ```rs 287 | use mistralai_client::v1::{ 288 | chat::{ChatMessage, ChatMessageRole, ChatParams}, 289 | client::Client, 290 | constants::Model, 291 | tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, 292 | }; 293 | use serde::Deserialize; 294 | use std::any::Any; 295 | 296 | #[derive(Debug, Deserialize)] 297 | struct GetCityTemperatureArguments { 298 | city: String, 299 | } 300 | 301 | struct GetCityTemperatureFunction; 302 | #[async_trait::async_trait] 303 | impl Function for GetCityTemperatureFunction { 304 | async fn execute(&self, arguments: String) -> Box { 305 | // Deserialize arguments, perform the logic, and return the result 306 | let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); 307 | 308 | let temperature = match city.as_str() { 309 | "Paris" => "20°C", 310 | _ => "Unknown city", 311 | }; 312 | 313 | Box::new(temperature.to_string()) 314 | } 315 | } 316 | 317 | #[tokio::main] 318 | async fn main() { 319 | let tools = vec![Tool::new( 320 | "get_city_temperature".to_string(), 321 | "Get the current temperature in a city.".to_string(), 322 | vec![ToolFunctionParameter::new( 323 | "city".to_string(), 324 | "The name of the city.".to_string(), 325 | ToolFunctionParameterType::String, 326 | )], 327 | )]; 328 | 329 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 330 | let mut client = Client::new(None, None, None, None).unwrap(); 331 | client.register_function( 332 | "get_city_temperature".to_string(), 333 | Box::new(GetCityTemperatureFunction), 334 | ); 335 | 336 | let model = Model::MistralSmallLatest; 337 | let messages = vec![ChatMessage { 338 | role: ChatMessageRole::User, 339 | content: "What's the temperature in Paris?".to_string(), 340 | tool_calls: None, 341 | }]; 342 | let options = ChatParams { 343 | temperature: 0.0, 344 | random_seed: Some(42), 345 | tool_choice: Some(ToolChoice::Auto), 346 | tools: Some(tools), 347 | ..Default::default() 348 | }; 349 | 350 | client 351 | .chat_async(model, messages, Some(options)) 352 | .await 353 | .unwrap(); 354 | let temperature = client 355 | .get_last_function_call_result() 356 | .unwrap() 357 | .downcast::() 358 | .unwrap(); 359 | println!("The temperature in Paris is: {}.", temperature); 360 | // => "The temperature in Paris is: 20°C." 361 | } 362 | ``` 363 | 364 | ### Embeddings 365 | 366 | ```rs 367 | use mistralai_client::v1::{client::Client, constants::EmbedModel}; 368 | 369 | fn main() { 370 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 371 | let client: Client = Client::new(None, None, None, None).unwrap(); 372 | 373 | let model = EmbedModel::MistralEmbed; 374 | let input = vec!["Embed this sentence.", "As well as this one."] 375 | .iter() 376 | .map(|s| s.to_string()) 377 | .collect(); 378 | let options = None; 379 | 380 | let response = client.embeddings(model, input, options).unwrap(); 381 | println!("First Embedding: {:?}", response.data[0]); 382 | // => "First Embedding: {...}" 383 | } 384 | ``` 385 | 386 | ### Embeddings (async) 387 | 388 | ```rs 389 | use mistralai_client::v1::{client::Client, constants::EmbedModel}; 390 | 391 | #[tokio::main] 392 | async fn main() { 393 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 394 | let client: Client = Client::new(None, None, None, None).unwrap(); 395 | 396 | let model = EmbedModel::MistralEmbed; 397 | let input = vec!["Embed this sentence.", "As well as this one."] 398 | .iter() 399 | .map(|s| s.to_string()) 400 | .collect(); 401 | let options = None; 402 | 403 | let response = client 404 | .embeddings_async(model, input, options) 405 | .await 406 | .unwrap(); 407 | println!("First Embedding: {:?}", response.data[0]); 408 | // => "First Embedding: {...}" 409 | } 410 | ``` 411 | 412 | ### List models 413 | 414 | ```rs 415 | use mistralai_client::v1::client::Client; 416 | 417 | fn main() { 418 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 419 | let client = Client::new(None, None, None, None).unwrap(); 420 | 421 | let result = client.list_models().unwrap(); 422 | println!("First Model ID: {:?}", result.data[0].id); 423 | // => "First Model ID: open-mistral-7b" 424 | } 425 | ``` 426 | 427 | ### List models (async) 428 | 429 | ```rs 430 | use mistralai_client::v1::client::Client; 431 | 432 | #[tokio::main] 433 | async fn main() { 434 | // This example suppose you have set the `MISTRAL_API_KEY` environment variable. 435 | let client = Client::new(None, None, None, None).unwrap(); 436 | 437 | let result = client.list_models_async().await.unwrap(); 438 | println!("First Model ID: {:?}", result.data[0].id); 439 | // => "First Model ID: open-mistral-7b" 440 | } 441 | ``` 442 | 443 | ## Contributing 444 | 445 | Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for details on how to contribute to this library. 446 | -------------------------------------------------------------------------------- /src/v1/client.rs: -------------------------------------------------------------------------------- 1 | use futures::stream::StreamExt; 2 | use futures::Stream; 3 | use log::debug; 4 | use reqwest::Error as ReqwestError; 5 | use std::{ 6 | any::Any, 7 | collections::HashMap, 8 | sync::{Arc, Mutex}, 9 | }; 10 | 11 | use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils}; 12 | 13 | #[derive(Debug)] 14 | pub struct Client { 15 | pub api_key: String, 16 | pub endpoint: String, 17 | pub max_retries: u32, 18 | pub timeout: u32, 19 | 20 | functions: Arc>>>, 21 | last_function_call_result: Arc>>>, 22 | } 23 | 24 | impl Client { 25 | /// Constructs a new `Client`. 26 | /// 27 | /// # Arguments 28 | /// 29 | /// * `api_key` - An optional API key. 30 | /// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable. 31 | /// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided. 32 | /// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`. 33 | /// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`. 34 | /// 35 | /// # Examples 36 | /// 37 | /// ``` 38 | /// use mistralai_client::v1::client::Client; 39 | /// 40 | /// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60)); 41 | /// assert!(client.is_ok()); 42 | /// ``` 43 | /// 44 | /// # Errors 45 | /// 46 | /// This method fails whenever neither the `api_key` is provided 47 | /// nor the `MISTRAL_API_KEY` environment variable is set. 48 | pub fn new( 49 | api_key: Option, 50 | endpoint: Option, 51 | max_retries: Option, 52 | timeout: Option, 53 | ) -> Result { 54 | let api_key = match api_key { 55 | Some(api_key_from_param) => api_key_from_param, 56 | None => { 57 | std::env::var("MISTRAL_API_KEY").map_err(|_| error::ClientError::MissingApiKey)? 58 | } 59 | }; 60 | let endpoint = endpoint.unwrap_or(constants::API_URL_BASE.to_string()); 61 | let max_retries = max_retries.unwrap_or(5); 62 | let timeout = timeout.unwrap_or(120); 63 | 64 | let functions: Arc<_> = Arc::new(Mutex::new(HashMap::new())); 65 | let last_function_call_result = Arc::new(Mutex::new(None)); 66 | 67 | Ok(Self { 68 | api_key, 69 | endpoint, 70 | max_retries, 71 | timeout, 72 | 73 | functions, 74 | last_function_call_result, 75 | }) 76 | } 77 | 78 | /// Synchronously sends a chat completion request and returns the response. 79 | /// 80 | /// # Arguments 81 | /// 82 | /// * `model` - The [Model] to use for the chat completion. 83 | /// * `messages` - A vector of [ChatMessage] to send as part of the chat. 84 | /// * `options` - Optional [ChatParams] to customize the request. 85 | /// 86 | /// # Returns 87 | /// 88 | /// Returns a [Result] containing the `ChatResponse` if the request is successful, 89 | /// or an [ApiError] if there is an error. 90 | /// 91 | /// # Examples 92 | /// 93 | /// ``` 94 | /// use mistralai_client::v1::{ 95 | /// chat::{ChatMessage, ChatMessageRole}, 96 | /// client::Client, 97 | /// constants::Model, 98 | /// }; 99 | /// 100 | /// let client = Client::new(None, None, None, None).unwrap(); 101 | /// let messages = vec![ChatMessage { 102 | /// role: ChatMessageRole::User, 103 | /// content: "Hello, world!".to_string(), 104 | /// tool_calls: None, 105 | /// }]; 106 | /// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap(); 107 | /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); 108 | /// ``` 109 | pub fn chat( 110 | &self, 111 | model: constants::Model, 112 | messages: Vec, 113 | options: Option, 114 | ) -> Result { 115 | let request = chat::ChatRequest::new(model, messages, false, options); 116 | 117 | let response = self.post_sync("/chat/completions", &request)?; 118 | let result = response.json::(); 119 | match result { 120 | Ok(data) => { 121 | utils::debug_pretty_json_from_struct("Response Data", &data); 122 | 123 | self.call_function_if_any(data.clone()); 124 | 125 | Ok(data) 126 | } 127 | Err(error) => Err(self.to_api_error(error)), 128 | } 129 | } 130 | 131 | /// Asynchronously sends a chat completion request and returns the response. 132 | /// 133 | /// # Arguments 134 | /// 135 | /// * `model` - The [Model] to use for the chat completion. 136 | /// * `messages` - A vector of [ChatMessage] to send as part of the chat. 137 | /// * `options` - Optional [ChatParams] to customize the request. 138 | /// 139 | /// # Returns 140 | /// 141 | /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, 142 | /// or an [ApiError] if there is an error. 143 | /// 144 | /// # Examples 145 | /// 146 | /// ``` 147 | /// use mistralai_client::v1::{ 148 | /// chat::{ChatMessage, ChatMessageRole}, 149 | /// client::Client, 150 | /// constants::Model, 151 | /// }; 152 | /// 153 | /// #[tokio::main] 154 | /// async fn main() { 155 | /// let client = Client::new(None, None, None, None).unwrap(); 156 | /// let messages = vec![ChatMessage { 157 | /// role: ChatMessageRole::User, 158 | /// content: "Hello, world!".to_string(), 159 | /// tool_calls: None, 160 | /// }]; 161 | /// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap(); 162 | /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); 163 | /// } 164 | /// ``` 165 | pub async fn chat_async( 166 | &self, 167 | model: constants::Model, 168 | messages: Vec, 169 | options: Option, 170 | ) -> Result { 171 | let request = chat::ChatRequest::new(model, messages, false, options); 172 | 173 | let response = self.post_async("/chat/completions", &request).await?; 174 | let result = response.json::().await; 175 | match result { 176 | Ok(data) => { 177 | utils::debug_pretty_json_from_struct("Response Data", &data); 178 | 179 | self.call_function_if_any_async(data.clone()).await; 180 | 181 | Ok(data) 182 | } 183 | Err(error) => Err(self.to_api_error(error)), 184 | } 185 | } 186 | 187 | /// Asynchronously sends a chat completion request and returns a stream of message chunks. 188 | /// 189 | /// # Arguments 190 | /// 191 | /// * `model` - The [Model] to use for the chat completion. 192 | /// * `messages` - A vector of [ChatMessage] to send as part of the chat. 193 | /// * `options` - Optional [ChatParams] to customize the request. 194 | /// 195 | /// # Returns 196 | /// 197 | /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, 198 | /// or an [ApiError] if there is an error. 199 | /// 200 | /// # Examples 201 | /// 202 | /// ``` 203 | /// use futures::stream::StreamExt; 204 | /// use mistralai_client::v1::{ 205 | /// chat::{ChatMessage, ChatMessageRole}, 206 | /// client::Client, 207 | /// constants::Model, 208 | /// }; 209 | /// use std::io::{self, Write}; 210 | /// 211 | /// #[tokio::main] 212 | /// async fn main() { 213 | /// let client = Client::new(None, None, None, None).unwrap(); 214 | /// let messages = vec![ChatMessage { 215 | /// role: ChatMessageRole::User, 216 | /// content: "Hello, world!".to_string(), 217 | /// tool_calls: None, 218 | /// }]; 219 | /// 220 | /// let stream_result = client 221 | /// .chat_stream(Model::OpenMistral7b,messages, None) 222 | /// .await 223 | /// .unwrap(); 224 | /// stream_result 225 | /// .for_each(|chunk_result| async { 226 | /// match chunk_result { 227 | /// Ok(chunks) => chunks.iter().for_each(|chunk| { 228 | /// print!("{}", chunk.choices[0].delta.content); 229 | /// io::stdout().flush().unwrap(); 230 | /// // => "Once upon a time, [...]" 231 | /// }), 232 | /// Err(error) => { 233 | /// eprintln!("Error processing chunk: {:?}", error) 234 | /// } 235 | /// } 236 | /// }) 237 | /// .await; 238 | /// print!("\n") // To persist the last chunk output. 239 | /// } 240 | pub async fn chat_stream( 241 | &self, 242 | model: constants::Model, 243 | messages: Vec, 244 | options: Option, 245 | ) -> Result< 246 | impl Stream, error::ApiError>>, 247 | error::ApiError, 248 | > { 249 | let request = chat::ChatRequest::new(model, messages, true, options); 250 | let response = self 251 | .post_stream("/chat/completions", &request) 252 | .await 253 | .map_err(|e| error::ApiError { 254 | message: e.to_string(), 255 | })?; 256 | if !response.status().is_success() { 257 | let status = response.status(); 258 | let text = response.text().await.unwrap_or_default(); 259 | return Err(error::ApiError { 260 | message: format!("{}: {}", status, text), 261 | }); 262 | } 263 | 264 | let deserialized_stream = response.bytes_stream().then(|bytes_result| async move { 265 | match bytes_result { 266 | Ok(bytes) => match String::from_utf8(bytes.to_vec()) { 267 | Ok(message) => { 268 | let chunks = message 269 | .lines() 270 | .filter_map( 271 | |line| match chat_stream::get_chunk_from_stream_message_line(line) { 272 | Ok(Some(chunks)) => Some(chunks), 273 | Ok(None) => None, 274 | Err(_error) => None, 275 | }, 276 | ) 277 | .flatten() 278 | .collect(); 279 | 280 | Ok(chunks) 281 | } 282 | Err(e) => Err(error::ApiError { 283 | message: e.to_string(), 284 | }), 285 | }, 286 | Err(e) => Err(error::ApiError { 287 | message: e.to_string(), 288 | }), 289 | } 290 | }); 291 | 292 | Ok(deserialized_stream) 293 | } 294 | 295 | pub fn embeddings( 296 | &self, 297 | model: constants::EmbedModel, 298 | input: Vec, 299 | options: Option, 300 | ) -> Result { 301 | let request = embedding::EmbeddingRequest::new(model, input, options); 302 | 303 | let response = self.post_sync("/embeddings", &request)?; 304 | let result = response.json::(); 305 | match result { 306 | Ok(data) => { 307 | utils::debug_pretty_json_from_struct("Response Data", &data); 308 | 309 | Ok(data) 310 | } 311 | Err(error) => Err(self.to_api_error(error)), 312 | } 313 | } 314 | 315 | pub async fn embeddings_async( 316 | &self, 317 | model: constants::EmbedModel, 318 | input: Vec, 319 | options: Option, 320 | ) -> Result { 321 | let request = embedding::EmbeddingRequest::new(model, input, options); 322 | 323 | let response = self.post_async("/embeddings", &request).await?; 324 | let result = response.json::().await; 325 | match result { 326 | Ok(data) => { 327 | utils::debug_pretty_json_from_struct("Response Data", &data); 328 | 329 | Ok(data) 330 | } 331 | Err(error) => Err(self.to_api_error(error)), 332 | } 333 | } 334 | 335 | pub fn get_last_function_call_result(&self) -> Option> { 336 | let mut result_lock = self.last_function_call_result.lock().unwrap(); 337 | 338 | result_lock.take() 339 | } 340 | 341 | pub fn list_models(&self) -> Result { 342 | let response = self.get_sync("/models")?; 343 | let result = response.json::(); 344 | match result { 345 | Ok(data) => { 346 | utils::debug_pretty_json_from_struct("Response Data", &data); 347 | 348 | Ok(data) 349 | } 350 | Err(error) => Err(self.to_api_error(error)), 351 | } 352 | } 353 | 354 | pub async fn list_models_async( 355 | &self, 356 | ) -> Result { 357 | let response = self.get_async("/models").await?; 358 | let result = response.json::().await; 359 | match result { 360 | Ok(data) => { 361 | utils::debug_pretty_json_from_struct("Response Data", &data); 362 | 363 | Ok(data) 364 | } 365 | Err(error) => Err(self.to_api_error(error)), 366 | } 367 | } 368 | 369 | pub fn register_function(&mut self, name: String, function: Box) { 370 | let mut functions = self.functions.lock().unwrap(); 371 | 372 | functions.insert(name, function); 373 | } 374 | 375 | fn build_request_sync( 376 | &self, 377 | request: reqwest::blocking::RequestBuilder, 378 | ) -> reqwest::blocking::RequestBuilder { 379 | let user_agent = format!( 380 | "ivangabriele/mistralai-client-rs/{}", 381 | env!("CARGO_PKG_VERSION") 382 | ); 383 | 384 | let request_builder = request 385 | .bearer_auth(&self.api_key) 386 | .header("Accept", "application/json") 387 | .header("User-Agent", user_agent); 388 | 389 | request_builder 390 | } 391 | 392 | fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { 393 | let user_agent = format!( 394 | "ivangabriele/mistralai-client-rs/{}", 395 | env!("CARGO_PKG_VERSION") 396 | ); 397 | 398 | let request_builder = request 399 | .bearer_auth(&self.api_key) 400 | .header("Accept", "application/json") 401 | .header("User-Agent", user_agent); 402 | 403 | request_builder 404 | } 405 | 406 | fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { 407 | let user_agent = format!( 408 | "ivangabriele/mistralai-client-rs/{}", 409 | env!("CARGO_PKG_VERSION") 410 | ); 411 | 412 | let request_builder = request 413 | .bearer_auth(&self.api_key) 414 | .header("Accept", "text/event-stream") 415 | .header("User-Agent", user_agent); 416 | 417 | request_builder 418 | } 419 | 420 | fn call_function_if_any(&self, response: chat::ChatResponse) -> () { 421 | let next_result = match response.choices.get(0) { 422 | Some(first_choice) => match first_choice.message.tool_calls.to_owned() { 423 | Some(tool_calls) => match tool_calls.get(0) { 424 | Some(first_tool_call) => { 425 | let functions = self.functions.lock().unwrap(); 426 | match functions.get(&first_tool_call.function.name) { 427 | Some(function) => { 428 | let runtime = tokio::runtime::Runtime::new().unwrap(); 429 | let result = runtime.block_on(async { 430 | function 431 | .execute(first_tool_call.function.arguments.to_owned()) 432 | .await 433 | }); 434 | 435 | Some(result) 436 | } 437 | None => None, 438 | } 439 | } 440 | None => None, 441 | }, 442 | None => None, 443 | }, 444 | None => None, 445 | }; 446 | 447 | let mut last_result_lock = self.last_function_call_result.lock().unwrap(); 448 | *last_result_lock = next_result; 449 | } 450 | 451 | async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () { 452 | let next_result = match response.choices.get(0) { 453 | Some(first_choice) => match first_choice.message.tool_calls.to_owned() { 454 | Some(tool_calls) => match tool_calls.get(0) { 455 | Some(first_tool_call) => { 456 | let functions = self.functions.lock().unwrap(); 457 | match functions.get(&first_tool_call.function.name) { 458 | Some(function) => { 459 | let result = function 460 | .execute(first_tool_call.function.arguments.to_owned()) 461 | .await; 462 | 463 | Some(result) 464 | } 465 | None => None, 466 | } 467 | } 468 | None => None, 469 | }, 470 | None => None, 471 | }, 472 | None => None, 473 | }; 474 | 475 | let mut last_result_lock = self.last_function_call_result.lock().unwrap(); 476 | *last_result_lock = next_result; 477 | } 478 | 479 | fn get_sync(&self, path: &str) -> Result { 480 | let reqwest_client = reqwest::blocking::Client::new(); 481 | let url = format!("{}{}", self.endpoint, path); 482 | debug!("Request URL: {}", url); 483 | 484 | let request = self.build_request_sync(reqwest_client.get(url)); 485 | 486 | let result = request.send(); 487 | match result { 488 | Ok(response) => { 489 | if response.status().is_success() { 490 | Ok(response) 491 | } else { 492 | let response_status = response.status(); 493 | let response_body = response.text().unwrap_or_default(); 494 | debug!("Response Status: {}", &response_status); 495 | utils::debug_pretty_json_from_string("Response Data", &response_body); 496 | 497 | Err(error::ApiError { 498 | message: format!("{}: {}", response_status, response_body), 499 | }) 500 | } 501 | } 502 | Err(error) => Err(error::ApiError { 503 | message: error.to_string(), 504 | }), 505 | } 506 | } 507 | 508 | async fn get_async(&self, path: &str) -> Result { 509 | let reqwest_client = reqwest::Client::new(); 510 | let url = format!("{}{}", self.endpoint, path); 511 | debug!("Request URL: {}", url); 512 | 513 | let request_builder = reqwest_client.get(url); 514 | let request = self.build_request_async(request_builder); 515 | 516 | let result = request.send().await; 517 | match result { 518 | Ok(response) => { 519 | if response.status().is_success() { 520 | Ok(response) 521 | } else { 522 | let response_status = response.status(); 523 | let response_body = response.text().await.unwrap_or_default(); 524 | debug!("Response Status: {}", &response_status); 525 | utils::debug_pretty_json_from_string("Response Data", &response_body); 526 | 527 | Err(error::ApiError { 528 | message: format!("{}: {}", response_status, response_body), 529 | }) 530 | } 531 | } 532 | Err(error) => Err(error::ApiError { 533 | message: error.to_string(), 534 | }), 535 | } 536 | } 537 | 538 | fn post_sync( 539 | &self, 540 | path: &str, 541 | params: &T, 542 | ) -> Result { 543 | let reqwest_client = reqwest::blocking::Client::new(); 544 | let url = format!("{}{}", self.endpoint, path); 545 | debug!("Request URL: {}", url); 546 | utils::debug_pretty_json_from_struct("Request Body", params); 547 | 548 | let request_builder = reqwest_client.post(url).json(params); 549 | let request = self.build_request_sync(request_builder); 550 | 551 | let result = request.send(); 552 | match result { 553 | Ok(response) => { 554 | if response.status().is_success() { 555 | Ok(response) 556 | } else { 557 | let response_status = response.status(); 558 | let response_body = response.text().unwrap_or_default(); 559 | debug!("Response Status: {}", &response_status); 560 | utils::debug_pretty_json_from_string("Response Data", &response_body); 561 | 562 | Err(error::ApiError { 563 | message: format!("{}: {}", response_body, response_status), 564 | }) 565 | } 566 | } 567 | Err(error) => Err(error::ApiError { 568 | message: error.to_string(), 569 | }), 570 | } 571 | } 572 | 573 | async fn post_async( 574 | &self, 575 | path: &str, 576 | params: &T, 577 | ) -> Result { 578 | let reqwest_client = reqwest::Client::new(); 579 | let url = format!("{}{}", self.endpoint, path); 580 | debug!("Request URL: {}", url); 581 | utils::debug_pretty_json_from_struct("Request Body", params); 582 | 583 | let request_builder = reqwest_client.post(url).json(params); 584 | let request = self.build_request_async(request_builder); 585 | 586 | let result = request.send().await; 587 | match result { 588 | Ok(response) => { 589 | if response.status().is_success() { 590 | Ok(response) 591 | } else { 592 | let response_status = response.status(); 593 | let response_body = response.text().await.unwrap_or_default(); 594 | debug!("Response Status: {}", &response_status); 595 | utils::debug_pretty_json_from_string("Response Data", &response_body); 596 | 597 | Err(error::ApiError { 598 | message: format!("{}: {}", response_status, response_body), 599 | }) 600 | } 601 | } 602 | Err(error) => Err(error::ApiError { 603 | message: error.to_string(), 604 | }), 605 | } 606 | } 607 | 608 | async fn post_stream( 609 | &self, 610 | path: &str, 611 | params: &T, 612 | ) -> Result { 613 | let reqwest_client = reqwest::Client::new(); 614 | let url = format!("{}{}", self.endpoint, path); 615 | debug!("Request URL: {}", url); 616 | utils::debug_pretty_json_from_struct("Request Body", params); 617 | 618 | let request_builder = reqwest_client.post(url).json(params); 619 | let request = self.build_request_stream(request_builder); 620 | 621 | let result = request.send().await; 622 | match result { 623 | Ok(response) => { 624 | if response.status().is_success() { 625 | Ok(response) 626 | } else { 627 | let response_status = response.status(); 628 | let response_body = response.text().await.unwrap_or_default(); 629 | debug!("Response Status: {}", &response_status); 630 | utils::debug_pretty_json_from_string("Response Data", &response_body); 631 | 632 | Err(error::ApiError { 633 | message: format!("{}: {}", response_status, response_body), 634 | }) 635 | } 636 | } 637 | Err(error) => Err(error::ApiError { 638 | message: error.to_string(), 639 | }), 640 | } 641 | } 642 | 643 | fn to_api_error(&self, err: ReqwestError) -> error::ApiError { 644 | error::ApiError { 645 | message: err.to_string(), 646 | } 647 | } 648 | } 649 | --------------------------------------------------------------------------------