├── rig-core ├── .gitignore ├── tests │ └── data │ │ ├── dummy.pdf │ │ └── pages.pdf ├── examples │ ├── loaders.rs │ ├── simple_model.rs │ ├── agent.rs │ ├── xai_embeddings.rs │ ├── gemini_embeddings.rs │ ├── agent_with_ollama.rs │ ├── anthropic_agent.rs │ ├── sentiment_classifier.rs │ ├── perplexity_agent.rs │ ├── extractor.rs │ ├── agent_with_loaders.rs │ ├── agent_with_context.rs │ ├── gemini_agent.rs │ ├── cohere_connector.rs │ ├── multi_agent.rs │ ├── vector_search_cohere.rs │ ├── vector_search.rs │ ├── agent_with_tools.rs │ ├── debate.rs │ ├── rag.rs │ └── rag_dynamic_tools.rs ├── src │ ├── providers │ │ ├── xai │ │ │ ├── mod.rs │ │ │ └── embedding.rs │ │ ├── anthropic │ │ │ ├── mod.rs │ │ │ └── client.rs │ │ ├── mod.rs │ │ └── gemini │ │ │ └── mod.rs │ ├── embeddings │ │ ├── mod.rs │ │ ├── embedding.rs │ │ └── tool.rs │ ├── json_utils.rs │ ├── loaders │ │ └── mod.rs │ ├── cli_chatbot.rs │ ├── vector_store │ │ └── mod.rs │ ├── lib.rs │ └── extractor.rs ├── rig-core-derive │ ├── Cargo.toml │ └── src │ │ ├── lib.rs │ │ ├── basic.rs │ │ ├── embed.rs │ │ └── custom.rs ├── LICENSE ├── Cargo.toml ├── README.md └── CHANGELOG.md ├── .gitignore ├── Cargo.toml ├── rig-qdrant ├── README.md ├── Cargo.toml ├── CHANGELOG.md ├── LICENSE ├── examples │ └── qdrant_vector_search.rs ├── tests │ └── integration_tests.rs └── src │ └── lib.rs ├── .github ├── ISSUE_TEMPLATE │ ├── new-model-provider.md │ ├── vector-store-integration-request.md │ ├── bug-report.md │ └── feature-or-improvement-request.md ├── workflows │ ├── cd.yaml │ └── ci.yaml └── PULL_REQUEST_TEMPLATE │ ├── new-model-provider.md │ ├── other.md │ └── new-vector-store.md ├── rig-sqlite ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE ├── README.md └── examples │ └── vector_search_sqlite.rs ├── .pre-commit-config.yaml ├── rig-lancedb ├── Cargo.toml ├── LICENSE ├── README.md ├── examples │ ├── vector_search_local_enn.rs │ ├── vector_search_local_ann.rs │ ├── fixtures │ │ └── lib.rs │ └── vector_search_s3_ann.rs ├── CHANGELOG.md ├── tests │ ├── fixtures │ │ └── lib.rs │ └── integration_tests.rs └── src │ └── utils │ └── mod.rs ├── rig-mongodb ├── Cargo.toml ├── LICENSE ├── README.md ├── CHANGELOG.md └── examples │ └── vector_search_mongodb.rs ├── rig-neo4j ├── Cargo.toml ├── LICENSE ├── CHANGELOG.md ├── examples │ ├── display │ │ └── lib.rs │ ├── vector_search_movies_consume.rs │ └── vector_search_simple.rs ├── README.md └── tests │ └── integration_tests.rs ├── LICENSE ├── CONTRIBUTING.md └── img ├── rig_logo.svg └── rig_logo_dark.svg /rig-core/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | *.log 3 | .env 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | target/ 3 | .DS_Store 4 | .idea/ 5 | .vscode/ 6 | -------------------------------------------------------------------------------- /rig-core/tests/data/dummy.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chalingok/rig/HEAD/rig-core/tests/data/dummy.pdf -------------------------------------------------------------------------------- /rig-core/tests/data/pages.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chalingok/rig/HEAD/rig-core/tests/data/pages.pdf -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | members = [ 4 | "rig-core", "rig-lancedb", 5 | "rig-mongodb", "rig-neo4j", 6 | "rig-qdrant", "rig-core/rig-core-derive", 7 | "rig-sqlite" 8 | ] 9 | -------------------------------------------------------------------------------- /rig-qdrant/README.md: -------------------------------------------------------------------------------- 1 | # Rig-Qdrant 2 | Vector store index integration for [Qdrant](https://qdrant.tech/). This integration supports dense vector retrieval using Rig's embedding providers. It is also extensible to allow all [hybrid queries](https://qdrant.tech/documentation/concepts/hybrid-queries/) supported by Qdrant. 3 | 4 | You can find end-to-end examples [here](https://github.com/0xPlaygrounds/rig/tree/main/rig-qdrant/examples). 5 | -------------------------------------------------------------------------------- /rig-core/examples/loaders.rs: -------------------------------------------------------------------------------- 1 | use rig::loaders::FileLoader; 2 | 3 | #[tokio::main] 4 | async fn main() -> Result<(), anyhow::Error> { 5 | FileLoader::with_glob("cargo.toml")? 6 | .read() 7 | .into_iter() 8 | .for_each(|result| match result { 9 | Ok(content) => println!("{}", content), 10 | Err(e) => eprintln!("Error reading file: {}", e), 11 | }); 12 | 13 | Ok(()) 14 | } 15 | -------------------------------------------------------------------------------- /rig-core/src/providers/xai/mod.rs: -------------------------------------------------------------------------------- 1 | //! xAi API client and Rig integration 2 | //! 3 | //! # Example 4 | //! ``` 5 | //! use rig::providers::xai; 6 | //! 7 | //! let client = xai::Client::new("YOUR_API_KEY"); 8 | //! 9 | //! let groq_embedding_model = client.embedding_model(xai::v1); 10 | //! ``` 11 | 12 | pub mod client; 13 | pub mod completion; 14 | pub mod embedding; 15 | 16 | pub use client::Client; 17 | pub use completion::GROK_BETA; 18 | pub use embedding::EMBEDDING_V1; 19 | -------------------------------------------------------------------------------- /rig-core/rig-core-derive/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rig-derive" 3 | version = "0.1.0" 4 | edition = "2021" 5 | license = "MIT" 6 | description = "Internal crate that implements Rig derive macros." 7 | repository = "https://github.com/0xPlaygrounds/rig" 8 | 9 | [dependencies] 10 | indoc = "2.0.5" 11 | proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } 12 | quote = "1.0.37" 13 | syn = { version = "2.0.79", features = ["full"]} 14 | 15 | [lib] 16 | proc-macro = true 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/new-model-provider.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: New model provider 3 | about: Suggest a new model provider to integrate 4 | title: 'feat: Add support for X' 5 | labels: feat, model 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Model Provider Integration Request 11 | 14 | 15 | ### Resources 16 | 19 | -------------------------------------------------------------------------------- /rig-core/examples/simple_model.rs: -------------------------------------------------------------------------------- 1 | use rig::{completion::Prompt, providers::openai}; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | // Create OpenAI client and model 6 | let openai_client = openai::Client::from_env(); 7 | 8 | let gpt4 = openai_client.agent("gpt-4").build(); 9 | 10 | // Prompt the model and print its response 11 | let response = gpt4 12 | .prompt("Who are you?") 13 | .await 14 | .expect("Failed to prompt GPT-4"); 15 | 16 | println!("GPT-4: {response}"); 17 | } 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/vector-store-integration-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Vector store integration request 3 | about: Suggest a new vector store to integrate 4 | title: 'feat: Add support for X vector store' 5 | labels: data store, feat 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Vector Store Integration Request 11 | 14 | 15 | ### Resources 16 | 19 | -------------------------------------------------------------------------------- /rig-core/src/providers/anthropic/mod.rs: -------------------------------------------------------------------------------- 1 | //! Anthropic API client and Rig integration 2 | //! 3 | //! # Example 4 | //! ``` 5 | //! use rig::providers::anthropic; 6 | //! 7 | //! let client = anthropic::Anthropic::new("YOUR_API_KEY"); 8 | //! 9 | //! let sonnet = client.completion_model(anthropic::CLAUDE_3_5_SONNET); 10 | //! ``` 11 | 12 | pub mod client; 13 | pub mod completion; 14 | 15 | pub use client::{Client, ClientBuilder}; 16 | pub use completion::{ 17 | ANTHROPIC_VERSION_2023_01_01, ANTHROPIC_VERSION_2023_06_01, ANTHROPIC_VERSION_LATEST, 18 | CLAUDE_3_5_SONNET, CLAUDE_3_HAIKU, CLAUDE_3_OPUS, CLAUDE_3_SONNET, 19 | }; 20 | -------------------------------------------------------------------------------- /rig-core/src/embeddings/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module provides functionality for working with embeddings. 2 | //! Embeddings are numerical representations of documents or other objects, typically used in 3 | //! natural language processing (NLP) tasks such as text classification, information retrieval, 4 | //! and document similarity. 5 | 6 | pub mod builder; 7 | pub mod embed; 8 | pub mod embedding; 9 | pub mod tool; 10 | 11 | pub mod distance; 12 | pub use builder::EmbeddingsBuilder; 13 | pub use embed::{to_texts, Embed, EmbedError, TextEmbedder}; 14 | pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; 15 | pub use tool::ToolSchema; 16 | -------------------------------------------------------------------------------- /rig-sqlite/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.1.0](https://github.com/0xPlaygrounds/rig/releases/tag/rig-sqlite-v0.1.0) - 2024-12-03 11 | 12 | ### Added 13 | 14 | - Add support for Sqlite vector store ([#122](https://github.com/0xPlaygrounds/rig/pull/122)) 15 | 16 | ### Fixed 17 | 18 | - rig-sqlite missing version in Cargo.toml ([#137](https://github.com/0xPlaygrounds/rig/pull/137)) 19 | - *(rig-sqlite)* Fix missing rig-core version 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: 'bug: ' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | - [ ] I have looked for existing issues (including closed) about this 11 | 12 | ## Bug Report 13 | <!-- 14 | A clear and concise description of what the bug is. 15 | ---> 16 | 17 | ## Reproduction 18 | <!-- 19 | Code snippet. 20 | ---> 21 | 22 | ## Expected behavior 23 | <!-- 24 | A clear and concise description of what you expected to happen. 25 | ---> 26 | 27 | ## Screenshots 28 | <!-- 29 | If applicable, add screenshots to help explain your problem. 30 | ---> 31 | 32 | ## Additional context 33 | <!-- 34 | Add any other context about the problem here. 35 | ---> 36 | -------------------------------------------------------------------------------- /rig-core/src/json_utils.rs: -------------------------------------------------------------------------------- 1 | pub fn merge(a: serde_json::Value, b: serde_json::Value) -> serde_json::Value { 2 | match (a, b) { 3 | (serde_json::Value::Object(mut a_map), serde_json::Value::Object(b_map)) => { 4 | b_map.into_iter().for_each(|(key, value)| { 5 | a_map.insert(key, value); 6 | }); 7 | serde_json::Value::Object(a_map) 8 | } 9 | (a, _) => a, 10 | } 11 | } 12 | 13 | pub fn merge_inplace(a: &mut serde_json::Value, b: serde_json::Value) { 14 | if let (serde_json::Value::Object(a_map), serde_json::Value::Object(b_map)) = (a, b) { 15 | b_map.into_iter().for_each(|(key, value)| { 16 | a_map.insert(key, value); 17 | }); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /rig-core/rig-core-derive/src/lib.rs: -------------------------------------------------------------------------------- 1 | extern crate proc_macro; 2 | use proc_macro::TokenStream; 3 | use syn::{parse_macro_input, DeriveInput}; 4 | 5 | mod basic; 6 | mod custom; 7 | mod embed; 8 | 9 | pub(crate) const EMBED: &str = "embed"; 10 | 11 | /// References: 12 | /// <https://doc.rust-lang.org/book/ch19-06-macros.html#how-to-write-a-custom-derive-macro> 13 | /// <https://doc.rust-lang.org/reference/procedural-macros.html> 14 | #[proc_macro_derive(Embed, attributes(embed))] 15 | pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { 16 | let mut input = parse_macro_input!(item as DeriveInput); 17 | 18 | embed::expand_derive_embedding(&mut input) 19 | .unwrap_or_else(syn::Error::into_compile_error) 20 | .into() 21 | } 22 | -------------------------------------------------------------------------------- /rig-core/examples/agent.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{completion::Prompt, providers}; 4 | 5 | #[tokio::main] 6 | async fn main() -> Result<(), anyhow::Error> { 7 | // Create OpenAI client 8 | let client = providers::openai::Client::new( 9 | &env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"), 10 | ); 11 | 12 | // Create agent with a single context prompt 13 | let comedian_agent = client 14 | .agent("gpt-4o") 15 | .preamble("You are a comedian here to entertain the user using humour and jokes.") 16 | .build(); 17 | 18 | // Prompt the agent and print the response 19 | let response = comedian_agent.prompt("Entertain me!").await?; 20 | println!("{}", response); 21 | 22 | Ok(()) 23 | } 24 | -------------------------------------------------------------------------------- /rig-qdrant/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rig-qdrant" 3 | version = "0.1.3" 4 | edition = "2021" 5 | license = "MIT" 6 | readme = "README.md" 7 | description = "Rig vector store index integration for Qdrant. https://qdrant.tech" 8 | repository = "https://github.com/0xPlaygrounds/rig" 9 | 10 | [dependencies] 11 | rig-core = { path = "../rig-core", version = "0.5.0" } 12 | serde_json = "1.0.128" 13 | serde = "1.0.210" 14 | qdrant-client = "1.12.1" 15 | 16 | [dev-dependencies] 17 | tokio = { version = "1.40.0", features = ["rt-multi-thread"] } 18 | anyhow = "1.0.89" 19 | testcontainers = "0.23.1" 20 | 21 | [[example]] 22 | name = "qdrant_vector_search" 23 | required-features = ["rig-core/derive"] 24 | 25 | 26 | [[test]] 27 | name = "integration_tests" 28 | required-features = ["rig-core/derive"] 29 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.6.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | - id: check-json 12 | - id: check-case-conflict 13 | - id: check-merge-conflict 14 | 15 | 16 | - repo: https://github.com/doublify/pre-commit-rust 17 | rev: v1.0 18 | hooks: 19 | - id: fmt 20 | - id: cargo-check 21 | - id: clippy 22 | 23 | - repo: https://github.com/commitizen-tools/commitizen 24 | rev: v2.20.0 25 | hooks: 26 | - id: commitizen 27 | stages: [commit-msg] 28 | -------------------------------------------------------------------------------- /rig-core/examples/xai_embeddings.rs: -------------------------------------------------------------------------------- 1 | use rig::providers::xai; 2 | use rig::Embed; 3 | 4 | #[derive(Embed, Debug)] 5 | struct Greetings { 6 | #[embed] 7 | message: String, 8 | } 9 | 10 | #[tokio::main] 11 | async fn main() -> Result<(), anyhow::Error> { 12 | // Initialize the xAI client 13 | let client = xai::Client::from_env(); 14 | 15 | let embeddings = client 16 | .embeddings(xai::embedding::EMBEDDING_V1) 17 | .document(Greetings { 18 | message: "Hello, world!".to_string(), 19 | })? 20 | .document(Greetings { 21 | message: "Goodbye, world!".to_string(), 22 | })? 23 | .build() 24 | .await 25 | .expect("Failed to embed documents"); 26 | 27 | println!("{:?}", embeddings); 28 | 29 | Ok(()) 30 | } 31 | -------------------------------------------------------------------------------- /rig-core/examples/gemini_embeddings.rs: -------------------------------------------------------------------------------- 1 | use rig::providers::gemini; 2 | use rig::Embed; 3 | 4 | #[derive(Embed, Debug)] 5 | struct Greetings { 6 | #[embed] 7 | message: String, 8 | } 9 | 10 | #[tokio::main] 11 | async fn main() -> Result<(), anyhow::Error> { 12 | // Initialize the Google Gemini client 13 | // Create OpenAI client 14 | let client = gemini::Client::from_env(); 15 | 16 | let embeddings = client 17 | .embeddings(gemini::embedding::EMBEDDING_001) 18 | .document(Greetings { 19 | message: "Hello, world!".to_string(), 20 | })? 21 | .document(Greetings { 22 | message: "Goodbye, world!".to_string(), 23 | })? 24 | .build() 25 | .await 26 | .expect("Failed to embed documents"); 27 | 28 | println!("{:?}", embeddings); 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-or-improvement-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature or improvement request 3 | about: Suggest an idea for this project 4 | title: 'feat: <title>' 5 | labels: feat 6 | assignees: '' 7 | 8 | --- 9 | 10 | - [ ] I have looked for existing issues (including closed) about this 11 | 12 | ## Feature Request 13 | <!-- 14 | High level description of the requested feature or improvement. 15 | --> 16 | 17 | ### Motivation 18 | <!-- 19 | Please describe the use case(s) or other motivation for the new feature. 20 | --> 21 | 22 | ### Proposal 23 | <!-- 24 | How should the new feature be implemented, and why? Add any considered 25 | drawbacks. 26 | --> 27 | 28 | ### Alternatives 29 | <!-- 30 | Are there other ways to solve this problem that you've considered? What are 31 | their potential drawbacks? Why was the proposed solution chosen over these 32 | alternatives? 33 | --> 34 | -------------------------------------------------------------------------------- /rig-sqlite/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rig-sqlite" 3 | version = "0.1.0" 4 | edition = "2021" 5 | description = "SQLite-based vector store implementation for the rig framework" 6 | license = "MIT" 7 | 8 | [lib] 9 | doctest = false 10 | 11 | [dependencies] 12 | rig-core = { path = "../rig-core", version = "0.5.0", features = ["derive"] } 13 | rusqlite = { version = "0.32", features = ["bundled"] } 14 | serde = { version = "1.0", features = ["derive"] } 15 | serde_json = "1.0" 16 | sqlite-vec = "0.1" 17 | tokio-rusqlite = { git = "https://github.com/programatik29/tokio-rusqlite", version = "0.6.0", features = [ 18 | "bundled", 19 | ] } 20 | tracing = "0.1" 21 | zerocopy = "0.8.10" 22 | chrono = "0.4" 23 | 24 | [dev-dependencies] 25 | anyhow = "1.0.86" 26 | tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } 27 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 28 | -------------------------------------------------------------------------------- /rig-core/examples/agent_with_ollama.rs: -------------------------------------------------------------------------------- 1 | /// This example requires that you have the [`ollama`](https://ollama.com) server running locally. 2 | use rig::{completion::Prompt, providers}; 3 | 4 | #[tokio::main] 5 | async fn main() -> Result<(), anyhow::Error> { 6 | // Create an OpenAI client with a custom base url, a local ollama endpoint 7 | // The API Key is unnecessary for most local endpoints 8 | let client = providers::openai::Client::from_url("ollama", "http://localhost:11434"); 9 | 10 | // Create agent with a single context prompt 11 | let comedian_agent = client 12 | .agent("llama3.2:latest") 13 | .preamble("You are a comedian here to entertain the user using humour and jokes.") 14 | .build(); 15 | 16 | // Prompt the agent and print the response 17 | let response = comedian_agent.prompt("Entertain me!").await?; 18 | println!("{}", response); 19 | 20 | Ok(()) 21 | } 22 | -------------------------------------------------------------------------------- /rig-core/examples/anthropic_agent.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{ 4 | completion::Prompt, 5 | providers::anthropic::{self, CLAUDE_3_5_SONNET}, 6 | }; 7 | 8 | #[tokio::main] 9 | async fn main() -> Result<(), anyhow::Error> { 10 | // Create OpenAI client 11 | let client = anthropic::ClientBuilder::new( 12 | &env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"), 13 | ) 14 | .build(); 15 | 16 | // Create agent with a single context prompt 17 | let agent = client 18 | .agent(CLAUDE_3_5_SONNET) 19 | .preamble("Be precise and concise.") 20 | .temperature(0.5) 21 | .max_tokens(8192) 22 | .build(); 23 | 24 | // Prompt the agent and print the response 25 | let response = agent 26 | .prompt("When and where and what type is the next solar eclipse?") 27 | .await?; 28 | println!("{}", response); 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /rig-lancedb/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rig-lancedb" 3 | version = "0.2.0" 4 | edition = "2021" 5 | license = "MIT" 6 | readme = "README.md" 7 | description = "Rig vector store index integration for LanceDB." 8 | repository = "https://github.com/0xPlaygrounds/rig" 9 | 10 | [dependencies] 11 | lancedb = "0.10.0" 12 | rig-core = { path = "../rig-core", version = "0.5.0" } 13 | arrow-array = "52.2.0" 14 | serde_json = "1.0.128" 15 | serde = "1.0.210" 16 | futures = "0.3.30" 17 | 18 | [dev-dependencies] 19 | tokio = "1.40.0" 20 | anyhow = "1.0.89" 21 | 22 | [[example]] 23 | name = "vector_search_local_ann" 24 | required-features = ["rig-core/derive"] 25 | 26 | [[example]] 27 | name = "vector_search_local_enn" 28 | required-features = ["rig-core/derive"] 29 | 30 | [[example]] 31 | name = "vector_search_s3_ann" 32 | required-features = ["rig-core/derive"] 33 | 34 | [[test]] 35 | name = "integration_tests" 36 | required-features = ["rig-core/derive"] -------------------------------------------------------------------------------- /rig-mongodb/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rig-mongodb" 3 | version = "0.2.0" 4 | edition = "2021" 5 | license = "MIT" 6 | readme = "README.md" 7 | description = "MongoDB implementation of a Rig vector store." 8 | repository = "https://github.com/0xPlaygrounds/rig" 9 | 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dependencies] 13 | futures = "0.3.30" 14 | mongodb = "3.1.0" 15 | rig-core = { path = "../rig-core", version = "0.5.0" } 16 | serde = { version = "1.0.203", features = ["derive"] } 17 | serde_json = "1.0.117" 18 | tracing = "0.1.40" 19 | 20 | [dev-dependencies] 21 | anyhow = "1.0.86" 22 | testcontainers = "0.23.1" 23 | tokio = { version = "1.38.0", features = ["macros"] } 24 | 25 | [[example]] 26 | name = "vector_search_mongodb" 27 | required-features = ["rig-core/derive"] 28 | 29 | [[test]] 30 | name = "integration_tests" 31 | required-features = ["rig-core/derive"] 32 | -------------------------------------------------------------------------------- /rig-qdrant/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.1.3](https://github.com/0xPlaygrounds/rig/compare/rig-qdrant-v0.1.2...rig-qdrant-v0.1.3) - 2024-12-03 11 | 12 | ### Added 13 | 14 | - embeddings API overhaul ([#120](https://github.com/0xPlaygrounds/rig/pull/120)) 15 | 16 | ### Other 17 | 18 | - *(integration test)* Neo4J ([#133](https://github.com/0xPlaygrounds/rig/pull/133)) 19 | - *(integration test)* Qdrant ([#134](https://github.com/0xPlaygrounds/rig/pull/134)) 20 | 21 | ## [0.1.2](https://github.com/0xPlaygrounds/rig/compare/rig-qdrant-v0.1.1...rig-qdrant-v0.1.2) - 2024-11-13 22 | 23 | ### Other 24 | 25 | - updated the following local packages: rig-core 26 | -------------------------------------------------------------------------------- /rig-core/src/loaders/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module provides utility structs for loading and preprocessing files. 2 | //! 3 | //! The [FileLoader] struct can be used to define a common interface for loading any type of files from disk, 4 | //! as well as performing minimal preprocessing on the files, such as reading their contents, ignoring errors 5 | //! and keeping track of file paths along with their contents. 6 | //! 7 | //! The [PdfFileLoader] works similarly to the [FileLoader], but is specifically designed to load PDF 8 | //! files. This loader also provides PDF-specific preprocessing methods for splitting the PDF into pages 9 | //! and keeping track of the page numbers along with their contents. 10 | //! 11 | //! Note: The [PdfFileLoader] requires the `pdf` feature to be enabled in the `Cargo.toml` file. 12 | 13 | pub mod file; 14 | 15 | pub use file::FileLoader; 16 | 17 | #[cfg(feature = "pdf")] 18 | pub mod pdf; 19 | 20 | #[cfg(feature = "pdf")] 21 | pub use pdf::PdfFileLoader; 22 | -------------------------------------------------------------------------------- /rig-core/rig-core-derive/src/basic.rs: -------------------------------------------------------------------------------- 1 | use syn::{parse_quote, Attribute, DataStruct, Meta}; 2 | 3 | use crate::EMBED; 4 | 5 | /// Finds and returns fields with simple `#[embed]` attribute tags only. 6 | pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator<Item = &syn::Field> { 7 | data_struct.fields.iter().filter(|field| { 8 | field.attrs.iter().any(|attribute| match attribute { 9 | Attribute { 10 | meta: Meta::Path(path), 11 | .. 12 | } => path.is_ident(EMBED), 13 | _ => false, 14 | }) 15 | }) 16 | } 17 | 18 | /// Adds bounds to where clause that force all fields tagged with `#[embed]` to implement the `Embed` trait. 19 | pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { 20 | let where_clause = generics.make_where_clause(); 21 | 22 | where_clause.predicates.push(parse_quote! { 23 | #field_type: Embed 24 | }); 25 | } 26 | -------------------------------------------------------------------------------- /rig-core/examples/sentiment_classifier.rs: -------------------------------------------------------------------------------- 1 | use rig::providers::openai; 2 | use schemars::JsonSchema; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Deserialize, JsonSchema, Serialize)] 6 | /// An enum representing the sentiment of a document 7 | enum Sentiment { 8 | Positive, 9 | Negative, 10 | Neutral, 11 | } 12 | 13 | #[derive(Debug, Deserialize, JsonSchema, Serialize)] 14 | struct DocumentSentiment { 15 | /// The sentiment of the document 16 | sentiment: Sentiment, 17 | } 18 | 19 | #[tokio::main] 20 | async fn main() { 21 | // Create OpenAI client 22 | let openai_client = openai::Client::from_env(); 23 | 24 | // Create extractor 25 | let data_extractor = openai_client 26 | .extractor::<DocumentSentiment>("gpt-4") 27 | .build(); 28 | 29 | let sentiment = data_extractor 30 | .extract("I am happy") 31 | .await 32 | .expect("Failed to extract sentiment"); 33 | 34 | println!("GPT-4: {:?}", sentiment); 35 | } 36 | -------------------------------------------------------------------------------- /rig-neo4j/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rig-neo4j" 3 | version = "0.2.0" 4 | edition = "2021" 5 | license = "MIT" 6 | readme = "README.md" 7 | description = "Neo4j implementation of a Rig vector store." 8 | repository = "https://github.com/0xPlaygrounds/rig" 9 | 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dependencies] 13 | futures = "0.3.30" 14 | neo4rs = "0.8.0" 15 | rig-core = { path = "../rig-core", version = "0.5.0" } 16 | serde = { version = "1.0.203", features = ["derive"] } 17 | serde_json = "1.0.117" 18 | tracing = "0.1.40" 19 | 20 | [dev-dependencies] 21 | anyhow = "1.0.86" 22 | tokio = { version = "1.38.0", features = ["macros"] } 23 | textwrap = { version = "0.16.1"} 24 | term_size = { version = "0.3.2"} 25 | testcontainers = "0.23.1" 26 | tracing-subscriber = "0.3.18" 27 | 28 | [[example]] 29 | name = "vector_search_simple" 30 | required-features = ["rig-core/derive"] 31 | 32 | [[test]] 33 | name = "integration_tests" 34 | required-features = ["rig-core/derive"] 35 | -------------------------------------------------------------------------------- /.github/workflows/cd.yaml: -------------------------------------------------------------------------------- 1 | name: "Build & Release" 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | jobs: 10 | run-ci: 11 | permissions: 12 | checks: write 13 | uses: ./.github/workflows/ci.yaml 14 | secrets: inherit 15 | 16 | release-plz: 17 | name: Release-plz 18 | needs: run-ci 19 | runs-on: ubuntu-latest 20 | permissions: 21 | pull-requests: write 22 | contents: write 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | with: 27 | fetch-depth: 0 28 | 29 | - name: Install Rust toolchain 30 | uses: actions-rust-lang/setup-rust-toolchain@v1 31 | 32 | # Required to compile rig-lancedb 33 | - name: Install Protoc 34 | uses: arduino/setup-protoc@v3 35 | 36 | - name: Run release-plz 37 | uses: MarcoIeni/release-plz-action@v0.5 38 | env: 39 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 40 | CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, Playgrounds Analytics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /rig-core/examples/perplexity_agent.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{ 4 | completion::Prompt, 5 | providers::{self, perplexity::LLAMA_3_1_70B_INSTRUCT}, 6 | }; 7 | use serde_json::json; 8 | 9 | #[tokio::main] 10 | async fn main() -> Result<(), anyhow::Error> { 11 | // Create OpenAI client 12 | let client = providers::perplexity::Client::new( 13 | &env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set"), 14 | ); 15 | 16 | // Create agent with a single context prompt 17 | let agent = client 18 | .agent(LLAMA_3_1_70B_INSTRUCT) 19 | .preamble("Be precise and concise.") 20 | .temperature(0.5) 21 | .additional_params(json!({ 22 | "return_related_questions": true, 23 | "return_images": true 24 | })) 25 | .build(); 26 | 27 | // Prompt the agent and print the response 28 | let response = agent 29 | .prompt("When and where and what type is the next solar eclipse?") 30 | .await?; 31 | println!("{}", response); 32 | 33 | Ok(()) 34 | } 35 | -------------------------------------------------------------------------------- /rig-core/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, Playgrounds Analytics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /rig-lancedb/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, Playgrounds Analytics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /rig-mongodb/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, Playgrounds Analytics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /rig-neo4j/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, Playgrounds Analytics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /rig-qdrant/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, Playgrounds Analytics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /rig-sqlite/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, Playgrounds Analytics Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /rig-core/examples/extractor.rs: -------------------------------------------------------------------------------- 1 | use rig::providers::openai; 2 | use schemars::JsonSchema; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Deserialize, JsonSchema, Serialize)] 6 | /// A record representing a person 7 | struct Person { 8 | /// The person's first name, if provided (null otherwise) 9 | pub first_name: Option<String>, 10 | /// The person's last name, if provided (null otherwise) 11 | pub last_name: Option<String>, 12 | /// The person's job, if provided (null otherwise) 13 | pub job: Option<String>, 14 | } 15 | 16 | #[tokio::main] 17 | async fn main() -> Result<(), anyhow::Error> { 18 | // Create OpenAI client 19 | let openai_client = openai::Client::from_env(); 20 | 21 | // Create extractor 22 | let data_extractor = openai_client.extractor::<Person>("gpt-4").build(); 23 | 24 | let person = data_extractor 25 | .extract("Hello my name is John Doe! I am a software engineer.") 26 | .await?; 27 | 28 | println!("GPT-4: {}", serde_json::to_string_pretty(&person).unwrap()); 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /rig-core/examples/agent_with_loaders.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{ 4 | agent::AgentBuilder, 5 | completion::Prompt, 6 | loaders::FileLoader, 7 | providers::openai::{self, GPT_4O}, 8 | }; 9 | 10 | #[tokio::main] 11 | async fn main() -> Result<(), anyhow::Error> { 12 | let openai_client = 13 | openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")); 14 | 15 | let model = openai_client.completion_model(GPT_4O); 16 | 17 | // Load in all the rust examples 18 | let examples = FileLoader::with_glob("rig-core/examples/*.rs")? 19 | .read_with_path() 20 | .ignore_errors() 21 | .into_iter(); 22 | 23 | // Create an agent with multiple context documents 24 | let agent = examples 25 | .fold(AgentBuilder::new(model), |builder, (path, content)| { 26 | builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str()) 27 | }) 28 | .build(); 29 | 30 | // Prompt the agent and print the response 31 | let response = agent 32 | .prompt("Which rust example is best suited for the operation 1 + 2") 33 | .await?; 34 | 35 | println!("{}", response); 36 | 37 | Ok(()) 38 | } 39 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/new-model-provider.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: New model provider 3 | about: Suggest a new model provider to integrate 4 | title: 'feat: Add support for X' 5 | labels: feat, model 6 | assignees: '' 7 | 8 | --- 9 | 10 | # New model provider: <Model Provider Name> 11 | 12 | ## Description 13 | 14 | Please describe the model provider you are adding to the project. Include links to their website and their api documentation. 15 | 16 | Fixes # (issue) 17 | 18 | ## Testing 19 | 20 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce your results. 21 | 22 | - [ ] Test A 23 | - [ ] Test B 24 | 25 | ## Checklist: 26 | 27 | - [ ] My code follows the style guidelines of this project 28 | - [ ] I have commented my code, particularly in hard-to-understand areas 29 | - [ ] I have made corresponding changes to the documentation 30 | - [ ] My changes generate no new warnings 31 | - [ ] I have added tests that prove my fix is effective or that my feature works 32 | - [ ] New and existing unit tests pass locally with my changes 33 | - [ ] I've reviewed the provider API documentation and implemented the types of response accurately 34 | 35 | ## Notes 36 | 37 | Any notes you wish to include about the nature of this PR (implementation details, specific questions, etc.) 38 | -------------------------------------------------------------------------------- /rig-core/examples/agent_with_context.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{agent::AgentBuilder, completion::Prompt, providers::cohere}; 4 | 5 | #[tokio::main] 6 | async fn main() -> Result<(), anyhow::Error> { 7 | // Create OpenAI and Cohere clients 8 | // let openai_client = openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")); 9 | let cohere_client = 10 | cohere::Client::new(&env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set")); 11 | 12 | // let model = openai_client.completion_model("gpt-4"); 13 | let model = cohere_client.completion_model("command-r"); 14 | 15 | // Create an agent with multiple context documents 16 | let agent = AgentBuilder::new(model) 17 | .context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") 18 | .context("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") 19 | .context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") 20 | .build(); 21 | 22 | // Prompt the agent and print the response 23 | let response = agent.prompt("What does \"glarb-glarb\" mean?").await?; 24 | 25 | println!("{}", response); 26 | 27 | Ok(()) 28 | } 29 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/other.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General pull request 3 | about: Makes a change to the code base 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | # <Pull Request Title> 11 | 12 | ## Description 13 | 14 | Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. 15 | 16 | Fixes # (issue) 17 | 18 | ## Type of change 19 | 20 | Please delete options that are not relevant. 21 | 22 | - [ ] Bug fix 23 | - [ ] New feature 24 | - [ ] Breaking change 25 | - [ ] Documentation update 26 | 27 | ## Testing 28 | 29 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce your results. 30 | 31 | - [ ] Test A 32 | - [ ] Test B 33 | 34 | ## Checklist: 35 | 36 | - [ ] My code follows the style guidelines of this project 37 | - [ ] I have commented my code, particularly in hard-to-understand areas 38 | - [ ] I have made corresponding changes to the documentation 39 | - [ ] My changes generate no new warnings 40 | - [ ] I have added tests that prove my fix is effective or that my feature works 41 | - [ ] New and existing unit tests pass locally with my changes 42 | 43 | ## Notes 44 | 45 | Any notes you wish to include about the nature of this PR (implementation details, specific questions, etc.) 46 | -------------------------------------------------------------------------------- /rig-neo4j/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.2.0](https://github.com/0xPlaygrounds/rig/compare/rig-neo4j-v0.1.2...rig-neo4j-v0.2.0) - 2024-12-03 11 | 12 | ### Added 13 | 14 | - embeddings API overhaul ([#120](https://github.com/0xPlaygrounds/rig/pull/120)) 15 | 16 | ### Fixed 17 | 18 | - *(neo4j)* remove embeddings from top_n lookup ([#118](https://github.com/0xPlaygrounds/rig/pull/118)) 19 | 20 | ### Other 21 | 22 | - *(integration test)* Neo4J ([#133](https://github.com/0xPlaygrounds/rig/pull/133)) 23 | 24 | ## [0.1.2](https://github.com/0xPlaygrounds/rig/compare/rig-neo4j-v0.1.1...rig-neo4j-v0.1.2) - 2024-11-13 25 | 26 | ### Other 27 | 28 | - updated the following local packages: rig-core 29 | 30 | ## [0.1.1](https://github.com/0xPlaygrounds/rig/compare/rig-neo4j-v0.1.0...rig-neo4j-v0.1.1) - 2024-11-07 31 | 32 | ### Fixed 33 | 34 | - *(neo4j)* last minute doc and const adjustments 35 | 36 | ## [0.1.0](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.0.7...rig-mongodb-v0.1.0) - 2024-10-22 37 | 38 | ### Features 39 | 40 | - initial implementation 41 | - supports `top_n` search for an existing index and database 42 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/new-vector-store.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Vector store integration request 3 | about: Suggest a new vector store to integrate 4 | title: 'feat: Add support for X vector store' 5 | labels: data store, feat 6 | assignees: '' 7 | 8 | --- 9 | 10 | # New vector store: <Vector Store Name> 11 | 12 | ## Description 13 | 14 | Please describe the vector store you are adding to the project. Include links to their website and their api documentation. 15 | 16 | Fixes # (issue) 17 | 18 | ## Testing 19 | 20 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce your results. 21 | 22 | - [ ] Test A 23 | - [ ] Test B 24 | 25 | ## Checklist: 26 | 27 | - [ ] My code follows the style guidelines of this project 28 | - [ ] I have performed a self-review of my own code 29 | - [ ] I have commented my code, particularly in hard-to-understand areas 30 | - [ ] I have made corresponding changes to the documentation 31 | - [ ] My changes generate no new warnings 32 | - [ ] I have added tests that prove my fix is effective or that my feature works 33 | - [ ] New and existing unit tests pass locally with my changes 34 | - [ ] Any dependent changes have been merged and published in downstream modules 35 | - [ ] I've reviewed the vector store API documentation and implemented the types of response accurately 36 | 37 | ## Notes 38 | 39 | Any notes you wish to include about the nature of this PR (implementation details, specific questions, etc.) 40 | -------------------------------------------------------------------------------- /rig-lancedb/README.md: -------------------------------------------------------------------------------- 1 | <!-- <div style="display: flex; align-items: center; justify-content: center;"> 2 | <picture> 3 | <source media="(prefers-color-scheme: dark)" srcset="../img/rig_logo_dark.svg"> 4 | <source media="(prefers-color-scheme: light)" srcset="../img/rig_logo.svg"> 5 | <img src="../img/rig_logo.svg" width="200" alt="Rig logo"> 6 | </picture> 7 | <span style="font-size: 48px; margin: 0 20px; font-weight: regular; font-family: Open Sans, sans-serif;"> + </span> 8 | <picture> 9 | <source media="(prefers-color-scheme: dark)" srcset="https://companieslogo.com/img/orig/MDB_BIG.D-96d632a9.png?t=1720244492"> 10 | <source media="(prefers-color-scheme: light)" srcset="https://cdn.iconscout.com/icon/free/png-256/free-mongodb-logo-icon-download-in-svg-png-gif-file-formats--wordmark-programming-langugae-freebies-pack-logos-icons-1175140.png?f=webp&w=256"> 11 | <img src="https://cdn.iconscout.com/icon/free/png-256/free-mongodb-logo-icon-download-in-svg-png-gif-file-formats--wordmark-programming-langugae-freebies-pack-logos-icons-1175140.png?f=webp&w=256" width="200" alt="MongoDB logo"> 12 | </picture> 13 | </div> 14 | 15 | <br><br> --> 16 | 17 | ## Rig-Lancedb 18 | This companion crate implements a Rig vector store based on Lancedb. 19 | 20 | ## Usage 21 | 22 | Add the companion crate to your `Cargo.toml`, along with the rig-core crate: 23 | 24 | ```toml 25 | [dependencies] 26 | rig-lancedb = "0.1.0" 27 | rig-core = "0.4.0" 28 | ``` 29 | 30 | You can also run `cargo add rig-lancedb rig-core` to add the most recent versions of the dependencies to your project. 31 | 32 | See the [`/examples`](./examples) folder for usage examples. 33 | -------------------------------------------------------------------------------- /rig-mongodb/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | <div style="display: flex; align-items: center; justify-content: center;"> 4 | <picture> 5 | <source media="(prefers-color-scheme: dark)" srcset="../img/rig_logo_dark.svg"> 6 | <source media="(prefers-color-scheme: light)" srcset="../img/rig_logo.svg"> 7 | <img src="../img/rig_logo.svg" width="200" alt="Rig logo"> 8 | </picture> 9 | <span style="font-size: 48px; margin: 0 20px; font-weight: regular; font-family: Open Sans, sans-serif;"> + </span> 10 | <picture> 11 | <source media="(prefers-color-scheme: dark)" srcset="https://companieslogo.com/img/orig/MDB_BIG.D-96d632a9.png?t=1720244492"> 12 | <source media="(prefers-color-scheme: light)" srcset="https://cdn.iconscout.com/icon/free/png-256/free-mongodb-logo-icon-download-in-svg-png-gif-file-formats--wordmark-programming-langugae-freebies-pack-logos-icons-1175140.png?f=webp&w=256"> 13 | <img src="https://cdn.iconscout.com/icon/free/png-256/free-mongodb-logo-icon-download-in-svg-png-gif-file-formats--wordmark-programming-langugae-freebies-pack-logos-icons-1175140.png?f=webp&w=256" width="200" alt="MongoDB logo"> 14 | </picture> 15 | </div> 16 | 17 | <br><br> 18 | 19 | ## Rig-MongoDB 20 | This companion crate implements a Rig vector store based on MongoDB. 21 | 22 | ## Usage 23 | 24 | Add the companion crate to your `Cargo.toml`, along with the rig-core crate: 25 | 26 | ```toml 27 | [dependencies] 28 | rig-mongodb = "0.1.3" 29 | rig-core = "0.4.0" 30 | ``` 31 | 32 | You can also run `cargo add rig-mongodb rig-core` to add the most recent versions of the dependencies to your project. 33 | 34 | See the [`/examples`](./examples) folder for usage examples. 35 | -------------------------------------------------------------------------------- /rig-core/examples/gemini_agent.rs: -------------------------------------------------------------------------------- 1 | use rig::{ 2 | completion::Prompt, 3 | providers::gemini::{self, completion::gemini_api_types::GenerationConfig}, 4 | }; 5 | #[tracing::instrument(ret)] 6 | #[tokio::main] 7 | 8 | async fn main() -> Result<(), anyhow::Error> { 9 | tracing_subscriber::fmt() 10 | .with_max_level(tracing::Level::DEBUG) 11 | .with_target(false) 12 | .init(); 13 | 14 | // Initialize the Google Gemini client 15 | let client = gemini::Client::from_env(); 16 | 17 | // Create agent with a single context prompt 18 | let agent = client 19 | .agent(gemini::completion::GEMINI_1_5_PRO) 20 | .preamble("Be creative and concise. Answer directly and clearly.") 21 | .temperature(0.5) 22 | // The `GenerationConfig` utility struct helps construct a typesafe `additional_params` 23 | .additional_params(serde_json::to_value(GenerationConfig { 24 | top_k: Some(1), 25 | top_p: Some(0.95), 26 | candidate_count: Some(1), 27 | ..Default::default() 28 | })?) // Unwrap the Result to get the Value 29 | .build(); 30 | 31 | tracing::info!("Prompting the agent..."); 32 | 33 | // Prompt the agent and print the response 34 | let response = agent 35 | .prompt("How much wood would a woodchuck chuck if a woodchuck could chuck wood? Infer an answer.") 36 | .await; 37 | 38 | tracing::info!("Response: {:?}", response); 39 | 40 | match response { 41 | Ok(response) => println!("{}", response), 42 | Err(e) => { 43 | tracing::error!("Error: {:?}", e); 44 | return Err(e.into()); 45 | } 46 | } 47 | 48 | Ok(()) 49 | } 50 | -------------------------------------------------------------------------------- /rig-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rig-core" 3 | version = "0.5.0" 4 | edition = "2021" 5 | license = "MIT" 6 | readme = "README.md" 7 | description = "An opinionated library for building LLM powered applications." 8 | repository = "https://github.com/0xPlaygrounds/rig" 9 | 10 | [lib] 11 | name = "rig" 12 | path = "src/lib.rs" 13 | doctest = false 14 | 15 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 16 | 17 | [dependencies] 18 | reqwest = { version = "0.11.22", features = ["json"] } 19 | serde = { version = "1.0.193", features = ["derive"] } 20 | serde_json = "1.0.108" 21 | tracing = "0.1.40" 22 | futures = "0.3.29" 23 | ordered-float = "4.2.0" 24 | schemars = "0.8.16" 25 | thiserror = "1.0.61" 26 | rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true } 27 | glob = "0.3.1" 28 | lopdf = { version = "0.34.0", optional = true } 29 | rayon = { version = "1.10.0", optional = true} 30 | 31 | [dev-dependencies] 32 | anyhow = "1.0.75" 33 | assert_fs = "1.1.2" 34 | tokio = { version = "1.34.0", features = ["full"] } 35 | tracing-subscriber = "0.3.18" 36 | tokio-test = "0.4.4" 37 | 38 | [features] 39 | all = ["derive", "pdf", "rayon"] 40 | derive = ["dep:rig-derive"] 41 | pdf = ["dep:lopdf"] 42 | rayon = ["dep:rayon"] 43 | 44 | [[test]] 45 | name = "embed_macro" 46 | required-features = ["derive"] 47 | 48 | [[example]] 49 | name = "rag" 50 | required-features = ["derive"] 51 | 52 | [[example]] 53 | name = "vector_search" 54 | required-features = ["derive"] 55 | 56 | [[example]] 57 | name = "vector_search_cohere" 58 | required-features = ["derive"] 59 | 60 | [[example]] 61 | name = "gemini_embeddings" 62 | required-features = ["derive"] 63 | 64 | [[example]] 65 | name = "xai_embeddings" 66 | required-features = ["derive"] 67 | -------------------------------------------------------------------------------- /rig-core/src/providers/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module contains clients for the different LLM providers that Rig supports. 2 | //! 3 | //! Currently, the following providers are supported: 4 | //! - Cohere 5 | //! - OpenAI 6 | //! - Perplexity 7 | //! - Anthropic 8 | //! - Google Gemini 9 | //! 10 | //! Each provider has its own module, which contains a `Client` implementation that can 11 | //! be used to initialize completion and embedding models and execute requests to those models. 12 | //! 13 | //! The clients also contain methods to easily create higher level AI constructs such as 14 | //! agents and RAG systems, reducing the need for boilerplate. 15 | //! 16 | //! # Example 17 | //! ``` 18 | //! use rig::{providers::openai, agent::AgentBuilder}; 19 | //! 20 | //! // Initialize the OpenAI client 21 | //! let openai = openai::Client::new("your-openai-api-key"); 22 | //! 23 | //! // Create a model and initialize an agent 24 | //! let gpt_4o = openai.completion_model("gpt-4o"); 25 | //! 26 | //! let agent = AgentBuilder::new(gpt_4o) 27 | //! .preamble("\ 28 | //! You are Gandalf the white and you will be conversing with other \ 29 | //! powerful beings to discuss the fate of Middle Earth.\ 30 | //! ") 31 | //! .build(); 32 | //! 33 | //! // Alternatively, you can initialize an agent directly 34 | //! let agent = openai.agent("gpt-4o") 35 | //! .preamble("\ 36 | //! You are Gandalf the white and you will be conversing with other \ 37 | //! powerful beings to discuss the fate of Middle Earth.\ 38 | //! ") 39 | //! .build(); 40 | //! ``` 41 | //! Note: The example above uses the OpenAI provider client, but the same pattern can 42 | //! be used with the Cohere provider client. 43 | pub mod anthropic; 44 | pub mod cohere; 45 | pub mod gemini; 46 | pub mod openai; 47 | pub mod perplexity; 48 | pub mod xai; 49 | -------------------------------------------------------------------------------- /rig-sqlite/README.md: -------------------------------------------------------------------------------- 1 | <div style="display: flex; align-items: center; justify-content: center;"> 2 | <picture> 3 | <source media="(prefers-color-scheme: dark)" srcset="../img/rig_logo_dark.svg"> 4 | <source media="(prefers-color-scheme: light)" srcset="../img/rig_logo.svg"> 5 | <img src="../img/rig_logo.svg" width="200" alt="Rig logo"> 6 | </picture> 7 | <span style="font-size: 48px; margin: 0 20px; font-weight: regular; font-family: Open Sans, sans-serif;"> + </span> 8 | <picture> 9 | <source media="(prefers-color-scheme: dark)" srcset="https://www.sqlite.org/images/sqlite370_banner.gif"> 10 | <source media="(prefers-color-scheme: light)" srcset="https://www.sqlite.org/images/sqlite370_banner.gif"> 11 | <img src="https://www.sqlite.org/images/sqlite370_banner.gif" width="200" alt="SQLite logo"> 12 | </picture> 13 | </div> 14 | 15 | <br><br> 16 | 17 | ## Rig-SQLite 18 | 19 | This companion crate implements a Rig vector store based on SQLite. 20 | 21 | ## Usage 22 | 23 | Add the companion crate to your `Cargo.toml`, along with the rig-core crate: 24 | 25 | ```toml 26 | [dependencies] 27 | rig-sqlite = "0.1.3" 28 | rig-core = "0.4.0" 29 | ``` 30 | 31 | You can also run `cargo add rig-sqlite rig-core` to add the most recent versions of the dependencies to your project. 32 | 33 | See the [`/examples`](./examples) folder for usage examples. 34 | 35 | ## Important Note 36 | 37 | Before using the SQLite vector store, you must [initialize the SQLite vector extension](https://alexgarcia.xyz/sqlite-vec/rust.html). Add this code before creating your connection: 38 | 39 | ```rust 40 | use rusqlite::ffi::sqlite3_auto_extension; 41 | use sqlite_vec::sqlite3_vec_init; 42 | 43 | unsafe { 44 | sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /rig-core/examples/cohere_connector.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{ 4 | completion::{Completion, Prompt}, 5 | providers::cohere::Client as CohereClient, 6 | }; 7 | use serde_json::json; 8 | 9 | #[tokio::main] 10 | async fn main() -> Result<(), anyhow::Error> { 11 | // Create Cohere client 12 | let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); 13 | let cohere_client = CohereClient::new(&cohere_api_key); 14 | 15 | let klimadao_agent = cohere_client 16 | .agent("command-r") 17 | .temperature(0.0) 18 | .additional_params(json!({ 19 | "connectors": [{"id":"web-search", "options":{"site": "https://docs.klimadao.finance"}}] 20 | })) 21 | .build(); 22 | 23 | // Prompt the model and print the response 24 | // We use `prompt` to get a simple response from the model as a String 25 | let response = klimadao_agent.prompt("Tell me about BCT tokens?").await?; 26 | 27 | println!("\n\nCoral: {:?}", response); 28 | 29 | // Prompt the model and get the citations 30 | // We use `completion` to allow use to customize the request further and 31 | // get a more detailed response from the model. 32 | // Here the response is of type CompletionResponse<cohere::CompletionResponse> 33 | // which contains `choice` (Message or ToolCall) as well as `raw_response`, 34 | // the underlying providers' raw response. 35 | let response = klimadao_agent 36 | .completion("Tell me about BCT tokens?", vec![]) 37 | .await? 38 | .additional_params(json!({ 39 | "connectors": [{"id":"web-search", "options":{"site": "https://docs.klimadao.finance"}}] 40 | })) 41 | .send() 42 | .await?; 43 | 44 | println!( 45 | "\n\nCoral: {:?}\n\nCitations:\n{:?}", 46 | response.choice, response.raw_response.citations 47 | ); 48 | 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /rig-core/README.md: -------------------------------------------------------------------------------- 1 | # Rig 2 | Rig is a Rust library for building LLM-powered applications that focuses on ergonomics and modularity. 3 | 4 | More information about this crate can be found in the [crate documentation](https://docs.rs/rig-core/latest/rig/). 5 | ## Table of contents 6 | 7 | - [Rig](#rig) 8 | - [Table of contents](#table-of-contents) 9 | - [High-level features](#high-level-features) 10 | - [Installation](#installation) 11 | - [Simple example:](#simple-example) 12 | - [Integrations](#integrations) 13 | 14 | ## High-level features 15 | - Full support for LLM completion and embedding workflows 16 | - Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory) 17 | - Integrate LLMs in your app with minimal boilerplate 18 | 19 | ## Installation 20 | ```bash 21 | cargo add rig-core 22 | ``` 23 | 24 | ## Simple example: 25 | ```rust 26 | use rig::{completion::Prompt, providers::openai}; 27 | 28 | #[tokio::main] 29 | async fn main() { 30 | // Create OpenAI client and model 31 | // This requires the `OPENAI_API_KEY` environment variable to be set. 32 | let openai_client = openai::Client::from_env(); 33 | 34 | let gpt4 = openai_client.model("gpt-4").build(); 35 | 36 | // Prompt the model and print its response 37 | let response = gpt4 38 | .prompt("Who are you?") 39 | .await 40 | .expect("Failed to prompt GPT-4"); 41 | 42 | println!("GPT-4: {response}"); 43 | } 44 | ``` 45 | Note using `#[tokio::main]` requires you enable tokio's `macros` and `rt-multi-thread` features 46 | or just `full` to enable all features (`cargo add tokio --features macros,rt-multi-thread`). 47 | 48 | ## Integrations 49 | Rig supports the following LLM providers natively: 50 | - OpenAI 51 | - Cohere 52 | - Google Gemini 53 | - xAI 54 | 55 | Additionally, Rig currently has the following integration sub-libraries: 56 | - MongoDB vector store: `rig-mongodb` 57 | -------------------------------------------------------------------------------- /rig-core/src/cli_chatbot.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Write}; 2 | 3 | use crate::completion::{Chat, Message, PromptError}; 4 | 5 | /// Utility function to create a simple REPL CLI chatbot from a type that implements the 6 | /// `Chat` trait. 7 | pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> { 8 | let stdin = io::stdin(); 9 | let mut stdout = io::stdout(); 10 | let mut chat_log = vec![]; 11 | 12 | println!("Welcome to the chatbot! Type 'exit' to quit."); 13 | loop { 14 | print!("> "); 15 | // Flush stdout to ensure the prompt appears before input 16 | stdout.flush().unwrap(); 17 | 18 | let mut input = String::new(); 19 | match stdin.read_line(&mut input) { 20 | Ok(_) => { 21 | // Remove the newline character from the input 22 | let input = input.trim(); 23 | // Check for a command to exit 24 | if input == "exit" { 25 | break; 26 | } 27 | tracing::info!("Prompt:\n{}\n", input); 28 | 29 | let response = chatbot.chat(input, chat_log.clone()).await?; 30 | chat_log.push(Message { 31 | role: "user".into(), 32 | content: input.into(), 33 | }); 34 | chat_log.push(Message { 35 | role: "assistant".into(), 36 | content: response.clone(), 37 | }); 38 | 39 | println!("========================== Response ============================"); 40 | println!("{response}"); 41 | println!("================================================================\n\n"); 42 | 43 | tracing::info!("Response:\n{}\n", response); 44 | } 45 | Err(error) => println!("Error reading input: {}", error), 46 | } 47 | } 48 | 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /rig-lancedb/examples/vector_search_local_enn.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use arrow_array::RecordBatchIterator; 4 | use fixture::{as_record_batch, schema, words}; 5 | use rig::{ 6 | embeddings::{EmbeddingModel, EmbeddingsBuilder}, 7 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 8 | vector_store::VectorStoreIndexDyn, 9 | }; 10 | use rig_lancedb::{LanceDbVectorIndex, SearchParams}; 11 | 12 | #[path = "./fixtures/lib.rs"] 13 | mod fixture; 14 | 15 | #[tokio::main] 16 | async fn main() -> Result<(), anyhow::Error> { 17 | // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). 18 | let openai_client = Client::from_env(); 19 | 20 | // Select the embedding model and generate our embeddings 21 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 22 | 23 | // Generate embeddings for the test data. 24 | let embeddings = EmbeddingsBuilder::new(model.clone()) 25 | .documents(words())? 26 | .build() 27 | .await?; 28 | 29 | // Define search_params params that will be used by the vector store to perform the vector search. 30 | let search_params = SearchParams::default(); 31 | 32 | // Initialize LanceDB locally. 33 | let db = lancedb::connect("data/lancedb-store").execute().await?; 34 | 35 | let table = db 36 | .create_table( 37 | "definitions", 38 | RecordBatchIterator::new( 39 | vec![as_record_batch(embeddings, model.ndims())], 40 | Arc::new(schema(model.ndims())), 41 | ), 42 | ) 43 | .execute() 44 | .await?; 45 | 46 | let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; 47 | 48 | // Query the index 49 | let results = vector_store 50 | .top_n_ids("My boss says I zindle too much, what does that mean?", 1) 51 | .await?; 52 | 53 | println!("Results: {:?}", results); 54 | 55 | Ok(()) 56 | } 57 | -------------------------------------------------------------------------------- /rig-mongodb/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.2.0](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.5...rig-mongodb-v0.2.0) - 2024-12-03 11 | 12 | ### Added 13 | 14 | - embeddings API overhaul ([#120](https://github.com/0xPlaygrounds/rig/pull/120)) 15 | 16 | ### Fixed 17 | 18 | - *(rig-mongodb)* remove embeddings from `top_n` lookup ([#115](https://github.com/0xPlaygrounds/rig/pull/115)) 19 | 20 | ### Other 21 | 22 | - *(integration test)* MongoDB ([#126](https://github.com/0xPlaygrounds/rig/pull/126)) 23 | 24 | ## [0.1.5](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.4...rig-mongodb-v0.1.5) - 2024-11-13 25 | 26 | ### Other 27 | 28 | - updated the following local packages: rig-core 29 | 30 | ## [0.1.4](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.3...rig-mongodb-v0.1.4) - 2024-11-07 31 | 32 | ### Added 33 | 34 | - Qdrant support 35 | 36 | ### Fixed 37 | 38 | - wrong reference to companion crate 39 | - missing qdrant readme reference 40 | 41 | ### Other 42 | 43 | - update deps version 44 | - add coloured logos for integrations 45 | 46 | ## [0.1.3](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.2...rig-mongodb-v0.1.3) - 2024-10-24 47 | 48 | ### Fixed 49 | 50 | - make PR changes pt I 51 | - mongodb vector search - use num_candidates from search params 52 | 53 | ### Other 54 | 55 | - Merge branch 'main' into feat(vector-store)/lancedb 56 | 57 | ## [0.1.2](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.1...rig-mongodb-v0.1.2) - 2024-10-01 58 | 59 | ### Other 60 | 61 | - updated the following local packages: rig-core 62 | 63 | ## [0.1.1](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.1.0...rig-mongodb-v0.1.1) - 2024-10-01 64 | 65 | ### Other 66 | 67 | - updated the following local packages: rig-core 68 | 69 | ## [0.1.0](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.0.7...rig-mongodb-v0.1.0) - 2024-09-16 70 | 71 | ### Other 72 | 73 | - fmt code 74 | -------------------------------------------------------------------------------- /rig-core/src/providers/gemini/mod.rs: -------------------------------------------------------------------------------- 1 | //! Google Gemini API client and Rig integration 2 | //! 3 | //! # Example 4 | //! ``` 5 | //! use rig::providers::google; 6 | //! 7 | //! let client = google::Client::new("YOUR_API_KEY"); 8 | //! 9 | //! let gemini_embedding_model = client.embedding_model(google::EMBEDDING_001); 10 | //! ``` 11 | 12 | pub mod client; 13 | pub mod completion; 14 | pub mod embedding; 15 | pub use client::Client; 16 | 17 | pub mod gemini_api_types { 18 | use serde::{Deserialize, Serialize}; 19 | 20 | #[derive(Serialize, Deserialize, Debug)] 21 | #[serde(rename_all = "SCREAMING_SNAKE_CASE")] 22 | pub enum ExecutionLanguage { 23 | /// Unspecified language. This value should not be used. 24 | LanguageUnspecified, 25 | /// Python >= 3.10, with numpy and simpy available. 26 | Python, 27 | } 28 | 29 | /// Code generated by the model that is meant to be executed, and the result returned to the model. 30 | /// Only generated when using the CodeExecution tool, in which the code will be automatically executed, 31 | /// and a corresponding CodeExecutionResult will also be generated. 32 | #[derive(Debug, Deserialize, Serialize)] 33 | pub struct ExecutableCode { 34 | /// Programming language of the code. 35 | pub language: ExecutionLanguage, 36 | /// The code to be executed. 37 | pub code: String, 38 | } 39 | #[derive(Serialize, Deserialize, Debug)] 40 | pub struct CodeExecutionResult { 41 | /// Outcome of the code execution. 42 | pub outcome: CodeExecutionOutcome, 43 | /// Contains stdout when code execution is successful, stderr or other description otherwise. 44 | pub output: Option<String>, 45 | } 46 | 47 | #[derive(Serialize, Deserialize, Debug)] 48 | #[serde(rename_all = "SCREAMING_SNAKE_CASE")] 49 | pub enum CodeExecutionOutcome { 50 | /// Unspecified status. This value should not be used. 51 | Unspecified, 52 | /// Code execution completed successfully. 53 | Ok, 54 | /// Code execution finished but with a failure. stderr should contain the reason. 55 | Failed, 56 | /// Code execution ran for too long, and was cancelled. There may or may not be a partial output present. 57 | DeadlineExceeded, 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /rig-neo4j/examples/display/lib.rs: -------------------------------------------------------------------------------- 1 | // ===================================================== 2 | // Utilities to print results from a Neo4j vector search 3 | // ===================================================== 4 | 5 | use std::fmt::Display; 6 | 7 | #[allow(dead_code)] 8 | #[derive(Debug)] 9 | pub struct SearchResult { 10 | pub title: String, 11 | pub id: String, 12 | pub description: String, 13 | pub score: f64, 14 | } 15 | 16 | pub struct SearchResults<'a>(pub &'a Vec<SearchResult>); 17 | 18 | impl<'a> Display for SearchResults<'a> { 19 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 20 | let width = term_size::dimensions().map(|(w, _)| w).unwrap_or(150); 21 | let title_width = 40; 22 | let id_width = 10; 23 | let description_width = width - title_width - id_width - 2; // 2 for spaces 24 | 25 | writeln!( 26 | f, 27 | "{:<title_width$} {:<id_width$} {:<description_width$}", 28 | "Title", "ID", "Description" 29 | )?; 30 | writeln!(f, "{}", "-".repeat(width))?; 31 | for result in self.0 { 32 | let wrapped_title = textwrap::fill(&result.title, title_width); 33 | let wrapped_description = textwrap::fill(&result.description, description_width); 34 | let title_lines: Vec<&str> = wrapped_title.lines().collect(); 35 | let description_lines: Vec<&str> = wrapped_description.lines().collect(); 36 | let max_lines = title_lines.len().max(description_lines.len()); 37 | 38 | for i in 0..max_lines { 39 | let title_line = title_lines.get(i).unwrap_or(&""); 40 | let description_line = description_lines.get(i).unwrap_or(&""); 41 | if i == 0 { 42 | writeln!( 43 | f, 44 | "{:<title_width$} {:<id_width$} {:<description_width$}", 45 | title_line, result.id, description_line 46 | )?; 47 | } else { 48 | writeln!( 49 | f, 50 | "{:<title_width$} {:<id_width$} {:<description_width$}", 51 | title_line, "", description_line 52 | )?; 53 | } 54 | } 55 | } 56 | Ok(()) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Rig 2 | 3 | Thank you for considering contributing to Rig! Here are some guidelines to help you get started. 4 | 5 | ## Issues 6 | 7 | Before reporting an issue, please check existing or similar issues that are currently tracked. 8 | 9 | ## Pull Requests 10 | 11 | Contributions are always encouraged and welcome. Before creating a pull request, create a new issue that tracks that pull request describing the problem in more detail. Pull request descriptions should include information about it's implementation, especially if it makes changes to existing abstractions. 12 | 13 | PRs should be small and focused and should avoid interacting with multiple facets of the library. This may result in a larger PR being split into two or more smaller PRs. Commit messages should follow the [Conventional Commit](conventionalcommits.org/en/v1.0.0) format (prefixing with `feat`, `fix`, etc.) as this integrates into our auto-releases via a [release-plz](https://github.com/MarcoIeni/release-plz) Github action. 14 | 15 | ## Project Structure 16 | 17 | Rig is split up into multiple crates in a monorepo structure. The main crate `rig-core` contains all of the foundational abstractions for building with LLMs. This crate avoids adding many new dependencies to keep to lean and only really contains simple provider integrations on top of the base layer of abstractions. Side crates are leveraged to help add important first-party behavior without over burdening the main library with dependencies. For example, `rig-mongodb` contains extra dependencies to be able to interact with `mongodb` as a vector store. 18 | 19 | If you are unsure whether a side-crate should live in the main repo, you can spin up a personal repo containing your crate and create an issue in our repo making the case on whether this side-crate should be integrated in the main repo and maintained by the Rig team. 20 | 21 | 22 | ## Developing 23 | 24 | ### Setup 25 | 26 | This should be similar to most rust projects. 27 | 28 | ```bash 29 | git clone https://github.com/0xplaygrounds/rig 30 | cd rig 31 | cargo test 32 | ``` 33 | 34 | ### Clippy and Fmt 35 | 36 | We enforce both `clippy` and `fmt` for all pull requests. 37 | 38 | ```bash 39 | cargo clippy -- -D warnings 40 | ``` 41 | 42 | ```bash 43 | cargo fmt 44 | ``` 45 | 46 | 47 | ### Tests 48 | 49 | Make sure to test against the test suite before making a pull request. 50 | 51 | ```bash 52 | cargo test 53 | ``` 54 | -------------------------------------------------------------------------------- /rig-core/examples/multi_agent.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{ 4 | agent::{Agent, AgentBuilder}, 5 | cli_chatbot::cli_chatbot, 6 | completion::{Chat, CompletionModel, Message, PromptError}, 7 | providers::openai::Client as OpenAIClient, 8 | }; 9 | 10 | /// Represents a multi agent application that consists of two components: 11 | /// an agent specialized in translating prompt into english and a simple GPT-4 model. 12 | /// When prompted, the application will use the translator agent to translate the 13 | /// prompt in english, before answering it with GPT-4. The answer in english is returned. 14 | struct EnglishTranslator<M: CompletionModel> { 15 | translator_agent: Agent<M>, 16 | gpt4: Agent<M>, 17 | } 18 | 19 | impl<M: CompletionModel> EnglishTranslator<M> { 20 | fn new(model: M) -> Self { 21 | Self { 22 | // Create the translator agent 23 | translator_agent: AgentBuilder::new(model.clone()) 24 | .preamble("\ 25 | You are a translator assistant that will translate any input text into english. \ 26 | If the text is already in english, simply respond with the original text but fix any mistakes (grammar, syntax, etc.). \ 27 | ") 28 | .build(), 29 | 30 | // Create the GPT4 model 31 | gpt4: AgentBuilder::new(model).build() 32 | } 33 | } 34 | } 35 | 36 | impl<M: CompletionModel> Chat for EnglishTranslator<M> { 37 | async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> { 38 | // Translate the prompt using the translator agent 39 | let translated_prompt = self 40 | .translator_agent 41 | .chat(prompt, chat_history.clone()) 42 | .await?; 43 | 44 | println!("Translated prompt: {}", translated_prompt); 45 | 46 | // Answer the prompt using gpt4 47 | self.gpt4.chat(&translated_prompt, chat_history).await 48 | } 49 | } 50 | 51 | #[tokio::main] 52 | async fn main() -> Result<(), anyhow::Error> { 53 | // Create OpenAI client 54 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 55 | let openai_client = OpenAIClient::new(&openai_api_key); 56 | let model = openai_client.completion_model("gpt-4"); 57 | 58 | // Create OpenAI client 59 | // let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); 60 | // let cohere_client = CohereClient::new(&cohere_api_key); 61 | // let model = cohere_client.completion_model("command-r"); 62 | 63 | // Create model 64 | let translator = EnglishTranslator::new(model); 65 | 66 | // Spin up a chatbot using the agent 67 | cli_chatbot(translator).await?; 68 | 69 | Ok(()) 70 | } 71 | -------------------------------------------------------------------------------- /rig-lancedb/examples/vector_search_local_ann.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use arrow_array::RecordBatchIterator; 4 | use fixture::{as_record_batch, schema, words, Word}; 5 | use lancedb::index::vector::IvfPqIndexBuilder; 6 | use rig::{ 7 | embeddings::{EmbeddingModel, EmbeddingsBuilder}, 8 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 9 | vector_store::VectorStoreIndex, 10 | }; 11 | use rig_lancedb::{LanceDbVectorIndex, SearchParams}; 12 | 13 | #[path = "./fixtures/lib.rs"] 14 | mod fixture; 15 | 16 | #[tokio::main] 17 | async fn main() -> Result<(), anyhow::Error> { 18 | // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). 19 | let openai_client = Client::from_env(); 20 | 21 | // Select an embedding model. 22 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 23 | 24 | // Initialize LanceDB locally. 25 | let db = lancedb::connect("data/lancedb-store").execute().await?; 26 | 27 | // Generate embeddings for the test data. 28 | let embeddings = EmbeddingsBuilder::new(model.clone()) 29 | .documents(words())? 30 | // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. 31 | .documents( 32 | (0..256) 33 | .map(|i| Word { 34 | id: format!("doc{}", i), 35 | definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() 36 | }) 37 | )? 38 | .build() 39 | .await?; 40 | 41 | let table = db 42 | .create_table( 43 | "definitions", 44 | RecordBatchIterator::new( 45 | vec![as_record_batch(embeddings, model.ndims())], 46 | Arc::new(schema(model.ndims())), 47 | ), 48 | ) 49 | .execute() 50 | .await?; 51 | 52 | // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information 53 | table 54 | .create_index( 55 | &["embedding"], 56 | lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), 57 | ) 58 | .execute() 59 | .await?; 60 | 61 | // Define search_params params that will be used by the vector store to perform the vector search. 62 | let search_params = SearchParams::default(); 63 | let vector_store_index = LanceDbVectorIndex::new(table, model, "id", search_params).await?; 64 | 65 | // Query the index 66 | let results = vector_store_index 67 | .top_n::<Word>("My boss says I zindle too much, what does that mean?", 1) 68 | .await?; 69 | 70 | println!("Results: {:?}", results); 71 | 72 | Ok(()) 73 | } 74 | -------------------------------------------------------------------------------- /rig-core/src/embeddings/embedding.rs: -------------------------------------------------------------------------------- 1 | //! The module defines the [EmbeddingModel] trait, which represents an embedding model that can 2 | //! generate embeddings for documents. 3 | //! 4 | //! The module also defines the [Embedding] struct, which represents a single document embedding. 5 | //! 6 | //! Finally, the module defines the [EmbeddingError] enum, which represents various errors that 7 | //! can occur during embedding generation or processing. 8 | 9 | use serde::{Deserialize, Serialize}; 10 | 11 | #[derive(Debug, thiserror::Error)] 12 | pub enum EmbeddingError { 13 | /// Http error (e.g.: connection error, timeout, etc.) 14 | #[error("HttpError: {0}")] 15 | HttpError(#[from] reqwest::Error), 16 | 17 | /// Json error (e.g.: serialization, deserialization) 18 | #[error("JsonError: {0}")] 19 | JsonError(#[from] serde_json::Error), 20 | 21 | /// Error processing the document for embedding 22 | #[error("DocumentError: {0}")] 23 | DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>), 24 | 25 | /// Error parsing the completion response 26 | #[error("ResponseError: {0}")] 27 | ResponseError(String), 28 | 29 | /// Error returned by the embedding model provider 30 | #[error("ProviderError: {0}")] 31 | ProviderError(String), 32 | } 33 | 34 | /// Trait for embedding models that can generate embeddings for documents. 35 | pub trait EmbeddingModel: Clone + Sync + Send { 36 | /// The maximum number of documents that can be embedded in a single request. 37 | const MAX_DOCUMENTS: usize; 38 | 39 | /// The number of dimensions in the embedding vector. 40 | fn ndims(&self) -> usize; 41 | 42 | /// Embed multiple text documents in a single request 43 | fn embed_texts( 44 | &self, 45 | texts: impl IntoIterator<Item = String> + Send, 46 | ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send; 47 | 48 | /// Embed a single text document. 49 | fn embed_text( 50 | &self, 51 | text: &str, 52 | ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send { 53 | async { 54 | Ok(self 55 | .embed_texts(vec![text.to_string()]) 56 | .await? 57 | .pop() 58 | .expect("There should be at least one embedding")) 59 | } 60 | } 61 | } 62 | 63 | /// Struct that holds a single document and its embedding. 64 | #[derive(Clone, Default, Deserialize, Serialize, Debug)] 65 | pub struct Embedding { 66 | /// The document that was embedded. Used for debugging. 67 | pub document: String, 68 | /// The embedding vector 69 | pub vec: Vec<f64>, 70 | } 71 | 72 | impl PartialEq for Embedding { 73 | fn eq(&self, other: &Self) -> bool { 74 | self.document == other.document 75 | } 76 | } 77 | 78 | impl Eq for Embedding {} 79 | -------------------------------------------------------------------------------- /rig-lancedb/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.2.0](https://github.com/0xPlaygrounds/rig/compare/rig-lancedb-v0.1.2...rig-lancedb-v0.2.0) - 2024-12-03 11 | 12 | ### Added 13 | 14 | - embeddings API overhaul ([#120](https://github.com/0xPlaygrounds/rig/pull/120)) 15 | 16 | ### Fixed 17 | 18 | - *(rig-lancedb)* rag embedding filtering ([#104](https://github.com/0xPlaygrounds/rig/pull/104)) 19 | 20 | ## [0.1.2](https://github.com/0xPlaygrounds/rig/compare/rig-lancedb-v0.1.1...rig-lancedb-v0.1.2) - 2024-11-13 21 | 22 | ### Other 23 | 24 | - update Cargo.lock dependencies 25 | 26 | ## [0.1.1](https://github.com/0xPlaygrounds/rig/compare/rig-lancedb-v0.1.0...rig-lancedb-v0.1.1) - 2024-11-07 27 | 28 | ### Fixed 29 | 30 | - wrong reference to companion crate 31 | - missing qdrant readme reference 32 | 33 | ### Other 34 | 35 | - update deps version 36 | - add coloured logos for integrations 37 | - *(readme)* test new logo coloration 38 | 39 | ## [0.1.0](https://github.com/0xPlaygrounds/rig/releases/tag/rig-lancedb-v0.1.0) - 2024-10-24 40 | 41 | ### Added 42 | 43 | - update examples to use new version of VectorStoreIndex trait 44 | - replace document embeddings with serde json value 45 | - merge all arrow columns into JSON document in deserializer 46 | - finish implementing deserialiser for record batch 47 | - implement deserialization for any recordbatch returned from lanceDB 48 | - add indexes and tables for simple search 49 | - create enum for embedding models 50 | - add vector_search_s3_ann example 51 | - implement ANN search example 52 | - start implementing top_n_from_query for trait VectorStoreIndex 53 | - implement get_document method of VectorStore trait 54 | - implement search by id for VectorStore trait 55 | - implement add_documents on VectorStore trait 56 | - start implementing VectorStore trait for lancedb 57 | 58 | ### Fixed 59 | 60 | - update lancedb examples test data 61 | - make PR changes Pt II 62 | - make PR changes pt I 63 | - *(lancedb)* replace VectorStoreIndexDyn with VectorStoreIndex in examples 64 | - mongodb vector search - use num_candidates from search params 65 | - fix bug in deserializing type run end 66 | - make PR requested changes 67 | - reduce opanai generated content in ANN examples 68 | 69 | ### Other 70 | 71 | - cargo fmt 72 | - lance db examples 73 | - add example docstring 74 | - add doc strings 75 | - update rig core version on lancedb crate, remove implementation of VectorStore trait 76 | - remove print statement 77 | - use constants instead of enum for model names 78 | - remove associated type on VectorStoreIndex trait 79 | - cargo fmt 80 | - conversions from arrow types to primitive types 81 | - Add doc strings to utility methods 82 | - add doc string to mongodb search params struct 83 | - Merge branch 'main' into feat(vector-store)/lancedb 84 | - create wrapper for vec<DocumentEmbeddings> for from/tryfrom traits 85 | -------------------------------------------------------------------------------- /rig-lancedb/tests/fixtures/lib.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; 4 | use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; 5 | use rig::embeddings::Embedding; 6 | use rig::{Embed, OneOrMany}; 7 | use serde::Deserialize; 8 | 9 | #[derive(Embed, Clone, Deserialize, Debug)] 10 | pub struct Word { 11 | pub id: String, 12 | #[embed] 13 | pub definition: String, 14 | } 15 | 16 | pub fn words() -> Vec<Word> { 17 | vec![ 18 | Word { 19 | id: "doc0".to_string(), 20 | definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() 21 | }, 22 | Word { 23 | id: "doc1".to_string(), 24 | definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() 25 | }, 26 | Word { 27 | id: "doc2".to_string(), 28 | definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() 29 | } 30 | ] 31 | } 32 | 33 | // Schema of table in LanceDB. 34 | pub fn schema(dims: usize) -> Schema { 35 | Schema::new(Fields::from(vec![ 36 | Field::new("id", DataType::Utf8, false), 37 | Field::new("definition", DataType::Utf8, false), 38 | Field::new( 39 | "embedding", 40 | DataType::FixedSizeList( 41 | Arc::new(Field::new("item", DataType::Float64, true)), 42 | dims as i32, 43 | ), 44 | false, 45 | ), 46 | ])) 47 | } 48 | 49 | // Convert Word objects and their embedding to a RecordBatch. 50 | pub fn as_record_batch( 51 | records: Vec<(Word, OneOrMany<Embedding>)>, 52 | dims: usize, 53 | ) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> { 54 | let id = StringArray::from_iter_values( 55 | records 56 | .iter() 57 | .map(|(Word { id, .. }, _)| id) 58 | .collect::<Vec<_>>(), 59 | ); 60 | 61 | let definition = StringArray::from_iter_values( 62 | records 63 | .iter() 64 | .map(|(Word { definition, .. }, _)| definition) 65 | .collect::<Vec<_>>(), 66 | ); 67 | 68 | let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>( 69 | records 70 | .into_iter() 71 | .map(|(_, embeddings)| { 72 | Some( 73 | embeddings 74 | .first() 75 | .vec 76 | .into_iter() 77 | .map(Some) 78 | .collect::<Vec<_>>(), 79 | ) 80 | }) 81 | .collect::<Vec<_>>(), 82 | dims as i32, 83 | ); 84 | 85 | RecordBatch::try_from_iter(vec![ 86 | ("id", Arc::new(id) as ArrayRef), 87 | ("definition", Arc::new(definition) as ArrayRef), 88 | ("embedding", Arc::new(embedding) as ArrayRef), 89 | ]) 90 | } 91 | -------------------------------------------------------------------------------- /rig-lancedb/examples/fixtures/lib.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; 4 | use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; 5 | use rig::embeddings::Embedding; 6 | use rig::{Embed, OneOrMany}; 7 | use serde::Deserialize; 8 | 9 | #[derive(Embed, Clone, Deserialize, Debug)] 10 | pub struct Word { 11 | pub id: String, 12 | #[embed] 13 | pub definition: String, 14 | } 15 | 16 | pub fn words() -> Vec<Word> { 17 | vec![ 18 | Word { 19 | id: "doc0".to_string(), 20 | definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() 21 | }, 22 | Word { 23 | id: "doc1".to_string(), 24 | definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() 25 | }, 26 | Word { 27 | id: "doc2".to_string(), 28 | definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() 29 | } 30 | ] 31 | } 32 | 33 | // Schema of table in LanceDB. 34 | pub fn schema(dims: usize) -> Schema { 35 | Schema::new(Fields::from(vec![ 36 | Field::new("id", DataType::Utf8, false), 37 | Field::new("definition", DataType::Utf8, false), 38 | Field::new( 39 | "embedding", 40 | DataType::FixedSizeList( 41 | Arc::new(Field::new("item", DataType::Float64, true)), 42 | dims as i32, 43 | ), 44 | false, 45 | ), 46 | ])) 47 | } 48 | 49 | // Convert Word objects and their embedding to a RecordBatch. 50 | pub fn as_record_batch( 51 | records: Vec<(Word, OneOrMany<Embedding>)>, 52 | dims: usize, 53 | ) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> { 54 | let id = StringArray::from_iter_values( 55 | records 56 | .iter() 57 | .map(|(Word { id, .. }, _)| id) 58 | .collect::<Vec<_>>(), 59 | ); 60 | 61 | let definition = StringArray::from_iter_values( 62 | records 63 | .iter() 64 | .map(|(Word { definition, .. }, _)| definition) 65 | .collect::<Vec<_>>(), 66 | ); 67 | 68 | let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>( 69 | records 70 | .into_iter() 71 | .map(|(_, embeddings)| { 72 | Some( 73 | embeddings 74 | .first() 75 | .vec 76 | .into_iter() 77 | .map(Some) 78 | .collect::<Vec<_>>(), 79 | ) 80 | }) 81 | .collect::<Vec<_>>(), 82 | dims as i32, 83 | ); 84 | 85 | RecordBatch::try_from_iter(vec![ 86 | ("id", Arc::new(id) as ArrayRef), 87 | ("definition", Arc::new(definition) as ArrayRef), 88 | ("embedding", Arc::new(embedding) as ArrayRef), 89 | ]) 90 | } 91 | -------------------------------------------------------------------------------- /rig-core/examples/vector_search_cohere.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{ 4 | embeddings::EmbeddingsBuilder, 5 | providers::cohere::{Client, EMBED_ENGLISH_V3}, 6 | vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, 7 | Embed, 8 | }; 9 | use serde::{Deserialize, Serialize}; 10 | 11 | // Shape of data that needs to be RAG'ed. 12 | // The definition field will be used to generate embeddings. 13 | #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] 14 | struct WordDefinition { 15 | id: String, 16 | word: String, 17 | #[embed] 18 | definitions: Vec<String>, 19 | } 20 | 21 | #[tokio::main] 22 | async fn main() -> Result<(), anyhow::Error> { 23 | // Create Cohere client 24 | let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); 25 | let cohere_client = Client::new(&cohere_api_key); 26 | 27 | let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); 28 | let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); 29 | 30 | let embeddings = EmbeddingsBuilder::new(document_model.clone()) 31 | .documents(vec![ 32 | WordDefinition { 33 | id: "doc0".to_string(), 34 | word: "flurbo".to_string(), 35 | definitions: vec![ 36 | "A green alien that lives on cold planets.".to_string(), 37 | "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() 38 | ] 39 | }, 40 | WordDefinition { 41 | id: "doc1".to_string(), 42 | word: "glarb-glarb".to_string(), 43 | definitions: vec![ 44 | "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 45 | "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() 46 | ] 47 | }, 48 | WordDefinition { 49 | id: "doc2".to_string(), 50 | word: "linglingdong".to_string(), 51 | definitions: vec![ 52 | "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), 53 | "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() 54 | ] 55 | }, 56 | ])? 57 | .build() 58 | .await?; 59 | 60 | // Create vector store with the embeddings 61 | let vector_store = 62 | InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone()); 63 | 64 | // Create vector store index 65 | let index = vector_store.index(search_model); 66 | 67 | let results = index 68 | .top_n::<WordDefinition>( 69 | "Which instrument is found in the Nebulon Mountain Ranges?", 70 | 1, 71 | ) 72 | .await? 73 | .into_iter() 74 | .map(|(score, id, doc)| (score, id, doc.word)) 75 | .collect::<Vec<_>>(); 76 | 77 | println!("Results: {:?}", results); 78 | 79 | Ok(()) 80 | } 81 | -------------------------------------------------------------------------------- /rig-core/src/embeddings/tool.rs: -------------------------------------------------------------------------------- 1 | //! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::ToolEmbedding] 2 | 3 | use crate::{tool::ToolEmbeddingDyn, Embed}; 4 | use serde::Serialize; 5 | 6 | use super::embed::EmbedError; 7 | 8 | /// Embeddable document that is used as an intermediate representation of a tool when 9 | /// RAGging tools. 10 | #[derive(Clone, Serialize, Default, Eq, PartialEq)] 11 | pub struct ToolSchema { 12 | pub name: String, 13 | pub context: serde_json::Value, 14 | pub embedding_docs: Vec<String>, 15 | } 16 | 17 | impl Embed for ToolSchema { 18 | fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> { 19 | for doc in &self.embedding_docs { 20 | embedder.embed(doc.clone()); 21 | } 22 | Ok(()) 23 | } 24 | } 25 | 26 | impl ToolSchema { 27 | /// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema]. 28 | /// 29 | /// # Example 30 | /// ```rust 31 | /// use rig::{ 32 | /// completion::ToolDefinition, 33 | /// embeddings::ToolSchema, 34 | /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn}, 35 | /// }; 36 | /// use serde_json::json; 37 | /// 38 | /// #[derive(Debug, thiserror::Error)] 39 | /// #[error("Math error")] 40 | /// struct NothingError; 41 | /// 42 | /// #[derive(Debug, thiserror::Error)] 43 | /// #[error("Init error")] 44 | /// struct InitError; 45 | /// 46 | /// struct Nothing; 47 | /// impl Tool for Nothing { 48 | /// const NAME: &'static str = "nothing"; 49 | /// 50 | /// type Error = NothingError; 51 | /// type Args = (); 52 | /// type Output = (); 53 | /// 54 | /// async fn definition(&self, _prompt: String) -> ToolDefinition { 55 | /// serde_json::from_value(json!({ 56 | /// "name": "nothing", 57 | /// "description": "nothing", 58 | /// "parameters": {} 59 | /// })) 60 | /// .expect("Tool Definition") 61 | /// } 62 | /// 63 | /// async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { 64 | /// Ok(()) 65 | /// } 66 | /// } 67 | /// 68 | /// impl ToolEmbedding for Nothing { 69 | /// type InitError = InitError; 70 | /// type Context = (); 71 | /// type State = (); 72 | /// 73 | /// fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> { 74 | /// Ok(Nothing) 75 | /// } 76 | /// 77 | /// fn embedding_docs(&self) -> Vec<String> { 78 | /// vec!["Do nothing.".into()] 79 | /// } 80 | /// 81 | /// fn context(&self) -> Self::Context {} 82 | /// } 83 | /// 84 | /// let tool = ToolSchema::try_from(&Nothing).unwrap(); 85 | /// 86 | /// assert_eq!(tool.name, "nothing".to_string()); 87 | /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); 88 | /// ``` 89 | pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbedError> { 90 | Ok(ToolSchema { 91 | name: tool.name(), 92 | context: tool.context().map_err(EmbedError::new)?, 93 | embedding_docs: tool.embedding_docs(), 94 | }) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /rig-core/examples/vector_search.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use rig::{ 4 | embeddings::EmbeddingsBuilder, 5 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 6 | vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, 7 | Embed, 8 | }; 9 | use serde::{Deserialize, Serialize}; 10 | 11 | // Shape of data that needs to be RAG'ed. 12 | // The definition field will be used to generate embeddings. 13 | #[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] 14 | struct WordDefinition { 15 | id: String, 16 | word: String, 17 | #[embed] 18 | definitions: Vec<String>, 19 | } 20 | 21 | #[tokio::main] 22 | async fn main() -> Result<(), anyhow::Error> { 23 | // Create OpenAI client 24 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 25 | let openai_client = Client::new(&openai_api_key); 26 | 27 | let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 28 | 29 | let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) 30 | .documents(vec![ 31 | WordDefinition { 32 | id: "doc0".to_string(), 33 | word: "flurbo".to_string(), 34 | definitions: vec![ 35 | "A green alien that lives on cold planets.".to_string(), 36 | "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() 37 | ] 38 | }, 39 | WordDefinition { 40 | id: "doc1".to_string(), 41 | word: "glarb-glarb".to_string(), 42 | definitions: vec![ 43 | "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 44 | "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() 45 | ] 46 | }, 47 | WordDefinition { 48 | id: "doc2".to_string(), 49 | word: "linglingdong".to_string(), 50 | definitions: vec![ 51 | "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), 52 | "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() 53 | ] 54 | }, 55 | ])? 56 | .build() 57 | .await?; 58 | 59 | // Create vector store with the embeddings 60 | let vector_store = 61 | InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone()); 62 | 63 | // Create vector store index 64 | let index = vector_store.index(embedding_model); 65 | 66 | let results = index 67 | .top_n::<WordDefinition>("I need to buy something in a fictional universe. What type of money can I use for this?", 1) 68 | .await? 69 | .into_iter() 70 | .map(|(score, id, doc)| (score, id, doc.word)) 71 | .collect::<Vec<_>>(); 72 | 73 | println!("Results: {:?}", results); 74 | 75 | let id_results = index 76 | .top_n_ids("I need to buy something in a fictional universe. What type of money can I use for this?", 1) 77 | .await? 78 | .into_iter() 79 | .collect::<Vec<_>>(); 80 | 81 | println!("ID results: {:?}", id_results); 82 | 83 | Ok(()) 84 | } 85 | -------------------------------------------------------------------------------- /rig-lancedb/tests/integration_tests.rs: -------------------------------------------------------------------------------- 1 | use serde_json::json; 2 | 3 | use arrow_array::RecordBatchIterator; 4 | use fixture::{as_record_batch, schema, words, Word}; 5 | use lancedb::index::vector::IvfPqIndexBuilder; 6 | use rig::{ 7 | embeddings::{EmbeddingModel, EmbeddingsBuilder}, 8 | providers::openai::{self, Client}, 9 | vector_store::VectorStoreIndex, 10 | }; 11 | use rig_lancedb::{LanceDbVectorIndex, SearchParams}; 12 | use std::sync::Arc; 13 | 14 | #[path = "./fixtures/lib.rs"] 15 | mod fixture; 16 | 17 | #[tokio::test] 18 | async fn vector_search_test() { 19 | // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). 20 | let openai_client = Client::from_env(); 21 | 22 | // Select an embedding model. 23 | let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); 24 | 25 | // Initialize LanceDB locally. 26 | let db = lancedb::connect("data/lancedb-store") 27 | .execute() 28 | .await 29 | .unwrap(); 30 | 31 | // Generate embeddings for the test data. 32 | let embeddings = EmbeddingsBuilder::new(model.clone()) 33 | .documents(words()).unwrap() 34 | // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. 35 | .documents( 36 | (0..256) 37 | .map(|i| Word { 38 | id: format!("doc{}", i), 39 | definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() 40 | }) 41 | ).unwrap() 42 | .build() 43 | .await.unwrap(); 44 | 45 | let table = db 46 | .create_table( 47 | "words", 48 | RecordBatchIterator::new( 49 | vec![as_record_batch(embeddings, model.ndims())], 50 | Arc::new(schema(model.ndims())), 51 | ), 52 | ) 53 | .execute() 54 | .await 55 | .unwrap(); 56 | 57 | // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information 58 | table 59 | .create_index( 60 | &["embedding"], 61 | lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), 62 | ) 63 | .execute() 64 | .await 65 | .unwrap(); 66 | 67 | // Define search_params params that will be used by the vector store to perform the vector search. 68 | let search_params = SearchParams::default(); 69 | let vector_store_index = LanceDbVectorIndex::new(table, model, "id", search_params) 70 | .await 71 | .unwrap(); 72 | 73 | // Query the index 74 | let results = vector_store_index 75 | .top_n::<serde_json::Value>( 76 | "My boss says I zindle too much, what does that mean.unwrap()", 77 | 1, 78 | ) 79 | .await 80 | .unwrap(); 81 | 82 | let (distance, _, value) = &results.first().unwrap(); 83 | 84 | assert_eq!( 85 | *value, 86 | json!({ 87 | "_distance": distance, 88 | "definition": "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.", 89 | "id": "doc1" 90 | }) 91 | ); 92 | 93 | db.drop_db().await.unwrap(); 94 | } 95 | -------------------------------------------------------------------------------- /rig-neo4j/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | <div style="display: flex; align-items: center; justify-content: center;"> 4 | <picture> 5 | <source media="(prefers-color-scheme: dark)" srcset="../img/rig_logo_dark.svg"> 6 | <source media="(prefers-color-scheme: light)" srcset="../img/rig_logo.svg"> 7 | <img src="../img/rig_logo.svg" width="200" alt="Rig logo"> 8 | </picture> 9 | <span style="font-size: 48px; margin: 0 20px; font-weight: regular; font-family: Open Sans, sans-serif;"> + </span> 10 | <picture> 11 | <source media="(prefers-color-scheme: dark)" srcset="https://cdn.prod.website-files.com/653986a9412d138f23c5b8cb/65c3ee6c93dc929503742ff6_1_E5u7PfGGOQ32_H5dUVGerQ%402x.png"> 12 | <source media="(prefers-color-scheme: light)" srcset="https://commons.wikimedia.org/wiki/File:Neo4j-logo_color.png"> 13 | <img src="https://commons.wikimedia.org/wiki/File:Neo4j-logo_color.png" width="200" alt="Neo4j logo"> 14 | </picture> 15 | 16 | </div> 17 | 18 | <br><br> 19 | 20 | This companion crate implements a Rig vector store based on Neo4j Graph database. It uses the [neo4rs](https://github.com/neo4j-labs/neo4rs) crate to interact with Neo4j. Note that the neo4rs crate is a work in progress and does not yet support all Neo4j features. Further documentation on Neo4j & vector search integration can be found on the [neo4rs docs](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/). 21 | 22 | ## Prerequisites 23 | 24 | The GenAI plugin is enabled by default in Neo4j Aura. 25 | 26 | The plugin needs to be installed on self-managed instances. This is done by moving the neo4j-genai.jar file from /products to /plugins in the Neo4j home directory, or, if you are using Docker, by starting the Docker container with the extra parameter --env NEO4J_PLUGINS='["genai"]'. For more information, see Operations Manual → Configure plugins. 27 | 28 | 29 | ## Usage 30 | 31 | Add the companion crate to your `Cargo.toml`, along with the rig-core crate: 32 | 33 | ```toml 34 | [dependencies] 35 | rig-neo4j = "0.1" 36 | ``` 37 | 38 | You can also run `cargo add rig-neo4j rig-core` to add the most recent versions of the dependencies to your project. 39 | 40 | See the [examples](./examples) folder for usage examples. 41 | 42 | - [examples/vector_search_simple.rs](examples/vector_search_simple.rs) shows how to create an index on simple data. 43 | - [examples/vector_search_movies_consume.rs](examples/vector_search_movies_consume.rs) shows how to query an existing index. 44 | - [examples/vector_search_movies_create.rs](examples/vector_search_movies_create.rs) shows how to create embeddings & index on a large DB and query it in one go. 45 | 46 | ## Notes 47 | 48 | - The `rig-neo4j::vector_index` module offers utility functions to create and query a Neo4j vector index. You can also create indexes using the Neo4j browser or directly call cypther queries with the Neo4rs crate. See the [Neo4j documentation](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/setup/vector-index/) for more information. Example [examples/vector_search_simple.rs](examples/vector_search_simple.rs) shows how to create an index on existing data. 49 | 50 | ```Cypher 51 | CREATE VECTOR INDEX moviePlots 52 | FOR (m:Movie) 53 | ON m.embedding 54 | OPTIONS {indexConfig: { 55 | `vector.dimensions`: 1536, 56 | `vector.similarity_function`: 'cosine' 57 | }} 58 | ``` 59 | 60 | ## Roadmap 61 | 62 | - Add support for creating the vector index through RIG. 63 | - Add support for adding embeddings to an existing database 64 | - Add support for uploading documents to an existing database 65 | -------------------------------------------------------------------------------- /rig-lancedb/examples/vector_search_s3_ann.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use arrow_array::RecordBatchIterator; 4 | use fixture::{as_record_batch, schema, words, Word}; 5 | use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; 6 | use rig::{ 7 | embeddings::{EmbeddingModel, EmbeddingsBuilder}, 8 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 9 | vector_store::VectorStoreIndex, 10 | }; 11 | use rig_lancedb::{LanceDbVectorIndex, SearchParams}; 12 | 13 | #[path = "./fixtures/lib.rs"] 14 | mod fixture; 15 | 16 | // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. 17 | // https://lancedb.github.io/lancedb/guides/storage/ 18 | #[tokio::main] 19 | async fn main() -> Result<(), anyhow::Error> { 20 | // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). 21 | let openai_client = Client::from_env(); 22 | 23 | // Select the embedding model and generate our embeddings 24 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 25 | 26 | // Initialize LanceDB on S3. 27 | // Note: see below docs for more options and IAM permission required to read/write to S3. 28 | // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 29 | let db = lancedb::connect("s3://lancedb-test-829666124233") 30 | .execute() 31 | .await?; 32 | 33 | // Generate embeddings for the test data. 34 | let embeddings = EmbeddingsBuilder::new(model.clone()) 35 | .documents(words())? 36 | // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. 37 | .documents( 38 | (0..256) 39 | .map(|i| Word { 40 | id: format!("doc{}", i), 41 | definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() 42 | }) 43 | )? 44 | .build() 45 | .await?; 46 | 47 | let table = db 48 | .create_table( 49 | "definitions", 50 | RecordBatchIterator::new( 51 | vec![as_record_batch(embeddings, model.ndims())], 52 | Arc::new(schema(model.ndims())), 53 | ), 54 | ) 55 | .execute() 56 | .await?; 57 | 58 | // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information 59 | table 60 | .create_index( 61 | &["embedding"], 62 | lancedb::index::Index::IvfPq( 63 | IvfPqIndexBuilder::default() 64 | // This overrides the default distance type of L2. 65 | // Needs to be the same distance type as the one used in search params. 66 | .distance_type(DistanceType::Cosine), 67 | ), 68 | ) 69 | .execute() 70 | .await?; 71 | 72 | // Define search_params params that will be used by the vector store to perform the vector search. 73 | let search_params = SearchParams::default().distance_type(DistanceType::Cosine); 74 | 75 | let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; 76 | 77 | // Query the index 78 | let results = vector_store 79 | .top_n::<Word>("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) 80 | .await?; 81 | 82 | println!("Results: {:?}", results); 83 | 84 | Ok(()) 85 | } 86 | -------------------------------------------------------------------------------- /rig-core/examples/agent_with_tools.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use rig::{ 3 | completion::{Prompt, ToolDefinition}, 4 | providers, 5 | tool::Tool, 6 | }; 7 | use serde::{Deserialize, Serialize}; 8 | use serde_json::json; 9 | 10 | #[derive(Deserialize)] 11 | struct OperationArgs { 12 | x: i32, 13 | y: i32, 14 | } 15 | 16 | #[derive(Debug, thiserror::Error)] 17 | #[error("Math error")] 18 | struct MathError; 19 | 20 | #[derive(Deserialize, Serialize)] 21 | struct Adder; 22 | impl Tool for Adder { 23 | const NAME: &'static str = "add"; 24 | 25 | type Error = MathError; 26 | type Args = OperationArgs; 27 | type Output = i32; 28 | 29 | async fn definition(&self, _prompt: String) -> ToolDefinition { 30 | ToolDefinition { 31 | name: "add".to_string(), 32 | description: "Add x and y together".to_string(), 33 | parameters: json!({ 34 | "type": "object", 35 | "properties": { 36 | "x": { 37 | "type": "number", 38 | "description": "The first number to add" 39 | }, 40 | "y": { 41 | "type": "number", 42 | "description": "The second number to add" 43 | } 44 | } 45 | }), 46 | } 47 | } 48 | 49 | async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { 50 | let result = args.x + args.y; 51 | Ok(result) 52 | } 53 | } 54 | 55 | #[derive(Deserialize, Serialize)] 56 | struct Subtract; 57 | impl Tool for Subtract { 58 | const NAME: &'static str = "subtract"; 59 | 60 | type Error = MathError; 61 | type Args = OperationArgs; 62 | type Output = i32; 63 | 64 | async fn definition(&self, _prompt: String) -> ToolDefinition { 65 | serde_json::from_value(json!({ 66 | "name": "subtract", 67 | "description": "Subtract y from x (i.e.: x - y)", 68 | "parameters": { 69 | "type": "object", 70 | "properties": { 71 | "x": { 72 | "type": "number", 73 | "description": "The number to substract from" 74 | }, 75 | "y": { 76 | "type": "number", 77 | "description": "The number to substract" 78 | } 79 | } 80 | } 81 | })) 82 | .expect("Tool Definition") 83 | } 84 | 85 | async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { 86 | let result = args.x - args.y; 87 | Ok(result) 88 | } 89 | } 90 | 91 | #[tokio::main] 92 | async fn main() -> Result<(), anyhow::Error> { 93 | // Create OpenAI client 94 | let openai_client = providers::openai::Client::from_env(); 95 | 96 | // Create agent with a single context prompt and two tools 97 | let calculator_agent = openai_client 98 | .agent(providers::openai::GPT_4O) 99 | .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") 100 | .max_tokens(1024) 101 | .tool(Adder) 102 | .tool(Subtract) 103 | .build(); 104 | 105 | // Prompt the agent and print the response 106 | println!("Calculate 2 - 5"); 107 | println!( 108 | "Calculator Agent: {}", 109 | calculator_agent.prompt("Calculate 2 - 5").await? 110 | ); 111 | 112 | Ok(()) 113 | } 114 | -------------------------------------------------------------------------------- /rig-core/src/vector_store/mod.rs: -------------------------------------------------------------------------------- 1 | use futures::future::BoxFuture; 2 | use serde::Deserialize; 3 | use serde_json::Value; 4 | 5 | use crate::embeddings::EmbeddingError; 6 | 7 | pub mod in_memory_store; 8 | 9 | #[derive(Debug, thiserror::Error)] 10 | pub enum VectorStoreError { 11 | #[error("Embedding error: {0}")] 12 | EmbeddingError(#[from] EmbeddingError), 13 | 14 | /// Json error (e.g.: serialization, deserialization, etc.) 15 | #[error("Json error: {0}")] 16 | JsonError(#[from] serde_json::Error), 17 | 18 | #[error("Datastore error: {0}")] 19 | DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>), 20 | 21 | #[error("Missing Id: {0}")] 22 | MissingIdError(String), 23 | } 24 | 25 | /// Trait for vector store indexes 26 | pub trait VectorStoreIndex: Send + Sync { 27 | /// Get the top n documents based on the distance to the given query. 28 | /// The result is a list of tuples of the form (score, id, document) 29 | fn top_n<T: for<'a> Deserialize<'a> + Send>( 30 | &self, 31 | query: &str, 32 | n: usize, 33 | ) -> impl std::future::Future<Output = Result<Vec<(f64, String, T)>, VectorStoreError>> + Send; 34 | 35 | /// Same as `top_n` but returns the document ids only. 36 | fn top_n_ids( 37 | &self, 38 | query: &str, 39 | n: usize, 40 | ) -> impl std::future::Future<Output = Result<Vec<(f64, String)>, VectorStoreError>> + Send; 41 | } 42 | 43 | pub type TopNResults = Result<Vec<(f64, String, Value)>, VectorStoreError>; 44 | 45 | pub trait VectorStoreIndexDyn: Send + Sync { 46 | fn top_n<'a>(&'a self, query: &'a str, n: usize) -> BoxFuture<'a, TopNResults>; 47 | 48 | fn top_n_ids<'a>( 49 | &'a self, 50 | query: &'a str, 51 | n: usize, 52 | ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>>; 53 | } 54 | 55 | impl<I: VectorStoreIndex> VectorStoreIndexDyn for I { 56 | fn top_n<'a>( 57 | &'a self, 58 | query: &'a str, 59 | n: usize, 60 | ) -> BoxFuture<'a, Result<Vec<(f64, String, Value)>, VectorStoreError>> { 61 | Box::pin(async move { 62 | Ok(self 63 | .top_n::<serde_json::Value>(query, n) 64 | .await? 65 | .into_iter() 66 | .map(|(score, id, doc)| (score, id, prune_document(doc).unwrap_or_default())) 67 | .collect::<Vec<_>>()) 68 | }) 69 | } 70 | 71 | fn top_n_ids<'a>( 72 | &'a self, 73 | query: &'a str, 74 | n: usize, 75 | ) -> BoxFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> { 76 | Box::pin(self.top_n_ids(query, n)) 77 | } 78 | } 79 | 80 | fn prune_document(document: serde_json::Value) -> Option<serde_json::Value> { 81 | match document { 82 | Value::Object(mut map) => { 83 | let new_map = map 84 | .iter_mut() 85 | .filter_map(|(key, value)| { 86 | prune_document(value.take()).map(|value| (key.clone(), value)) 87 | }) 88 | .collect::<serde_json::Map<_, _>>(); 89 | 90 | Some(Value::Object(new_map)) 91 | } 92 | Value::Array(vec) if vec.len() > 400 => None, 93 | Value::Array(vec) => Some(Value::Array( 94 | vec.into_iter().filter_map(prune_document).collect(), 95 | )), 96 | Value::Number(num) => Some(Value::Number(num)), 97 | Value::String(s) => Some(Value::String(s)), 98 | Value::Bool(b) => Some(Value::Bool(b)), 99 | Value::Null => Some(Value::Null), 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /rig-core/examples/debate.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use anyhow::Result; 4 | use rig::{ 5 | agent::Agent, 6 | completion::{Chat, Message}, 7 | providers::{cohere, openai}, 8 | }; 9 | 10 | struct Debater { 11 | gpt_4: Agent<openai::CompletionModel>, 12 | coral: Agent<cohere::CompletionModel>, 13 | } 14 | 15 | impl Debater { 16 | fn new(position_a: &str, position_b: &str) -> Self { 17 | let openai_client = 18 | openai::Client::new(&env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")); 19 | let cohere_client = 20 | cohere::Client::new(&env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set")); 21 | 22 | Self { 23 | gpt_4: openai_client.agent("gpt-4").preamble(position_a).build(), 24 | coral: cohere_client 25 | .agent("command-r") 26 | .preamble(position_b) 27 | .build(), 28 | } 29 | } 30 | 31 | async fn rounds(&self, n: usize) -> Result<()> { 32 | let mut history_a: Vec<Message> = vec![]; 33 | let mut history_b: Vec<Message> = vec![]; 34 | 35 | let mut last_resp_b: Option<String> = None; 36 | 37 | for _ in 0..n { 38 | let prompt_a = if let Some(msg_b) = &last_resp_b { 39 | msg_b.clone() 40 | } else { 41 | "Plead your case!".into() 42 | }; 43 | 44 | let resp_a = self.gpt_4.chat(&prompt_a, history_a.clone()).await?; 45 | println!("GPT-4:\n{}", resp_a); 46 | history_a.push(Message { 47 | role: "user".into(), 48 | content: prompt_a.clone(), 49 | }); 50 | history_a.push(Message { 51 | role: "assistant".into(), 52 | content: resp_a.clone(), 53 | }); 54 | println!("================================================================"); 55 | 56 | let resp_b = self.coral.chat(&resp_a, history_b.clone()).await?; 57 | println!("Coral:\n{}", resp_b); 58 | println!("================================================================"); 59 | 60 | history_b.push(Message { 61 | role: "user".into(), 62 | content: resp_a.clone(), 63 | }); 64 | history_b.push(Message { 65 | role: "assistant".into(), 66 | content: resp_b.clone(), 67 | }); 68 | 69 | last_resp_b = Some(resp_b) 70 | } 71 | 72 | Ok(()) 73 | } 74 | } 75 | 76 | #[tokio::main] 77 | async fn main() -> Result<(), anyhow::Error> { 78 | // Create model 79 | let debator = Debater::new( 80 | "\ 81 | You believe that religion is a useful concept. \ 82 | This could be for security, financial, ethical, philosophical, metaphysical, religious or any kind of other reason. \ 83 | You choose what your arguments are. \ 84 | I will argue against you and you must rebuke me and try to convince me that I am wrong. \ 85 | Make your statements short and concise. \ 86 | ", 87 | "\ 88 | You believe that religion is a harmful concept. \ 89 | This could be for security, financial, ethical, philosophical, metaphysical, religious or any kind of other reason. \ 90 | You choose what your arguments are. \ 91 | I will argue against you and you must rebuke me and try to convince me that I am wrong. \ 92 | Make your statements short and concise. \ 93 | ", 94 | ); 95 | 96 | // Run the debate for 4 rounds 97 | debator.rounds(4).await?; 98 | 99 | Ok(()) 100 | } 101 | -------------------------------------------------------------------------------- /rig-core/examples/rag.rs: -------------------------------------------------------------------------------- 1 | use std::{env, vec}; 2 | 3 | use rig::{ 4 | completion::Prompt, 5 | embeddings::EmbeddingsBuilder, 6 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 7 | vector_store::in_memory_store::InMemoryVectorStore, 8 | Embed, 9 | }; 10 | use serde::Serialize; 11 | 12 | // Data to be RAGged. 13 | // A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition` 14 | // and tag that field with `#[embed]`. 15 | #[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] 16 | struct WordDefinition { 17 | id: String, 18 | word: String, 19 | #[embed] 20 | definitions: Vec<String>, 21 | } 22 | 23 | #[tokio::main] 24 | async fn main() -> Result<(), anyhow::Error> { 25 | // Initialize tracing 26 | tracing_subscriber::fmt() 27 | .with_max_level(tracing::Level::INFO) 28 | .with_target(false) 29 | .init(); 30 | 31 | // Create OpenAI client 32 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 33 | let openai_client = Client::new(&openai_api_key); 34 | 35 | let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 36 | 37 | // Generate embeddings for the definitions of all the documents using the specified embedding model. 38 | let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) 39 | .documents(vec![ 40 | WordDefinition { 41 | id: "doc0".to_string(), 42 | word: "flurbo".to_string(), 43 | definitions: vec![ 44 | "1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(), 45 | "2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() 46 | ] 47 | }, 48 | WordDefinition { 49 | id: "doc1".to_string(), 50 | word: "glarb-glarb".to_string(), 51 | definitions: vec![ 52 | "1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 53 | "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() 54 | ] 55 | }, 56 | WordDefinition { 57 | id: "doc2".to_string(), 58 | word: "linglingdong".to_string(), 59 | definitions: vec![ 60 | "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), 61 | "2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() 62 | ] 63 | }, 64 | ])? 65 | .build() 66 | .await?; 67 | 68 | // Create vector store with the embeddings 69 | let vector_store = InMemoryVectorStore::from_documents(embeddings); 70 | 71 | // Create vector store index 72 | let index = vector_store.index(embedding_model); 73 | 74 | let rag_agent = openai_client.agent("gpt-4") 75 | .preamble(" 76 | You are a dictionary assistant here to assist the user in understanding the meaning of words. 77 | You will find additional non-standard word definitions that could be useful below. 78 | ") 79 | .dynamic_context(1, index) 80 | .build(); 81 | 82 | // Prompt the agent and print the response 83 | let response = rag_agent.prompt("What does \"glarb-glarb\" mean?").await?; 84 | 85 | println!("{}", response); 86 | 87 | Ok(()) 88 | } 89 | -------------------------------------------------------------------------------- /rig-qdrant/examples/qdrant_vector_search.rs: -------------------------------------------------------------------------------- 1 | // To run this example: 2 | // 3 | // export OPENAI_API_KEY=<YOUR-API-KEY> 4 | // docker run -p 6333:6333 -p 6334:6334 qdrant/qdrant 5 | // cargo run --release --example qdrant_vector_search 6 | // 7 | // You can view the data at http://localhost:6333/dashboard 8 | 9 | use std::env; 10 | 11 | use qdrant_client::{ 12 | qdrant::{ 13 | CreateCollectionBuilder, Distance, PointStruct, QueryPointsBuilder, UpsertPointsBuilder, 14 | VectorParamsBuilder, 15 | }, 16 | Payload, Qdrant, 17 | }; 18 | use rig::{ 19 | embeddings::EmbeddingsBuilder, 20 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 21 | vector_store::VectorStoreIndex, 22 | Embed, 23 | }; 24 | use rig_qdrant::QdrantVectorStore; 25 | 26 | #[derive(Embed, serde::Deserialize, serde::Serialize, Debug)] 27 | struct Word { 28 | id: String, 29 | #[embed] 30 | definition: String, 31 | } 32 | 33 | #[tokio::main] 34 | async fn main() -> Result<(), anyhow::Error> { 35 | const COLLECTION_NAME: &str = "rig-collection"; 36 | 37 | let client = Qdrant::from_url("http://localhost:6334").build()?; 38 | 39 | // Create a collection with 1536 dimensions if it doesn't exist 40 | // Note: Make sure the dimensions match the size of the embeddings returned by the 41 | // model you are using 42 | if !client.collection_exists(COLLECTION_NAME).await? { 43 | client 44 | .create_collection( 45 | CreateCollectionBuilder::new(COLLECTION_NAME) 46 | .vectors_config(VectorParamsBuilder::new(1536, Distance::Cosine)), 47 | ) 48 | .await?; 49 | } 50 | 51 | // Initialize OpenAI client. 52 | // Get your API key from https://platform.openai.com/api-keys 53 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 54 | let openai_client = Client::new(&openai_api_key); 55 | 56 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 57 | 58 | let documents = EmbeddingsBuilder::new(model.clone()) 59 | .document(Word { 60 | id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), 61 | definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), 62 | })? 63 | .document(Word { 64 | id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), 65 | definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 66 | })? 67 | .document(Word { 68 | id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), 69 | definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), 70 | })? 71 | .build() 72 | .await?; 73 | 74 | let points: Vec<PointStruct> = documents 75 | .into_iter() 76 | .map(|(d, embeddings)| { 77 | let vec: Vec<f32> = embeddings.first().vec.iter().map(|&x| x as f32).collect(); 78 | PointStruct::new( 79 | d.id.clone(), 80 | vec, 81 | Payload::try_from(serde_json::to_value(&d).unwrap()).unwrap(), 82 | ) 83 | }) 84 | .collect(); 85 | 86 | client 87 | .upsert_points(UpsertPointsBuilder::new(COLLECTION_NAME, points)) 88 | .await?; 89 | 90 | let query_params = QueryPointsBuilder::new(COLLECTION_NAME).with_payload(true); 91 | let vector_store = QdrantVectorStore::new(client, model, query_params.build()); 92 | 93 | let results = vector_store 94 | .top_n::<Word>("What is a linglingdong?", 1) 95 | .await?; 96 | 97 | println!("Results: {:?}", results); 98 | 99 | Ok(()) 100 | } 101 | -------------------------------------------------------------------------------- /rig-core/rig-core-derive/src/embed.rs: -------------------------------------------------------------------------------- 1 | use proc_macro2::TokenStream; 2 | use quote::quote; 3 | use syn::DataStruct; 4 | 5 | use crate::{ 6 | basic::{add_struct_bounds, basic_embed_fields}, 7 | custom::custom_embed_fields, 8 | }; 9 | 10 | pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result<TokenStream> { 11 | let name = &input.ident; 12 | let data = &input.data; 13 | let generics = &mut input.generics; 14 | 15 | let target_stream = match data { 16 | syn::Data::Struct(data_struct) => { 17 | let (basic_targets, basic_target_size) = data_struct.basic(generics); 18 | let (custom_targets, custom_target_size) = data_struct.custom()?; 19 | 20 | // If there are no fields tagged with `#[embed]` or `#[embed(embed_with = "...")]`, return an empty TokenStream. 21 | // ie. do not implement `Embed` trait for the struct. 22 | if basic_target_size + custom_target_size == 0 { 23 | return Err(syn::Error::new_spanned( 24 | name, 25 | "Add at least one field tagged with #[embed] or #[embed(embed_with = \"...\")].", 26 | )); 27 | } 28 | 29 | quote! { 30 | #basic_targets; 31 | #custom_targets; 32 | } 33 | } 34 | _ => { 35 | return Err(syn::Error::new_spanned( 36 | input, 37 | "Embed derive macro should only be used on structs", 38 | )) 39 | } 40 | }; 41 | 42 | let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); 43 | 44 | let gen = quote! { 45 | // Note: `Embed` trait is imported with the macro. 46 | 47 | impl #impl_generics Embed for #name #ty_generics #where_clause { 48 | fn embed(&self, embedder: &mut rig::embeddings::embed::TextEmbedder) -> Result<(), rig::embeddings::embed::EmbedError> { 49 | #target_stream; 50 | 51 | Ok(()) 52 | } 53 | } 54 | }; 55 | 56 | Ok(gen) 57 | } 58 | 59 | trait StructParser { 60 | // Handles fields tagged with `#[embed]` 61 | fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); 62 | 63 | // Handles fields tagged with `#[embed(embed_with = "...")]` 64 | fn custom(&self) -> syn::Result<(TokenStream, usize)>; 65 | } 66 | 67 | impl StructParser for DataStruct { 68 | fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) { 69 | let embed_targets = basic_embed_fields(self) 70 | // Iterate over every field tagged with `#[embed]` 71 | .map(|field| { 72 | add_struct_bounds(generics, &field.ty); 73 | 74 | let field_name = &field.ident; 75 | 76 | quote! { 77 | self.#field_name 78 | } 79 | }) 80 | .collect::<Vec<_>>(); 81 | 82 | ( 83 | quote! { 84 | #(#embed_targets.embed(embedder)?;)* 85 | }, 86 | embed_targets.len(), 87 | ) 88 | } 89 | 90 | fn custom(&self) -> syn::Result<(TokenStream, usize)> { 91 | let embed_targets = custom_embed_fields(self)? 92 | // Iterate over every field tagged with `#[embed(embed_with = "...")]` 93 | .into_iter() 94 | .map(|(field, custom_func_path)| { 95 | let field_name = &field.ident; 96 | 97 | quote! { 98 | #custom_func_path(embedder, self.#field_name.clone())?; 99 | } 100 | }) 101 | .collect::<Vec<_>>(); 102 | 103 | Ok(( 104 | quote! { 105 | #(#embed_targets)* 106 | }, 107 | embed_targets.len(), 108 | )) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/joshka/github-workflows/blob/main/.github/workflows/rust-check.yml 2 | name: Lint & Test 3 | 4 | on: 5 | pull_request: 6 | branches: 7 | - "**" 8 | workflow_call: 9 | 10 | env: 11 | CARGO_TERM_COLOR: always 12 | 13 | # ensure that the workflow is only triggered once per PR, subsequent pushes to the PR will cancel 14 | # and restart the workflow. See https://docs.github.com/en/actions/using-jobs/using-concurrency 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | fmt: 21 | name: stable / fmt 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | 27 | - name: Install Rust stable 28 | uses: actions-rust-lang/setup-rust-toolchain@v1 29 | with: 30 | components: rustfmt 31 | 32 | - name: Run cargo fmt 33 | run: cargo fmt -- --check 34 | 35 | clippy: 36 | name: stable / clippy 37 | runs-on: ubuntu-latest 38 | permissions: 39 | checks: write 40 | steps: 41 | - name: Checkout 42 | uses: actions/checkout@v4 43 | 44 | - name: Install Rust stable 45 | uses: actions-rust-lang/setup-rust-toolchain@v1 46 | with: 47 | components: clippy 48 | 49 | # Required to compile rig-lancedb 50 | - name: Install Protoc 51 | uses: arduino/setup-protoc@v3 52 | 53 | - name: Run clippy action 54 | uses: clechasseur/rs-clippy-check@v3 55 | with: 56 | args: --all-features 57 | 58 | test: 59 | name: stable / test 60 | runs-on: ubuntu-latest 61 | steps: 62 | - name: Checkout 63 | uses: actions/checkout@v4 64 | 65 | - name: Install Rust stable 66 | uses: actions-rust-lang/setup-rust-toolchain@v1 67 | 68 | - name: Install nextest 69 | uses: taiki-e/install-action@v2 70 | with: 71 | tool: nextest 72 | 73 | # Required to compile rig-lancedb 74 | - name: Install Protoc 75 | uses: arduino/setup-protoc@v3 76 | 77 | - name: Test with latest nextest release 78 | uses: actions-rs/cargo@v1 79 | with: 80 | command: nextest 81 | args: run --all-features 82 | env: 83 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 84 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 85 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} 86 | COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} 87 | PERPLEXITY_API_KEY: ${{ secrets.PERPLEXITY_API_KEY }} 88 | 89 | doc: 90 | name: stable / doc 91 | runs-on: ubuntu-latest 92 | steps: 93 | - name: Checkout 94 | uses: actions/checkout@v4 95 | 96 | - name: Install Rust stable 97 | uses: actions-rust-lang/setup-rust-toolchain@v1 98 | with: 99 | components: rust-docs 100 | 101 | # Required to compile rig-lancedb 102 | - name: Install Protoc 103 | uses: arduino/setup-protoc@v3 104 | 105 | - name: Run cargo doc 106 | run: cargo doc --no-deps --all-features 107 | env: 108 | RUSTDOCFLAGS: -D warnings 109 | 110 | publish-check: 111 | name: stable / publish dry-run 112 | runs-on: ubuntu-latest 113 | steps: 114 | - name: Checkout 115 | uses: actions/checkout@v4 116 | 117 | - name: Install Rust stable 118 | uses: actions-rust-lang/setup-rust-toolchain@v1 119 | with: 120 | components: rust-docs 121 | 122 | # Required to compile rig-lancedb 123 | - name: Install Protoc 124 | uses: arduino/setup-protoc@v3 125 | 126 | - name: Run cargo publish --dry-run 127 | run: find . | grep -v "target" | grep "\./.*/Cargo\.toml" | xargs -n 1 sh -c 'cargo publish --manifest-path $0 --dry-run || exit 255' -------------------------------------------------------------------------------- /rig-sqlite/examples/vector_search_sqlite.rs: -------------------------------------------------------------------------------- 1 | use rig::{ 2 | embeddings::EmbeddingsBuilder, 3 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 4 | vector_store::VectorStoreIndex, 5 | Embed, 6 | }; 7 | use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; 8 | use rusqlite::ffi::sqlite3_auto_extension; 9 | use serde::Deserialize; 10 | use sqlite_vec::sqlite3_vec_init; 11 | use std::env; 12 | use tokio_rusqlite::Connection; 13 | 14 | #[derive(Embed, Clone, Debug, Deserialize)] 15 | struct Document { 16 | id: String, 17 | #[embed] 18 | content: String, 19 | } 20 | 21 | impl SqliteVectorStoreTable for Document { 22 | fn name() -> &'static str { 23 | "documents" 24 | } 25 | 26 | fn schema() -> Vec<Column> { 27 | vec![ 28 | Column::new("id", "TEXT PRIMARY KEY"), 29 | Column::new("content", "TEXT"), 30 | ] 31 | } 32 | 33 | fn id(&self) -> String { 34 | self.id.clone() 35 | } 36 | 37 | fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> { 38 | vec![ 39 | ("id", Box::new(self.id.clone())), 40 | ("content", Box::new(self.content.clone())), 41 | ] 42 | } 43 | } 44 | 45 | #[tokio::main] 46 | async fn main() -> Result<(), anyhow::Error> { 47 | tracing_subscriber::fmt() 48 | .with_env_filter( 49 | tracing_subscriber::EnvFilter::from_default_env() 50 | .add_directive(tracing::Level::DEBUG.into()), 51 | ) 52 | .init(); 53 | 54 | // Initialize OpenAI client 55 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 56 | let openai_client = Client::new(&openai_api_key); 57 | 58 | // Initialize the `sqlite-vec`extension 59 | // See: https://alexgarcia.xyz/sqlite-vec/rust.html 60 | unsafe { 61 | sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); 62 | } 63 | 64 | // Initialize SQLite connection 65 | let conn = Connection::open("vector_store.db").await?; 66 | 67 | // Select the embedding model and generate our embeddings 68 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 69 | 70 | let documents = vec![ 71 | Document { 72 | id: "doc0".to_string(), 73 | content: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), 74 | }, 75 | Document { 76 | id: "doc1".to_string(), 77 | content: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 78 | }, 79 | Document { 80 | id: "doc2".to_string(), 81 | content: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), 82 | }, 83 | ]; 84 | 85 | let embeddings = EmbeddingsBuilder::new(model.clone()) 86 | .documents(documents)? 87 | .build() 88 | .await?; 89 | 90 | // Initialize SQLite vector store 91 | let vector_store = SqliteVectorStore::new(conn, &model).await?; 92 | 93 | // Add embeddings to vector store 94 | vector_store.add_rows(embeddings).await?; 95 | 96 | // Create a vector index on our vector store 97 | let index = vector_store.index(model); 98 | 99 | // Query the index 100 | let results = index 101 | .top_n::<Document>("What is a linglingdong?", 1) 102 | .await? 103 | .into_iter() 104 | .map(|(score, id, doc)| (score, id, doc)) 105 | .collect::<Vec<_>>(); 106 | 107 | println!("Results: {:?}", results); 108 | 109 | let id_results = index 110 | .top_n_ids("What is a linglingdong?", 1) 111 | .await? 112 | .into_iter() 113 | .collect::<Vec<_>>(); 114 | 115 | println!("ID results: {:?}", id_results); 116 | 117 | Ok(()) 118 | } 119 | -------------------------------------------------------------------------------- /rig-lancedb/src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | mod deserializer; 2 | 3 | use std::sync::Arc; 4 | 5 | use deserializer::RecordBatchDeserializer; 6 | use futures::TryStreamExt; 7 | use lancedb::{ 8 | arrow::arrow_schema::{DataType, Schema}, 9 | query::ExecutableQuery, 10 | }; 11 | use rig::vector_store::VectorStoreError; 12 | 13 | use crate::lancedb_to_rig_error; 14 | 15 | /// Trait that facilitates the conversion of columnar data returned by a lanceDb query to serde_json::Value. 16 | /// Used whenever a lanceDb table is queried. 17 | pub(crate) trait QueryToJson { 18 | async fn execute_query(&self) -> Result<Vec<serde_json::Value>, VectorStoreError>; 19 | } 20 | 21 | impl QueryToJson for lancedb::query::VectorQuery { 22 | async fn execute_query(&self) -> Result<Vec<serde_json::Value>, VectorStoreError> { 23 | let record_batches = self 24 | .execute() 25 | .await 26 | .map_err(lancedb_to_rig_error)? 27 | .try_collect::<Vec<_>>() 28 | .await 29 | .map_err(lancedb_to_rig_error)?; 30 | 31 | record_batches.deserialize() 32 | } 33 | } 34 | 35 | /// Filter out the columns from a table that do not include embeddings. Return the vector of column names. 36 | pub(crate) trait FilterTableColumns { 37 | fn filter_embeddings(self) -> Vec<String>; 38 | } 39 | 40 | impl FilterTableColumns for Arc<Schema> { 41 | fn filter_embeddings(self) -> Vec<String> { 42 | self.fields() 43 | .iter() 44 | .filter_map(|field| match field.data_type() { 45 | DataType::FixedSizeList(inner, ..) => match inner.data_type() { 46 | DataType::Float64 => None, 47 | _ => Some(field.name().to_string()), 48 | }, 49 | _ => Some(field.name().to_string()), 50 | }) 51 | .collect() 52 | } 53 | } 54 | 55 | #[cfg(test)] 56 | mod tests { 57 | use std::sync::Arc; 58 | 59 | use lancedb::arrow::arrow_schema::{DataType, Field, Schema}; 60 | 61 | use super::FilterTableColumns; 62 | 63 | #[tokio::test] 64 | async fn test_column_filtering() { 65 | let field_a = Field::new("id", DataType::Int64, false); 66 | let field_b = Field::new("my_bool", DataType::Boolean, false); 67 | let field_c = Field::new( 68 | "my_embeddings", 69 | DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 10), 70 | false, 71 | ); 72 | let field_d = Field::new( 73 | "my_list", 74 | DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 10), 75 | false, 76 | ); 77 | 78 | let schema = Schema::new(vec![field_a, field_b, field_c, field_d]); 79 | 80 | let columns = Arc::new(schema).filter_embeddings(); 81 | 82 | assert_eq!( 83 | columns, 84 | vec![ 85 | "id".to_string(), 86 | "my_bool".to_string(), 87 | "my_list".to_string() 88 | ] 89 | ) 90 | } 91 | 92 | #[tokio::test] 93 | async fn test_column_filtering_2() { 94 | let field_a = Field::new("id", DataType::Int64, false); 95 | let field_b = Field::new("my_bool", DataType::Boolean, false); 96 | let field_c = Field::new( 97 | "my_embeddings", 98 | DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 10), 99 | false, 100 | ); 101 | let field_d = Field::new( 102 | "my_other_embeddings", 103 | DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 10), 104 | false, 105 | ); 106 | 107 | let schema = Schema::new(vec![field_a, field_b, field_c, field_d]); 108 | 109 | let columns = Arc::new(schema).filter_embeddings(); 110 | 111 | assert_eq!(columns, vec!["id".to_string(), "my_bool".to_string()]) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /rig-mongodb/examples/vector_search_mongodb.rs: -------------------------------------------------------------------------------- 1 | use mongodb::{ 2 | bson::{self, doc}, 3 | options::ClientOptions, 4 | Client as MongoClient, Collection, 5 | }; 6 | use rig::providers::openai::TEXT_EMBEDDING_ADA_002; 7 | use serde::Deserialize; 8 | use std::env; 9 | 10 | use rig::{ 11 | embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, Embed, 12 | }; 13 | use rig_mongodb::{MongoDbVectorIndex, SearchParams}; 14 | 15 | // Shape of data that needs to be RAG'ed. 16 | // The definition field will be used to generate embeddings. 17 | #[derive(Embed, Clone, Deserialize, Debug)] 18 | struct Word { 19 | #[serde(rename = "_id")] 20 | id: String, 21 | #[embed] 22 | definition: String, 23 | } 24 | 25 | #[tokio::main] 26 | async fn main() -> Result<(), anyhow::Error> { 27 | // Initialize OpenAI client 28 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 29 | let openai_client = Client::new(&openai_api_key); 30 | 31 | // Initialize MongoDB client 32 | let mongodb_connection_string = 33 | env::var("MONGODB_CONNECTION_STRING").expect("MONGODB_CONNECTION_STRING not set"); 34 | let options = ClientOptions::parse(mongodb_connection_string) 35 | .await 36 | .expect("MongoDB connection string should be valid"); 37 | 38 | let mongodb_client = 39 | MongoClient::with_options(options).expect("MongoDB client options should be valid"); 40 | 41 | // Initialize MongoDB vector store 42 | let collection: Collection<bson::Document> = mongodb_client 43 | .database("knowledgebase") 44 | .collection("context"); 45 | 46 | // Select the embedding model and generate our embeddings 47 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 48 | 49 | let words = vec![ 50 | Word { 51 | id: "doc0".to_string(), 52 | definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), 53 | }, 54 | Word { 55 | id: "doc1".to_string(), 56 | definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 57 | }, 58 | Word { 59 | id: "doc2".to_string(), 60 | definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), 61 | } 62 | ]; 63 | 64 | let embeddings = EmbeddingsBuilder::new(model.clone()) 65 | .documents(words)? 66 | .build() 67 | .await?; 68 | 69 | let mongo_documents = embeddings 70 | .iter() 71 | .map(|(Word { id, definition, .. }, embedding)| { 72 | doc! { 73 | "id": id.clone(), 74 | "definition": definition.clone(), 75 | "embedding": embedding.first().vec.clone(), 76 | } 77 | }) 78 | .collect::<Vec<_>>(); 79 | 80 | match collection.insert_many(mongo_documents).await { 81 | Ok(_) => println!("Documents added successfully"), 82 | Err(e) => println!("Error adding documents: {:?}", e), 83 | }; 84 | 85 | // Create a vector index on our vector store. 86 | // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. 87 | // IMPORTANT: Reuse the same model that was used to generate the embeddings 88 | let index = 89 | MongoDbVectorIndex::new(collection, model, "vector_index", SearchParams::new()).await?; 90 | 91 | // Query the index 92 | let results = index.top_n::<Word>("What is a linglingdong?", 1).await?; 93 | 94 | println!("Results: {:?}", results); 95 | 96 | let id_results = index 97 | .top_n_ids("What is a linglingdong?", 1) 98 | .await? 99 | .into_iter() 100 | .collect::<Vec<_>>(); 101 | 102 | println!("ID results: {:?}", id_results); 103 | 104 | Ok(()) 105 | } 106 | -------------------------------------------------------------------------------- /rig-neo4j/examples/vector_search_movies_consume.rs: -------------------------------------------------------------------------------- 1 | //! This example demonstrates how to perform a vector search on a Neo4j database. 2 | //! It is based on the [Neo4j Embeddings & Vector Index Tutorial](https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/). 3 | //! The tutorial uses the `recommendations` dataset and the `moviePlots` index, which is created in the tutorial. 4 | //! They both need to be configured and the database running before running this example. 5 | //! 6 | //! Neo4j provides a demo database for the `recommendations` dataset (see [Github Neo4j-Graph-Examples/recommendations](https://github.com/neo4j-graph-examples/recommendations/tree/main?tab=readme-ov-file#setup)). 7 | //! 8 | //! const NEO4J_URI: &str = "neo4j+s://demo.neo4jlabs.com:7687"; 9 | //! const NEO4J_DB: &str = "recommendations"; 10 | //! const NEO4J_USERNAME: &str = "recommendations"; 11 | //! const NEO4J_PASSWORD: &str = "recommendations"; 12 | //! 13 | //! [examples/vector_search_simple.rs](examples/vector_search_simple.rs) provides an example starting from an empty database. 14 | //! [examples/vector_search_movies_add_embeddings.rs](examples/vector_search_movies_add_embeddings.rs) provides an example of 15 | //! how to add embeddings to an existing `recommendations` database. 16 | use neo4rs::ConfigBuilder; 17 | use rig_neo4j::{vector_index::SearchParams, Neo4jClient}; 18 | 19 | use std::env; 20 | 21 | use rig::{ 22 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 23 | vector_store::VectorStoreIndex, 24 | }; 25 | use serde::{Deserialize, Serialize}; 26 | 27 | #[path = "./display/lib.rs"] 28 | mod display; 29 | 30 | #[tokio::main] 31 | async fn main() -> Result<(), anyhow::Error> { 32 | tracing_subscriber::fmt() 33 | .with_max_level(tracing::Level::DEBUG) 34 | .with_target(false) 35 | .init(); 36 | 37 | const INDEX_NAME: &str = "moviePlotsEmbedding"; 38 | 39 | // Initialize OpenAI client 40 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 41 | let openai_client = Client::new(&openai_api_key); 42 | 43 | let neo4j_uri = "neo4j+s://demo.neo4jlabs.com:7687"; 44 | let neo4j_username = "recommendations"; 45 | let neo4j_password = "recommendations"; 46 | 47 | let neo4j_client = Neo4jClient::from_config( 48 | ConfigBuilder::default() 49 | .uri(neo4j_uri) 50 | .user(neo4j_username) 51 | .password(neo4j_password) 52 | .db("recommendations") 53 | .build() 54 | .unwrap(), 55 | ) 56 | .await?; 57 | 58 | // // Select the embedding model and generate our embeddings 59 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 60 | 61 | // Define the properties that will be retrieved from querying the graph nodes 62 | #[derive(Debug, Deserialize, Serialize)] 63 | struct Movie { 64 | title: String, 65 | plot: String, 66 | } 67 | 68 | // Create a vector index on our vector store 69 | // ❗IMPORTANT: Reuse the same model that was used to generate the embeddings 70 | let index = neo4j_client 71 | .get_index( 72 | model, 73 | INDEX_NAME, 74 | SearchParams::new(Some("node.year > 1990".to_string())), 75 | ) 76 | .await?; 77 | 78 | // Query the index 79 | let results = index 80 | .top_n::<Movie>("a historical movie on quebec", 5) 81 | .await? 82 | .into_iter() 83 | .map(|(score, id, doc)| display::SearchResult { 84 | title: doc.title, 85 | id, 86 | description: doc.plot, 87 | score, 88 | }) 89 | .collect::<Vec<_>>(); 90 | 91 | println!("{:#}", display::SearchResults(&results)); 92 | 93 | let id_results = index 94 | .top_n_ids("A movie where the bad guy wins", 1) 95 | .await? 96 | .into_iter() 97 | .map(|(score, id)| (score, id)) 98 | .collect::<Vec<_>>(); 99 | 100 | println!("ID results: {:?}", id_results); 101 | 102 | Ok(()) 103 | } 104 | -------------------------------------------------------------------------------- /rig-core/src/providers/xai/embedding.rs: -------------------------------------------------------------------------------- 1 | // ================================================================ 2 | //! xAI Embeddings Integration 3 | //! From [xAI Reference](https://docs.x.ai/api/endpoints#create-embeddings) 4 | // ================================================================ 5 | 6 | use serde::Deserialize; 7 | use serde_json::json; 8 | 9 | use crate::embeddings::{self, EmbeddingError}; 10 | 11 | use super::{ 12 | client::xai_api_types::{ApiErrorResponse, ApiResponse}, 13 | Client, 14 | }; 15 | 16 | // ================================================================ 17 | // xAI Embedding API 18 | // ================================================================ 19 | /// `v1` embedding model 20 | pub const EMBEDDING_V1: &str = "v1"; 21 | 22 | #[derive(Debug, Deserialize)] 23 | pub struct EmbeddingResponse { 24 | pub object: String, 25 | pub data: Vec<EmbeddingData>, 26 | pub model: String, 27 | pub usage: Usage, 28 | } 29 | 30 | impl From<ApiErrorResponse> for EmbeddingError { 31 | fn from(err: ApiErrorResponse) -> Self { 32 | EmbeddingError::ProviderError(err.message()) 33 | } 34 | } 35 | 36 | impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> { 37 | fn from(value: ApiResponse<EmbeddingResponse>) -> Self { 38 | match value { 39 | ApiResponse::Ok(response) => Ok(response), 40 | ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())), 41 | } 42 | } 43 | } 44 | 45 | #[derive(Debug, Deserialize)] 46 | pub struct EmbeddingData { 47 | pub object: String, 48 | pub embedding: Vec<f64>, 49 | pub index: usize, 50 | } 51 | 52 | #[derive(Debug, Deserialize)] 53 | pub struct Usage { 54 | pub prompt_tokens: usize, 55 | pub total_tokens: usize, 56 | } 57 | 58 | #[derive(Clone)] 59 | pub struct EmbeddingModel { 60 | client: Client, 61 | pub model: String, 62 | ndims: usize, 63 | } 64 | 65 | impl embeddings::EmbeddingModel for EmbeddingModel { 66 | const MAX_DOCUMENTS: usize = 1024; 67 | 68 | fn ndims(&self) -> usize { 69 | self.ndims 70 | } 71 | 72 | async fn embed_texts( 73 | &self, 74 | documents: impl IntoIterator<Item = String>, 75 | ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> { 76 | let documents = documents.into_iter().collect::<Vec<_>>(); 77 | 78 | let response = self 79 | .client 80 | .post("/v1/embeddings") 81 | .json(&json!({ 82 | "model": self.model, 83 | "input": documents, 84 | })) 85 | .send() 86 | .await?; 87 | 88 | if response.status().is_success() { 89 | match response.json::<ApiResponse<EmbeddingResponse>>().await? { 90 | ApiResponse::Ok(response) => { 91 | if response.data.len() != documents.len() { 92 | return Err(EmbeddingError::ResponseError( 93 | "Response data length does not match input length".into(), 94 | )); 95 | } 96 | 97 | Ok(response 98 | .data 99 | .into_iter() 100 | .zip(documents.into_iter()) 101 | .map(|(embedding, document)| embeddings::Embedding { 102 | document, 103 | vec: embedding.embedding, 104 | }) 105 | .collect()) 106 | } 107 | ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())), 108 | } 109 | } else { 110 | Err(EmbeddingError::ProviderError(response.text().await?)) 111 | } 112 | } 113 | } 114 | 115 | impl EmbeddingModel { 116 | pub fn new(client: Client, model: &str, ndims: usize) -> Self { 117 | Self { 118 | client, 119 | model: model.to_string(), 120 | ndims, 121 | } 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /rig-core/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Rig is a Rust library for building LLM-powered applications that focuses on ergonomics and modularity. 2 | //! 3 | //! # Table of contents 4 | //! - [High-level features](#high-level-features) 5 | //! - [Simple Example](#simple-example) 6 | //! - [Core Concepts](#core-concepts) 7 | //! - [Integrations](#integrations) 8 | //! 9 | //! # High-level features 10 | //! - Full support for LLM completion and embedding workflows 11 | //! - Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, in-memory) 12 | //! - Integrate LLMs in your app with minimal boilerplate 13 | //! 14 | //! # Simple example: 15 | //! ``` 16 | //! use rig::{completion::Prompt, providers::openai}; 17 | //! 18 | //! #[tokio::main] 19 | //! async fn main() { 20 | //! // Create OpenAI client and agent. 21 | //! // This requires the `OPENAI_API_KEY` environment variable to be set. 22 | //! let openai_client = openai::Client::from_env(); 23 | //! 24 | //! let gpt4 = openai_client.agent("gpt-4").build(); 25 | //! 26 | //! // Prompt the model and print its response 27 | //! let response = gpt4 28 | //! .prompt("Who are you?") 29 | //! .await 30 | //! .expect("Failed to prompt GPT-4"); 31 | //! 32 | //! println!("GPT-4: {response}"); 33 | //! } 34 | //! ``` 35 | //! Note: using `#[tokio::main]` requires you enable tokio's `macros` and `rt-multi-thread` features 36 | //! or just `full` to enable all features (`cargo add tokio --features macros,rt-multi-thread`). 37 | //! 38 | //! # Core concepts 39 | //! ## Completion and embedding models 40 | //! Rig provides a consistent API for working with LLMs and embeddings. Specifically, 41 | //! each provider (e.g. OpenAI, Cohere) has a `Client` struct that can be used to initialize completion 42 | //! and embedding models. These models implement the [CompletionModel](crate::completion::CompletionModel) 43 | //! and [EmbeddingModel](crate::embeddings::EmbeddingModel) traits respectively, which provide a common, 44 | //! low-level interface for creating completion and embedding requests and executing them. 45 | //! 46 | //! ## Agents 47 | //! Rig also provides high-level abstractions over LLMs in the form of the [Agent](crate::agent::Agent) type. 48 | //! 49 | //! The [Agent](crate::agent::Agent) type can be used to create anything from simple agents that use vanilla models to full blown 50 | //! RAG systems that can be used to answer questions using a knowledge base. 51 | //! 52 | //! ## Vector stores and indexes 53 | //! Rig provides a common interface for working with vector stores and indexes. Specifically, the library 54 | //! provides the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) 55 | //! trait, which can be implemented to define vector stores and indices respectively. 56 | //! Those can then be used as the knowledge base for a RAG enabled [Agent](crate::agent::Agent), or 57 | //! as a source of context documents in a custom architecture that use multiple LLMs or agents. 58 | //! 59 | //! # Integrations 60 | //! ## Model Providers 61 | //! Rig natively supports the following completion and embedding model provider integrations: 62 | //! - OpenAI 63 | //! - Cohere 64 | //! - Anthropic 65 | //! - Perplexity 66 | //! - Gemini 67 | //! 68 | //! You can also implement your own model provider integration by defining types that 69 | //! implement the [CompletionModel](crate::completion::CompletionModel) and [EmbeddingModel](crate::embeddings::EmbeddingModel) traits. 70 | //! 71 | //! ## Vector Stores 72 | //! Rig currently supports the following vector store integrations via companion crates: 73 | //! - `rig-mongodb`: Vector store implementation for MongoDB 74 | //! - `rig-lancedb`: Vector store implementation for LanceDB 75 | //! - `rig-neo4j`: Vector store implementation for Neo4j 76 | //! - `rig-qdrant`: Vector store implementation for Qdrant 77 | //! 78 | //! You can also implement your own vector store integration by defining types that 79 | //! implement the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) trait. 80 | 81 | pub mod agent; 82 | pub mod cli_chatbot; 83 | pub mod completion; 84 | pub mod embeddings; 85 | pub mod extractor; 86 | pub(crate) mod json_utils; 87 | pub mod loaders; 88 | pub mod one_or_many; 89 | pub mod providers; 90 | pub mod tool; 91 | pub mod vector_store; 92 | 93 | // Re-export commonly used types and traits 94 | pub use embeddings::Embed; 95 | pub use one_or_many::{EmptyListError, OneOrMany}; 96 | 97 | #[cfg(feature = "derive")] 98 | pub use rig_derive::Embed; 99 | -------------------------------------------------------------------------------- /rig-core/rig-core-derive/src/custom.rs: -------------------------------------------------------------------------------- 1 | use quote::ToTokens; 2 | use syn::{meta::ParseNestedMeta, ExprPath}; 3 | 4 | use crate::EMBED; 5 | 6 | const EMBED_WITH: &str = "embed_with"; 7 | 8 | /// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. 9 | /// Also returns the "..." part of the tag (ie. the custom function). 10 | pub(crate) fn custom_embed_fields( 11 | data_struct: &syn::DataStruct, 12 | ) -> syn::Result<Vec<(&syn::Field, syn::ExprPath)>> { 13 | data_struct 14 | .fields 15 | .iter() 16 | .filter_map(|field| { 17 | field 18 | .attrs 19 | .iter() 20 | .filter_map(|attribute| match attribute.is_custom() { 21 | Ok(true) => match attribute.expand_tag() { 22 | Ok(path) => Some(Ok((field, path))), 23 | Err(e) => Some(Err(e)), 24 | }, 25 | Ok(false) => None, 26 | Err(e) => Some(Err(e)), 27 | }) 28 | .next() 29 | }) 30 | .collect::<Result<Vec<_>, _>>() 31 | } 32 | 33 | trait CustomAttributeParser { 34 | // Determine if field is tagged with an #[embed(embed_with = "...")] attribute. 35 | fn is_custom(&self) -> syn::Result<bool>; 36 | 37 | // Get the "..." part of the #[embed(embed_with = "...")] attribute. 38 | // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". 39 | fn expand_tag(&self) -> syn::Result<syn::ExprPath>; 40 | } 41 | 42 | impl CustomAttributeParser for syn::Attribute { 43 | fn is_custom(&self) -> syn::Result<bool> { 44 | // Check that the attribute is a list. 45 | match &self.meta { 46 | syn::Meta::List(meta) => { 47 | if meta.tokens.is_empty() { 48 | return Ok(false); 49 | } 50 | } 51 | _ => return Ok(false), 52 | }; 53 | 54 | // Check the first attribute tag (the first "embed") 55 | if !self.path().is_ident(EMBED) { 56 | return Ok(false); 57 | } 58 | 59 | self.parse_nested_meta(|meta| { 60 | // Parse the meta attribute as an expression. Need this to compile. 61 | meta.value()?.parse::<syn::Expr>()?; 62 | 63 | if meta.path.is_ident(EMBED_WITH) { 64 | Ok(()) 65 | } else { 66 | let path = meta.path.to_token_stream().to_string().replace(' ', ""); 67 | Err(syn::Error::new_spanned( 68 | meta.path, 69 | format_args!("unknown embedding field attribute `{}`", path), 70 | )) 71 | } 72 | })?; 73 | 74 | Ok(true) 75 | } 76 | 77 | fn expand_tag(&self) -> syn::Result<syn::ExprPath> { 78 | fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result<ExprPath> { 79 | // #[embed(embed_with = "...")] 80 | let expr = meta.value()?.parse::<syn::Expr>().unwrap(); 81 | let mut value = &expr; 82 | while let syn::Expr::Group(e) = value { 83 | value = &e.expr; 84 | } 85 | let string = if let syn::Expr::Lit(syn::ExprLit { 86 | lit: syn::Lit::Str(lit_str), 87 | .. 88 | }) = value 89 | { 90 | let suffix = lit_str.suffix(); 91 | if !suffix.is_empty() { 92 | return Err(syn::Error::new_spanned( 93 | lit_str, 94 | format!("unexpected suffix `{}` on string literal", suffix), 95 | )); 96 | } 97 | lit_str.clone() 98 | } else { 99 | return Err(syn::Error::new_spanned( 100 | value, 101 | format!( 102 | "expected {} attribute to be a string: `{} = \"...\"`", 103 | EMBED_WITH, EMBED_WITH 104 | ), 105 | )); 106 | }; 107 | 108 | string.parse() 109 | } 110 | 111 | let mut custom_func_path = None; 112 | 113 | self.parse_nested_meta(|meta| match function_path(&meta) { 114 | Ok(path) => { 115 | custom_func_path = Some(path); 116 | Ok(()) 117 | } 118 | Err(e) => Err(e), 119 | })?; 120 | 121 | Ok(custom_func_path.unwrap()) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /img/rig_logo.svg: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8"?> 2 | <svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1899 830"> 3 | <defs> 4 | <style> 5 | .cls-1 { 6 | fill: black; 7 | stroke-width: 10px; 8 | } 9 | </style> 10 | </defs> 11 | <g> 12 | <path class="cls-1" 13 | d="M613.5,698h-154.5c-3.2,0-6.3-1.6-8.1-4.3l-93.7-135.7c-1.8-2.6-4.7-4.2-7.9-4.3l-22.9-.6c-5.3-.1-9.5-4.4-9.6-9.7l-1.8-110.3c0-5.5,4.3-10,9.8-10l57.8-.4c.4,0,.8,0,1.2,0,6.9-.9,71.8-11.5,72.7-76.5,1.1-76.8-68.7-85.6-74.9-86.2-.3,0-.6,0-.9,0l-57.5-.7c-5.4,0-9.7-4.4-9.7-9.8l-.4-98c0-11.5,9.3-20.8,20.8-20.8h80.3c.2,0,.4,0,.6,0,2.2.1,4.4.3,6.7.4,87.9,5.7,203.3,72.7,203.3,209.4s-74.2,164.3-100.9,179.5c-5,2.8-6.4,9.3-3.2,14.1l101.1,148.6c4.5,6.5-.2,15.4-8.2,15.4Z" /> 14 | <path class="cls-1" 15 | d="M952.7,150.3v527.8c0,10.9-8.9,19.8-19.8,19.8h-110.8c-11.5,0-20.8-9.3-20.8-20.8V150.3c0-10.9,8.9-19.8,19.8-19.8h111.8c10.9,0,19.8,8.9,19.8,19.8Z" /> 16 | <rect class="cls-1" x="98.7" y="130.6" width="151.3" height="567.4" rx="20.8" ry="20.8" /> 17 | </g> 18 | <path class="cls-1" 19 | d="M1821.1,424.7c0-6-1.4-11.6-3.8-16.6-5.5-11.8-16.6-19.9-29.5-19.9h-130.7c-.7,0-1.5.1-2.1.2-.7,0-1.4-.2-2.1-.2h-138.6c-11,0-19.9,8.9-19.9,19.9v68.4c0,11,8.9,19.9,19.9,19.9h101.5c-7.3,13.2-16.5,25.7-27.7,37-47,47.7-115.2,61.7-174.8,41.9-23.9-8-46.3-21.3-65.5-40.2-63.3-62.4-67.4-162.3-11.5-229.5,3-3.7,6.3-7.2,9.7-10.7,65.8-66.9,173.4-67.7,240.3-1.8,7.7,7.7,14.6,15.9,20.6,24.6,3.7,5.4,9.9,8.6,16.5,8.6h99.2c14.5,0,24.4-14.8,18.7-28.2-5.5-13-11.9-25.6-19.3-37.8,0,0,0,0,0,0l43.2-43.9c8.4-8.6,8.3-22.4-.2-30.9l-63.1-62.2-13.4-13.2c-8.6-8.4-22.4-8.3-30.9.2l-12.8,13-31.2,31.7c-22.5-12.8-46.3-22.4-70.8-28.8h0s0-2.9,0-2.9l-.5-62c0-12-10-21.8-22-21.7l-107.3.8c-12,0-21.8,10-21.7,22l.5,60.9v2.7c-26.1,6.8-51.5,17.1-75.3,31.1l-34.4-33.8-9.2-9.1c-13.1-12.9-34.3-12.7-47.2.3l-8.6,8.8-31.3,31.8-19.2,19.5c-12.9,13.1-12.7,34.3.3,47.2l42.9,42.3c-11.3,19.3-20.2,39.5-26.8,60.3h0c-8.2-2.7-17.7-.9-24.2,5.7-2.7,2.7-4.5,6-5.5,9.4l-30.3.3c-18.4.1-33.3,15.2-33.1,33.6l.6,84.3c.1,18.4,15.3,33.3,33.6,33.1l55.1-.4c6.3,24,15.6,47.3,28.1,69.5h0c-3.8,6-4.5,13.2-2.4,19.7l-28.4,28.9c-12.9,13.1-12.8,34.3.3,47.2l52.4,51.6,7.7,7.5c13.1,12.9,34.3,12.7,47.2-.3l7.1-7.2,31.9-32.3s0,0,0,0c18.7,10.3,38.1,18.5,58.1,24.4,5,1.5,10,2.9,15,4.1v3.8s.5,61.6.5,61.6c0,12,10,21.7,22,21.7l107.4-.8c12,0,21.7-10,21.7-22l-.4-60.4v-6.2c25.6-7.1,50.4-17.7,73.8-31.9h0l38.6,38.1,7.9,7.7c8.6,8.4,22.4,8.3,30.9-.2l7.4-7.5,67.9-69c8.4-8.5,8.3-22.4-.2-30.9l-47.3-46.6c11.8-20.6,20.7-42.3,27.1-64.6h2.5s32.4-.3,32.4-.3c18.2-.1,33.1-16.8,33-37.1l-.3-34ZM1101.1,453.1l-.6-72.9c-.1-14.5,11.7-26.5,26.2-26.6l26.6-.3c1.1,3.5,3,6.8,5.8,9.6.9.9,1.9,1.7,3,2.4l-49.1,109.5c-7.1-4.7-11.8-12.6-11.9-21.7ZM1178,479.9h-1.3c0,0-49-.6-49-.6-1.7,0-3.4-.1-5-.5l49.2-109.6c2,.3,4.1.3,6.2.1v110.6ZM1308.4,504.7l2.7,37.4c-1.5.2-2.9.6-4.3,1.1l-62.9-108.9c3-3.9,4.7-8.5,4.9-13.2h36c1,29,8.8,57.7,23.6,83.6ZM1268.4,307.8c1.1,1.1,2.4,2.1,3.7,3l-40.6,87c-1.9-.5-3.8-.8-5.7-.8l-4.5-123.8,41.5,10.7c-2.8,8.2-.9,17.5,5.6,24ZM1279.5,315.9v96.3h-32c-1.1-3.2-2.9-6.1-5.5-8.6-.8-.8-1.6-1.4-2.5-2l40-85.6ZM1237,440.3l62.1,107.5c-.4.3-.8.6-1.1.9-3.4,3.5-5.5,7.8-6.3,12.2l-43.5,1.9c-1.1-2.6-2.7-5.1-4.9-7.2-3.4-3.4-7.6-5.4-12-6.2l-3.8-106.3c3.3-.2,6.6-1.2,9.6-2.9ZM1288.5,377.9v-63.9c4.7-.8,9.2-2.9,12.8-6.6,3.4-3.5,5.5-7.7,6.3-12.2l18.5,4.8c-19.2,23.1-31.8,49.9-37.6,77.9ZM1206.2,244.9c-.3-.3-.5-.5-.6-.8h0s-25.6-25.2-25.6-25.2c-10.4-10.2-10.5-27-.3-37.3l16.2-16.5,1.3-1.3,50,49.3s0,0,0,0l9.6,9.4s0,0,0,0l9.8,9.7,58.9,56.9-18.8-4.8c-1.1-3.5-3-6.9-5.8-9.7-9-8.9-23.4-8.9-32.4-.1l-43.6-11.2-18.8-18.3ZM1189.3,481.2v-116.3c.9-.7,1.9-1.5,2.7-2.3,8.3-8.3,8.8-21.3,1.8-30.4,4.8-15.9,11.2-31.2,18.9-46.2l4.1,112.5c-2.8,1.2-5.4,2.9-7.7,5.2-9,9.1-8.9,23.9.2,32.9,2.6,2.6,5.7,4.3,9,5.4l3.9,107.5c-2,.4-4,1.1-5.8,2-12.3-22.4-21.3-46.1-27.1-70.4ZM1190.2,614l22.9-23.2c2.6,1.9,5.4,3.3,8.4,4,0,0,.1,0,.2,0,1.3.3,2.7.5,4,.6.6,0,1.2,0,1.7,0,1.1,0,2.2-.1,3.3-.3h0c1,0,2-.3,3.1-.7,6.6-2.1,9.7-6.3,12.7-9.5,2.6-4,3.8-8.6,3.7-13.2l41.6-1.8c.9,4.3,3,8.4,6.4,11.7,1.7,1.7,3.6,2.9,5.6,3.9l-13.7,42.4c-6.8-.7-13.9,1.4-19,6.7-2.1,2.2-3.7,4.7-4.8,7.3l-83.4-7.1c-.6-7.4,1.8-15.1,7.4-20.8ZM1248.7,707.1c-2.3-1.2-4.4-2.7-6.3-4.5l-52-51.2c-2.2-2.2-3.9-4.7-5.2-7.3l79.1,6.8c0,6.1,2.2,12.1,6.9,16.7,1.5,1.4,3.1,2.6,4.8,3.5l-12.6,38.9c-5,.5-10.2-.5-14.7-2.9ZM1313.6,669.3l-33.8,33c-1.8,1.8-3.7,3.2-5.8,4.4l10.6-32.8c7,.9,14.3-1.3,19.6-6.7,9-9.1,8.9-23.9-.2-32.9-1.6-1.6-3.4-2.8-5.3-3.8l13.8-42.5c.7,0,1.3.2,2,.2l5.4,75.5c-1.8,1.7-3.9,3.5-6.3,5.5h0ZM1386,687.7c-.6-.2-1.1-.3-1.6-.5-17.4-5.2-34.5-12.2-50.9-20.8l62.7-19.7-10.2,41ZM1398.8,636.6l-70.3,22.1-5.2-72c2.9-1.2,5.5-2.9,7.8-5.2,9-9.1,8.9-23.9-.2-32.9-3.1-3-6.8-5-10.8-5.9l-1.6-22c5.9,8.2,12.5,16,20,23.4,20.9,20.6,45.6,35.1,71.7,43.4l.8.3-12.3,49Z" /> 20 | </svg> 21 | -------------------------------------------------------------------------------- /img/rig_logo_dark.svg: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8"?> 2 | <svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1899 830"> 3 | <defs> 4 | <style> 5 | .cls-1 { 6 | fill: white; 7 | stroke-width: 0px; 8 | } 9 | </style> 10 | </defs> 11 | <g> 12 | <path class="cls-1" 13 | d="M613.5,698h-154.5c-3.2,0-6.3-1.6-8.1-4.3l-93.7-135.7c-1.8-2.6-4.7-4.2-7.9-4.3l-22.9-.6c-5.3-.1-9.5-4.4-9.6-9.7l-1.8-110.3c0-5.5,4.3-10,9.8-10l57.8-.4c.4,0,.8,0,1.2,0,6.9-.9,71.8-11.5,72.7-76.5,1.1-76.8-68.7-85.6-74.9-86.2-.3,0-.6,0-.9,0l-57.5-.7c-5.4,0-9.7-4.4-9.7-9.8l-.4-98c0-11.5,9.3-20.8,20.8-20.8h80.3c.2,0,.4,0,.6,0,2.2.1,4.4.3,6.7.4,87.9,5.7,203.3,72.7,203.3,209.4s-74.2,164.3-100.9,179.5c-5,2.8-6.4,9.3-3.2,14.1l101.1,148.6c4.5,6.5-.2,15.4-8.2,15.4Z" /> 14 | <path class="cls-1" 15 | d="M952.7,150.3v527.8c0,10.9-8.9,19.8-19.8,19.8h-110.8c-11.5,0-20.8-9.3-20.8-20.8V150.3c0-10.9,8.9-19.8,19.8-19.8h111.8c10.9,0,19.8,8.9,19.8,19.8Z" /> 16 | <rect class="cls-1" x="98.7" y="130.6" width="151.3" height="567.4" rx="20.8" ry="20.8" /> 17 | </g> 18 | <path class="cls-1" 19 | d="M1821.1,424.7c0-6-1.4-11.6-3.8-16.6-5.5-11.8-16.6-19.9-29.5-19.9h-130.7c-.7,0-1.5.1-2.1.2-.7,0-1.4-.2-2.1-.2h-138.6c-11,0-19.9,8.9-19.9,19.9v68.4c0,11,8.9,19.9,19.9,19.9h101.5c-7.3,13.2-16.5,25.7-27.7,37-47,47.7-115.2,61.7-174.8,41.9-23.9-8-46.3-21.3-65.5-40.2-63.3-62.4-67.4-162.3-11.5-229.5,3-3.7,6.3-7.2,9.7-10.7,65.8-66.9,173.4-67.7,240.3-1.8,7.7,7.7,14.6,15.9,20.6,24.6,3.7,5.4,9.9,8.6,16.5,8.6h99.2c14.5,0,24.4-14.8,18.7-28.2-5.5-13-11.9-25.6-19.3-37.8,0,0,0,0,0,0l43.2-43.9c8.4-8.6,8.3-22.4-.2-30.9l-63.1-62.2-13.4-13.2c-8.6-8.4-22.4-8.3-30.9.2l-12.8,13-31.2,31.7c-22.5-12.8-46.3-22.4-70.8-28.8h0s0-2.9,0-2.9l-.5-62c0-12-10-21.8-22-21.7l-107.3.8c-12,0-21.8,10-21.7,22l.5,60.9v2.7c-26.1,6.8-51.5,17.1-75.3,31.1l-34.4-33.8-9.2-9.1c-13.1-12.9-34.3-12.7-47.2.3l-8.6,8.8-31.3,31.8-19.2,19.5c-12.9,13.1-12.7,34.3.3,47.2l42.9,42.3c-11.3,19.3-20.2,39.5-26.8,60.3h0c-8.2-2.7-17.7-.9-24.2,5.7-2.7,2.7-4.5,6-5.5,9.4l-30.3.3c-18.4.1-33.3,15.2-33.1,33.6l.6,84.3c.1,18.4,15.3,33.3,33.6,33.1l55.1-.4c6.3,24,15.6,47.3,28.1,69.5h0c-3.8,6-4.5,13.2-2.4,19.7l-28.4,28.9c-12.9,13.1-12.8,34.3.3,47.2l52.4,51.6,7.7,7.5c13.1,12.9,34.3,12.7,47.2-.3l7.1-7.2,31.9-32.3s0,0,0,0c18.7,10.3,38.1,18.5,58.1,24.4,5,1.5,10,2.9,15,4.1v3.8s.5,61.6.5,61.6c0,12,10,21.7,22,21.7l107.4-.8c12,0,21.7-10,21.7-22l-.4-60.4v-6.2c25.6-7.1,50.4-17.7,73.8-31.9h0l38.6,38.1,7.9,7.7c8.6,8.4,22.4,8.3,30.9-.2l7.4-7.5,67.9-69c8.4-8.5,8.3-22.4-.2-30.9l-47.3-46.6c11.8-20.6,20.7-42.3,27.1-64.6h2.5s32.4-.3,32.4-.3c18.2-.1,33.1-16.8,33-37.1l-.3-34ZM1101.1,453.1l-.6-72.9c-.1-14.5,11.7-26.5,26.2-26.6l26.6-.3c1.1,3.5,3,6.8,5.8,9.6.9.9,1.9,1.7,3,2.4l-49.1,109.5c-7.1-4.7-11.8-12.6-11.9-21.7ZM1178,479.9h-1.3c0,0-49-.6-49-.6-1.7,0-3.4-.1-5-.5l49.2-109.6c2,.3,4.1.3,6.2.1v110.6ZM1308.4,504.7l2.7,37.4c-1.5.2-2.9.6-4.3,1.1l-62.9-108.9c3-3.9,4.7-8.5,4.9-13.2h36c1,29,8.8,57.7,23.6,83.6ZM1268.4,307.8c1.1,1.1,2.4,2.1,3.7,3l-40.6,87c-1.9-.5-3.8-.8-5.7-.8l-4.5-123.8,41.5,10.7c-2.8,8.2-.9,17.5,5.6,24ZM1279.5,315.9v96.3h-32c-1.1-3.2-2.9-6.1-5.5-8.6-.8-.8-1.6-1.4-2.5-2l40-85.6ZM1237,440.3l62.1,107.5c-.4.3-.8.6-1.1.9-3.4,3.5-5.5,7.8-6.3,12.2l-43.5,1.9c-1.1-2.6-2.7-5.1-4.9-7.2-3.4-3.4-7.6-5.4-12-6.2l-3.8-106.3c3.3-.2,6.6-1.2,9.6-2.9ZM1288.5,377.9v-63.9c4.7-.8,9.2-2.9,12.8-6.6,3.4-3.5,5.5-7.7,6.3-12.2l18.5,4.8c-19.2,23.1-31.8,49.9-37.6,77.9ZM1206.2,244.9c-.3-.3-.5-.5-.6-.8h0s-25.6-25.2-25.6-25.2c-10.4-10.2-10.5-27-.3-37.3l16.2-16.5,1.3-1.3,50,49.3s0,0,0,0l9.6,9.4s0,0,0,0l9.8,9.7,58.9,56.9-18.8-4.8c-1.1-3.5-3-6.9-5.8-9.7-9-8.9-23.4-8.9-32.4-.1l-43.6-11.2-18.8-18.3ZM1189.3,481.2v-116.3c.9-.7,1.9-1.5,2.7-2.3,8.3-8.3,8.8-21.3,1.8-30.4,4.8-15.9,11.2-31.2,18.9-46.2l4.1,112.5c-2.8,1.2-5.4,2.9-7.7,5.2-9,9.1-8.9,23.9.2,32.9,2.6,2.6,5.7,4.3,9,5.4l3.9,107.5c-2,.4-4,1.1-5.8,2-12.3-22.4-21.3-46.1-27.1-70.4ZM1190.2,614l22.9-23.2c2.6,1.9,5.4,3.3,8.4,4,0,0,.1,0,.2,0,1.3.3,2.7.5,4,.6.6,0,1.2,0,1.7,0,1.1,0,2.2-.1,3.3-.3h0c1,0,2-.3,3.1-.7,6.6-2.1,9.7-6.3,12.7-9.5,2.6-4,3.8-8.6,3.7-13.2l41.6-1.8c.9,4.3,3,8.4,6.4,11.7,1.7,1.7,3.6,2.9,5.6,3.9l-13.7,42.4c-6.8-.7-13.9,1.4-19,6.7-2.1,2.2-3.7,4.7-4.8,7.3l-83.4-7.1c-.6-7.4,1.8-15.1,7.4-20.8ZM1248.7,707.1c-2.3-1.2-4.4-2.7-6.3-4.5l-52-51.2c-2.2-2.2-3.9-4.7-5.2-7.3l79.1,6.8c0,6.1,2.2,12.1,6.9,16.7,1.5,1.4,3.1,2.6,4.8,3.5l-12.6,38.9c-5,.5-10.2-.5-14.7-2.9ZM1313.6,669.3l-33.8,33c-1.8,1.8-3.7,3.2-5.8,4.4l10.6-32.8c7,.9,14.3-1.3,19.6-6.7,9-9.1,8.9-23.9-.2-32.9-1.6-1.6-3.4-2.8-5.3-3.8l13.8-42.5c.7,0,1.3.2,2,.2l5.4,75.5c-1.8,1.7-3.9,3.5-6.3,5.5h0ZM1386,687.7c-.6-.2-1.1-.3-1.6-.5-17.4-5.2-34.5-12.2-50.9-20.8l62.7-19.7-10.2,41ZM1398.8,636.6l-70.3,22.1-5.2-72c2.9-1.2,5.5-2.9,7.8-5.2,9-9.1,8.9-23.9-.2-32.9-3.1-3-6.8-5-10.8-5.9l-1.6-22c5.9,8.2,12.5,16,20,23.4,20.9,20.6,45.6,35.1,71.7,43.4l.8.3-12.3,49Z" /> 20 | </svg> 21 | -------------------------------------------------------------------------------- /rig-qdrant/tests/integration_tests.rs: -------------------------------------------------------------------------------- 1 | use testcontainers::{ 2 | core::{IntoContainerPort, WaitFor}, 3 | runners::AsyncRunner, 4 | GenericImage, 5 | }; 6 | 7 | use qdrant_client::{ 8 | qdrant::{ 9 | CreateCollectionBuilder, Distance, PointStruct, QueryPointsBuilder, UpsertPointsBuilder, 10 | VectorParamsBuilder, 11 | }, 12 | Payload, Qdrant, 13 | }; 14 | use rig::{ 15 | embeddings::EmbeddingsBuilder, providers::openai, vector_store::VectorStoreIndex, Embed, 16 | }; 17 | use rig_qdrant::QdrantVectorStore; 18 | 19 | const QDRANT_PORT: u16 = 6333; 20 | const QDRANT_PORT_SECONDARY: u16 = 6334; 21 | const COLLECTION_NAME: &str = "rig-collection"; 22 | 23 | #[derive(Embed, Clone, serde::Deserialize, serde::Serialize, Debug)] 24 | struct Word { 25 | id: String, 26 | #[embed] 27 | definition: String, 28 | } 29 | 30 | #[tokio::test] 31 | async fn vector_search_test() { 32 | // Setup a local qdrant container for testing. NOTE: docker service must be running. 33 | let container = GenericImage::new("qdrant/qdrant", "latest") 34 | .with_wait_for(WaitFor::Duration { 35 | length: std::time::Duration::from_secs(5), 36 | }) 37 | .with_exposed_port(QDRANT_PORT.tcp()) 38 | .with_exposed_port(QDRANT_PORT_SECONDARY.tcp()) 39 | .start() 40 | .await 41 | .expect("Failed to start qdrant container"); 42 | 43 | let port = container 44 | .get_host_port_ipv4(QDRANT_PORT_SECONDARY) 45 | .await 46 | .unwrap(); 47 | let host = container.get_host().await.unwrap().to_string(); 48 | 49 | let client = Qdrant::from_url(&format!("http://{host}:{port}")) 50 | .build() 51 | .unwrap(); 52 | 53 | // Create a collection with 1536 dimensions if it doesn't exist 54 | // Note: Make sure the dimensions match the size of the embeddings returned by the 55 | // model you are using 56 | if !client.collection_exists(COLLECTION_NAME).await.unwrap() { 57 | client 58 | .create_collection( 59 | CreateCollectionBuilder::new(COLLECTION_NAME) 60 | .vectors_config(VectorParamsBuilder::new(1536, Distance::Cosine)), 61 | ) 62 | .await 63 | .unwrap(); 64 | } 65 | 66 | // Initialize OpenAI client. 67 | let openai_client = openai::Client::from_env(); 68 | 69 | let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); 70 | 71 | let points = create_points(model.clone()).await; 72 | 73 | client 74 | .upsert_points(UpsertPointsBuilder::new(COLLECTION_NAME, points)) 75 | .await 76 | .unwrap(); 77 | 78 | let query_params = QueryPointsBuilder::new(COLLECTION_NAME).with_payload(true); 79 | let vector_store = QdrantVectorStore::new(client, model, query_params.build()); 80 | 81 | let results = vector_store 82 | .top_n::<serde_json::Value>("What is a linglingdong?", 1) 83 | .await 84 | .unwrap(); 85 | 86 | let (_, _, value) = &results.first().unwrap(); 87 | 88 | assert_eq!( 89 | value, 90 | &serde_json::json!({ 91 | "definition": "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.", 92 | "id": "f9e17d59-32e5-440c-be02-b2759a654824" 93 | }) 94 | ) 95 | } 96 | 97 | async fn create_points(model: openai::EmbeddingModel) -> Vec<PointStruct> { 98 | let words = vec![ 99 | Word { 100 | id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), 101 | definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), 102 | }, 103 | Word { 104 | id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), 105 | definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 106 | }, 107 | Word { 108 | id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), 109 | definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), 110 | } 111 | ]; 112 | 113 | let documents = EmbeddingsBuilder::new(model) 114 | .documents(words) 115 | .unwrap() 116 | .build() 117 | .await 118 | .unwrap(); 119 | 120 | documents 121 | .into_iter() 122 | .map(|(d, embeddings)| { 123 | let vec: Vec<f32> = embeddings.first().vec.iter().map(|&x| x as f32).collect(); 124 | PointStruct::new( 125 | d.id.clone(), 126 | vec, 127 | Payload::try_from(serde_json::to_value(&d).unwrap()).unwrap(), 128 | ) 129 | }) 130 | .collect() 131 | } 132 | -------------------------------------------------------------------------------- /rig-qdrant/src/lib.rs: -------------------------------------------------------------------------------- 1 | use qdrant_client::{ 2 | qdrant::{point_id::PointIdOptions, PointId, Query, QueryPoints}, 3 | Qdrant, 4 | }; 5 | use rig::{ 6 | embeddings::EmbeddingModel, 7 | vector_store::{VectorStoreError, VectorStoreIndex}, 8 | }; 9 | use serde::Deserialize; 10 | 11 | /// Represents a vector store implementation using Qdrant - <https://qdrant.tech/> as the backend. 12 | pub struct QdrantVectorStore<M: EmbeddingModel> { 13 | /// Model used to generate embeddings for the vector store 14 | model: M, 15 | /// Client instance for Qdrant server communication 16 | client: Qdrant, 17 | /// Default search parameters 18 | query_params: QueryPoints, 19 | } 20 | 21 | impl<M: EmbeddingModel> QdrantVectorStore<M> { 22 | /// Creates a new instance of `QdrantVectorStore`. 23 | /// 24 | /// # Arguments 25 | /// * `client` - Qdrant client instance 26 | /// * `model` - Embedding model instance 27 | /// * `query_params` - Search parameters for vector queries 28 | /// Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points> 29 | pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self { 30 | Self { 31 | client, 32 | model, 33 | query_params, 34 | } 35 | } 36 | 37 | /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format. 38 | async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> { 39 | let embedding = self.model.embed_text(query).await?; 40 | Ok(embedding.vec.iter().map(|&x| x as f32).collect()) 41 | } 42 | 43 | /// Fill in query parameters with the given query and limit. 44 | fn prepare_query_params(&self, query: Option<Query>, limit: usize) -> QueryPoints { 45 | let mut params = self.query_params.clone(); 46 | params.query = query; 47 | params.limit = Some(limit as u64); 48 | params 49 | } 50 | } 51 | 52 | /// Converts a `PointId` to its string representation. 53 | fn stringify_id(id: PointId) -> Result<String, VectorStoreError> { 54 | match id.point_id_options { 55 | Some(PointIdOptions::Num(num)) => Ok(num.to_string()), 56 | Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()), 57 | None => Err(VectorStoreError::DatastoreError( 58 | "Invalid point ID format".into(), 59 | )), 60 | } 61 | } 62 | 63 | impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for QdrantVectorStore<M> { 64 | /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store. 65 | /// Returns a vector of tuples containing the score, ID, and payload of the nearest neighbors. 66 | async fn top_n<T: for<'a> Deserialize<'a> + Send>( 67 | &self, 68 | query: &str, 69 | n: usize, 70 | ) -> Result<Vec<(f64, String, T)>, VectorStoreError> { 71 | let query = match self.query_params.query { 72 | Some(ref q) => Some(q.clone()), 73 | None => Some(Query::new_nearest(self.generate_query_vector(query).await?)), 74 | }; 75 | 76 | let params = self.prepare_query_params(query, n); 77 | let result = self 78 | .client 79 | .query(params) 80 | .await 81 | .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; 82 | 83 | result 84 | .result 85 | .into_iter() 86 | .map(|item| { 87 | let id = 88 | stringify_id(item.id.ok_or_else(|| { 89 | VectorStoreError::DatastoreError("Missing point ID".into()) 90 | })?)?; 91 | let score = item.score as f64; 92 | let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?; 93 | Ok((score, id, payload)) 94 | }) 95 | .collect() 96 | } 97 | 98 | /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store. 99 | /// Returns a vector of tuples containing the score and ID of the nearest neighbors. 100 | async fn top_n_ids( 101 | &self, 102 | query: &str, 103 | n: usize, 104 | ) -> Result<Vec<(f64, String)>, VectorStoreError> { 105 | let query = match self.query_params.query { 106 | Some(ref q) => Some(q.clone()), 107 | None => Some(Query::new_nearest(self.generate_query_vector(query).await?)), 108 | }; 109 | 110 | let params = self.prepare_query_params(query, n); 111 | let points = self 112 | .client 113 | .query(params) 114 | .await 115 | .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))? 116 | .result; 117 | 118 | points 119 | .into_iter() 120 | .map(|point| { 121 | let id = 122 | stringify_id(point.id.ok_or_else(|| { 123 | VectorStoreError::DatastoreError("Missing point ID".into()) 124 | })?)?; 125 | Ok((point.score as f64, id)) 126 | }) 127 | .collect() 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /rig-core/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.5.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.4.1...rig-core-v0.5.0) - 2024-12-03 11 | 12 | ### Added 13 | 14 | - Improve `InMemoryVectorStore` API ([#130](https://github.com/0xPlaygrounds/rig/pull/130)) 15 | - embeddings API overhaul ([#120](https://github.com/0xPlaygrounds/rig/pull/120)) 16 | - *(provider)* xAI (grok) integration ([#106](https://github.com/0xPlaygrounds/rig/pull/106)) 17 | 18 | ### Fixed 19 | 20 | - *(rig-lancedb)* rag embedding filtering ([#104](https://github.com/0xPlaygrounds/rig/pull/104)) 21 | 22 | ## [0.4.1](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.4.0...rig-core-v0.4.1) - 2024-11-13 23 | 24 | ### Other 25 | 26 | - Inefficient context documents serialization ([#100](https://github.com/0xPlaygrounds/rig/pull/100)) 27 | 28 | ## [0.4.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.3.0...rig-core-v0.4.0) - 2024-11-07 29 | 30 | ### Added 31 | 32 | - *(gemini)* move system prompt to correct request field 33 | - *(provider-gemini)* add support for gemini specific completion parameters 34 | - *(provider-gemini)* add agent support in client 35 | - *(provider-gemini)* add gemini embedding support 36 | - *(provider-gemini)* add gemini support for basic completion 37 | - *(provider-gemini)* add gemini API client 38 | 39 | ### Fixed 40 | 41 | - *(gemini)* issue when additionnal param is empty 42 | - docs imports and refs 43 | - *(gemini)* missing param to be marked as optional in completion res 44 | 45 | ### Other 46 | 47 | - Cargo fmt 48 | - Add module level docs for the `tool` module 49 | - Fix loaders module docs references 50 | - Add docstrings to loaders module 51 | - Improve main lib docs 52 | - Add `all` feature flag to rig-core 53 | - *(gemini)* add utility config docstring 54 | - *(gemini)* remove try_from and use serde deserialization 55 | - Merge branch 'main' into feat/model-provider/16-add-gemini-completion-embedding-models 56 | - *(gemini)* separate gemini api types module, fix pr comments 57 | - add debug trait to embedding struct 58 | - *(gemini)* add addtionnal types from the official documentation, add embeddings example 59 | - *(provider-gemini)* test pre-commits 60 | - *(provider-gemini)* Update readme entries, add gemini agent example 61 | 62 | ## [0.3.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.2.1...rig-core-v0.3.0) - 2024-10-24 63 | 64 | ### Added 65 | 66 | - Generalize `EmbeddingModel::embed_documents` with `IntoIterator` 67 | - Add `from_env` constructor to Cohere and Anthropic clients 68 | - Small optimization to serde_json object merging 69 | - Add better error handling for provider clients 70 | 71 | ### Fixed 72 | 73 | - Bad Anthropic request/response handling 74 | - *(vector-index)* In memory vector store index incorrect search 75 | 76 | ### Other 77 | 78 | - Made internal `json_utils` module private 79 | - Update lib docs 80 | - Made CompletionRequest helper method private to crate 81 | - lint + fmt 82 | - Simplify `agent_with_tools` example 83 | - Fix docstring links 84 | - Add nextest test runner to CI 85 | - Merge pull request [#42](https://github.com/0xPlaygrounds/rig/pull/42) from 0xPlaygrounds/refactor(vector-store)/update-vector-store-index-trait 86 | 87 | ## [0.2.1](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.2.0...rig-core-v0.2.1) - 2024-10-01 88 | 89 | ### Fixed 90 | 91 | - *(docs)* Docs still referring to old types 92 | 93 | ### Other 94 | 95 | - Merge pull request [#45](https://github.com/0xPlaygrounds/rig/pull/45) from 0xPlaygrounds/fix/docs 96 | 97 | ## [0.2.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.1.0...rig-core-v0.2.0) - 2024-10-01 98 | 99 | ### Added 100 | 101 | - anthropic models 102 | 103 | ### Fixed 104 | 105 | - *(context)* displaying documents should be deterministic (sorted by alpha) 106 | - *(context)* spin out helper method + add tests 107 | - move context documents to user prompt message 108 | - adjust version const naming 109 | - implement review suggestions + renaming 110 | - add `completion_request.documents` to `chat_history` 111 | - adjust API to be cleaner + add docstrings 112 | 113 | ### Other 114 | 115 | - Merge pull request [#43](https://github.com/0xPlaygrounds/rig/pull/43) from 0xPlaygrounds/fix/context-documents 116 | - Merge pull request [#27](https://github.com/0xPlaygrounds/rig/pull/27) from 0xPlaygrounds/feat/anthropic 117 | - Fix docstrings 118 | - Deprecate RagAgent and Model in favor of versatile Agent 119 | - Make RagAgent VectorStoreIndex dynamic trait objects 120 | 121 | ## [0.1.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.0.7...rig-core-v0.1.0) - 2024-09-16 122 | 123 | ### Added 124 | 125 | - add o1-preview and o1-mini 126 | 127 | ### Fixed 128 | 129 | - *(perplexity)* fix preamble and context in completion request 130 | - clippy warnings 131 | 132 | ### Other 133 | 134 | - Merge pull request [#18](https://github.com/0xPlaygrounds/rig/pull/18) from 0xPlaygrounds/feat/perplexity-support 135 | - Add logging of http errors 136 | - fmt code 137 | -------------------------------------------------------------------------------- /rig-core/src/extractor.rs: -------------------------------------------------------------------------------- 1 | //! This module provides high-level abstractions for extracting structured data from text using LLMs. 2 | //! 3 | //! Note: The target structure must implement the `serde::Deserialize`, `serde::Serialize`, 4 | //! and `schemars::JsonSchema` traits. Those can be easily derived using the `derive` macro. 5 | //! 6 | //! # Example 7 | //! ``` 8 | //! use rig::providers::openai; 9 | //! 10 | //! // Initialize the OpenAI client 11 | //! let openai = openai::Client::new("your-open-ai-api-key"); 12 | //! 13 | //! // Define the structure of the data you want to extract 14 | //! #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] 15 | //! struct Person { 16 | //! name: Option<String>, 17 | //! age: Option<u8>, 18 | //! profession: Option<String>, 19 | //! } 20 | //! 21 | //! // Create the extractor 22 | //! let extractor = openai.extractor::<Person>(openai::GPT_4O) 23 | //! .build(); 24 | //! 25 | //! // Extract structured data from text 26 | //! let person = extractor.extract("John Doe is a 30 year old doctor.") 27 | //! .await 28 | //! .expect("Failed to extract data from text"); 29 | //! ``` 30 | 31 | use std::marker::PhantomData; 32 | 33 | use schemars::{schema_for, JsonSchema}; 34 | use serde::{Deserialize, Serialize}; 35 | use serde_json::json; 36 | 37 | use crate::{ 38 | agent::{Agent, AgentBuilder}, 39 | completion::{CompletionModel, Prompt, PromptError, ToolDefinition}, 40 | tool::Tool, 41 | }; 42 | 43 | #[derive(Debug, thiserror::Error)] 44 | pub enum ExtractionError { 45 | #[error("No data extracted")] 46 | NoData, 47 | 48 | #[error("Failed to deserialize the extracted data: {0}")] 49 | DeserializationError(#[from] serde_json::Error), 50 | 51 | #[error("PromptError: {0}")] 52 | PromptError(#[from] PromptError), 53 | } 54 | 55 | /// Extractor for structured data from text 56 | pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> { 57 | agent: Agent<M>, 58 | _t: PhantomData<T>, 59 | } 60 | 61 | impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T> 62 | where 63 | M: Sync, 64 | { 65 | pub async fn extract(&self, text: &str) -> Result<T, ExtractionError> { 66 | let summary = self.agent.prompt(text).await?; 67 | 68 | if summary.is_empty() { 69 | return Err(ExtractionError::NoData); 70 | } 71 | 72 | Ok(serde_json::from_str(&summary)?) 73 | } 74 | } 75 | 76 | /// Builder for the Extractor 77 | pub struct ExtractorBuilder< 78 | T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync + 'static, 79 | M: CompletionModel, 80 | > { 81 | agent_builder: AgentBuilder<M>, 82 | _t: PhantomData<T>, 83 | } 84 | 85 | impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel> 86 | ExtractorBuilder<T, M> 87 | { 88 | pub fn new(model: M) -> Self { 89 | Self { 90 | agent_builder: AgentBuilder::new(model) 91 | .preamble("\ 92 | You are an AI assistant whose purpose is to extract structured data from the provided text.\n\ 93 | You will have access to a `submit` function that defines the structure of the data to extract from the provided text.\n\ 94 | Use the `submit` function to submit the structured data.\n\ 95 | Be sure to fill out every field and ALWAYS CALL THE `submit` function, event with default values!!!. 96 | ") 97 | .tool(SubmitTool::<T> {_t: PhantomData}), 98 | _t: PhantomData, 99 | } 100 | } 101 | 102 | /// Add additional preamble to the extractor 103 | pub fn preamble(mut self, preamble: &str) -> Self { 104 | self.agent_builder = self.agent_builder.append_preamble(&format!( 105 | "\n=============== ADDITIONAL INSTRUCTIONS ===============\n{preamble}" 106 | )); 107 | self 108 | } 109 | 110 | /// Add a context document to the extractor 111 | pub fn context(mut self, doc: &str) -> Self { 112 | self.agent_builder = self.agent_builder.context(doc); 113 | self 114 | } 115 | 116 | /// Build the Extractor 117 | pub fn build(self) -> Extractor<M, T> { 118 | Extractor { 119 | agent: self.agent_builder.build(), 120 | _t: PhantomData, 121 | } 122 | } 123 | } 124 | 125 | #[derive(Deserialize, Serialize)] 126 | struct SubmitTool<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> { 127 | _t: PhantomData<T>, 128 | } 129 | 130 | #[derive(Debug, thiserror::Error)] 131 | #[error("SubmitError")] 132 | struct SubmitError; 133 | 134 | impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync> Tool for SubmitTool<T> { 135 | const NAME: &'static str = "submit"; 136 | type Error = SubmitError; 137 | type Args = T; 138 | type Output = T; 139 | 140 | async fn definition(&self, _prompt: String) -> ToolDefinition { 141 | ToolDefinition { 142 | name: Self::NAME.to_string(), 143 | description: "Submit the structured data you extracted from the provided text." 144 | .to_string(), 145 | parameters: json!(schema_for!(T)), 146 | } 147 | } 148 | 149 | async fn call(&self, data: Self::Args) -> Result<Self::Output, Self::Error> { 150 | Ok(data) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /rig-core/examples/rag_dynamic_tools.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use rig::{ 3 | completion::{Prompt, ToolDefinition}, 4 | embeddings::EmbeddingsBuilder, 5 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 6 | tool::{Tool, ToolEmbedding, ToolSet}, 7 | vector_store::in_memory_store::InMemoryVectorStore, 8 | }; 9 | use serde::{Deserialize, Serialize}; 10 | use serde_json::json; 11 | use std::env; 12 | 13 | #[derive(Deserialize)] 14 | struct OperationArgs { 15 | x: i32, 16 | y: i32, 17 | } 18 | 19 | #[derive(Debug, thiserror::Error)] 20 | #[error("Math error")] 21 | struct MathError; 22 | 23 | #[derive(Debug, thiserror::Error)] 24 | #[error("Math error")] 25 | struct InitError; 26 | 27 | #[derive(Deserialize, Serialize)] 28 | struct Add; 29 | 30 | impl Tool for Add { 31 | const NAME: &'static str = "add"; 32 | 33 | type Error = MathError; 34 | type Args = OperationArgs; 35 | type Output = i32; 36 | 37 | async fn definition(&self, _prompt: String) -> ToolDefinition { 38 | serde_json::from_value(json!({ 39 | "name": "add", 40 | "description": "Add x and y together", 41 | "parameters": { 42 | "type": "object", 43 | "properties": { 44 | "x": { 45 | "type": "number", 46 | "description": "The first number to add" 47 | }, 48 | "y": { 49 | "type": "number", 50 | "description": "The second number to add" 51 | } 52 | } 53 | } 54 | })) 55 | .expect("Tool Definition") 56 | } 57 | 58 | async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { 59 | let result = args.x + args.y; 60 | Ok(result) 61 | } 62 | } 63 | 64 | impl ToolEmbedding for Add { 65 | type InitError = InitError; 66 | type Context = (); 67 | type State = (); 68 | 69 | fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> { 70 | Ok(Add) 71 | } 72 | 73 | fn embedding_docs(&self) -> Vec<String> { 74 | vec!["Add x and y together".into()] 75 | } 76 | 77 | fn context(&self) -> Self::Context {} 78 | } 79 | 80 | #[derive(Deserialize, Serialize)] 81 | struct Subtract; 82 | 83 | impl Tool for Subtract { 84 | const NAME: &'static str = "subtract"; 85 | 86 | type Error = MathError; 87 | type Args = OperationArgs; 88 | type Output = i32; 89 | 90 | async fn definition(&self, _prompt: String) -> ToolDefinition { 91 | serde_json::from_value(json!({ 92 | "name": "subtract", 93 | "description": "Subtract y from x (i.e.: x - y)", 94 | "parameters": { 95 | "type": "object", 96 | "properties": { 97 | "x": { 98 | "type": "number", 99 | "description": "The number to substract from" 100 | }, 101 | "y": { 102 | "type": "number", 103 | "description": "The number to substract" 104 | } 105 | } 106 | } 107 | })) 108 | .expect("Tool Definition") 109 | } 110 | 111 | async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { 112 | let result = args.x - args.y; 113 | Ok(result) 114 | } 115 | } 116 | 117 | impl ToolEmbedding for Subtract { 118 | type InitError = InitError; 119 | type Context = (); 120 | type State = (); 121 | 122 | fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> { 123 | Ok(Subtract) 124 | } 125 | 126 | fn context(&self) -> Self::Context {} 127 | 128 | fn embedding_docs(&self) -> Vec<String> { 129 | vec!["Subtract y from x (i.e.: x - y)".into()] 130 | } 131 | } 132 | 133 | #[tokio::main] 134 | async fn main() -> Result<(), anyhow::Error> { 135 | // required to enable CloudWatch error logging by the runtime 136 | tracing_subscriber::fmt() 137 | .with_max_level(tracing::Level::INFO) 138 | // disable printing the name of the module in every log line. 139 | .with_target(false) 140 | .init(); 141 | 142 | // Create OpenAI client 143 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 144 | let openai_client = Client::new(&openai_api_key); 145 | 146 | let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 147 | 148 | let toolset = ToolSet::builder() 149 | .dynamic_tool(Add) 150 | .dynamic_tool(Subtract) 151 | .build(); 152 | 153 | let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) 154 | .documents(toolset.schemas()?)? 155 | .build() 156 | .await?; 157 | 158 | // Create vector store with the embeddings 159 | let vector_store = 160 | InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); 161 | 162 | // Create vector store index 163 | let index = vector_store.index(embedding_model); 164 | 165 | // Create RAG agent with a single context prompt and a dynamic tool source 166 | let calculator_rag = openai_client 167 | .agent("gpt-4") 168 | .preamble("You are a calculator here to help the user perform arithmetic operations.") 169 | // Add a dynamic tool source with a sample rate of 1 (i.e.: only 170 | // 1 additional tool will be added to prompts) 171 | .dynamic_tools(1, index, toolset) 172 | .build(); 173 | 174 | // Prompt the agent and print the response 175 | let response = calculator_rag.prompt("Calculate 3 - 7").await?; 176 | println!("{}", response); 177 | 178 | Ok(()) 179 | } 180 | -------------------------------------------------------------------------------- /rig-core/src/providers/anthropic/client.rs: -------------------------------------------------------------------------------- 1 | //! Anthropic client api implementation 2 | 3 | use crate::{agent::AgentBuilder, extractor::ExtractorBuilder}; 4 | 5 | use schemars::JsonSchema; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | use super::completion::{CompletionModel, ANTHROPIC_VERSION_LATEST}; 9 | 10 | // ================================================================ 11 | // Main Anthropic Client 12 | // ================================================================ 13 | const ANTHROPIC_API_BASE_URL: &str = "https://api.anthropic.com"; 14 | 15 | #[derive(Clone)] 16 | pub struct ClientBuilder<'a> { 17 | api_key: &'a str, 18 | base_url: &'a str, 19 | anthropic_version: &'a str, 20 | anthropic_betas: Option<Vec<&'a str>>, 21 | } 22 | 23 | /// Create a new anthropic client using the builder 24 | /// 25 | /// # Example 26 | /// ``` 27 | /// use rig::providers::anthropic::{ClientBuilder, self}; 28 | /// 29 | /// // Initialize the Anthropic client 30 | /// let anthropic_client = ClientBuilder::new("your-claude-api-key") 31 | /// .anthropic_version(ANTHROPIC_VERSION_LATEST) 32 | /// .anthropic_beta("prompt-caching-2024-07-31") 33 | /// .build() 34 | /// ``` 35 | impl<'a> ClientBuilder<'a> { 36 | pub fn new(api_key: &'a str) -> Self { 37 | Self { 38 | api_key, 39 | base_url: ANTHROPIC_API_BASE_URL, 40 | anthropic_version: ANTHROPIC_VERSION_LATEST, 41 | anthropic_betas: None, 42 | } 43 | } 44 | 45 | pub fn base_url(mut self, base_url: &'a str) -> Self { 46 | self.base_url = base_url; 47 | self 48 | } 49 | 50 | pub fn anthropic_version(mut self, anthropic_version: &'a str) -> Self { 51 | self.anthropic_version = anthropic_version; 52 | self 53 | } 54 | 55 | pub fn anthropic_beta(mut self, anthropic_beta: &'a str) -> Self { 56 | if let Some(mut betas) = self.anthropic_betas { 57 | betas.push(anthropic_beta); 58 | self.anthropic_betas = Some(betas); 59 | } else { 60 | self.anthropic_betas = Some(vec![anthropic_beta]); 61 | } 62 | self 63 | } 64 | 65 | pub fn build(self) -> Client { 66 | Client::new( 67 | self.api_key, 68 | self.base_url, 69 | self.anthropic_betas, 70 | self.anthropic_version, 71 | ) 72 | } 73 | } 74 | 75 | #[derive(Clone)] 76 | pub struct Client { 77 | base_url: String, 78 | http_client: reqwest::Client, 79 | } 80 | 81 | impl Client { 82 | /// Create a new Anthropic client with the given API key, base URL, betas, and version. 83 | /// Note, you proably want to use the `ClientBuilder` instead. 84 | /// 85 | /// Panics: 86 | /// - If the API key or version cannot be parsed as a Json value from a String. 87 | /// - This should really never happen. 88 | /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). 89 | pub fn new(api_key: &str, base_url: &str, betas: Option<Vec<&str>>, version: &str) -> Self { 90 | Self { 91 | base_url: base_url.to_string(), 92 | http_client: reqwest::Client::builder() 93 | .default_headers({ 94 | let mut headers = reqwest::header::HeaderMap::new(); 95 | headers.insert("x-api-key", api_key.parse().expect("API key should parse")); 96 | headers.insert( 97 | "anthropic-version", 98 | version.parse().expect("Anthropic version should parse"), 99 | ); 100 | if let Some(betas) = betas { 101 | headers.insert( 102 | "anthropic-beta", 103 | betas 104 | .join(",") 105 | .parse() 106 | .expect("Anthropic betas should parse"), 107 | ); 108 | } 109 | headers 110 | }) 111 | .build() 112 | .expect("Anthropic reqwest client should build"), 113 | } 114 | } 115 | 116 | /// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable. 117 | /// Panics if the environment variable is not set. 118 | pub fn from_env() -> Self { 119 | let api_key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); 120 | ClientBuilder::new(&api_key).build() 121 | } 122 | 123 | pub fn post(&self, path: &str) -> reqwest::RequestBuilder { 124 | let url = format!("{}/{}", self.base_url, path).replace("//", "/"); 125 | self.http_client.post(url) 126 | } 127 | 128 | pub fn completion_model(&self, model: &str) -> CompletionModel { 129 | CompletionModel::new(self.clone(), model) 130 | } 131 | 132 | /// Create an agent builder with the given completion model. 133 | /// 134 | /// # Example 135 | /// ``` 136 | /// use rig::providers::anthropic::{ClientBuilder, self}; 137 | /// 138 | /// // Initialize the Anthropic client 139 | /// let anthropic = ClientBuilder::new("your-claude-api-key").build(); 140 | /// 141 | /// let agent = anthropic.agent(anthropic::CLAUDE_3_5_SONNET) 142 | /// .preamble("You are comedian AI with a mission to make people laugh.") 143 | /// .temperature(0.0) 144 | /// .build(); 145 | /// ``` 146 | pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> { 147 | AgentBuilder::new(self.completion_model(model)) 148 | } 149 | 150 | pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>( 151 | &self, 152 | model: &str, 153 | ) -> ExtractorBuilder<T, CompletionModel> { 154 | ExtractorBuilder::new(self.completion_model(model)) 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /rig-neo4j/examples/vector_search_simple.rs: -------------------------------------------------------------------------------- 1 | //! Simple end-to-end example of the vector search capabilities of the `rig-neo4j` crate. 2 | //! This example expects a running Neo4j instance running. 3 | //! It: 4 | //! 1. Generates embeddings for a set of 3 "documents" 5 | //! 2. Adds the documents to the Neo4j DB 6 | //! 3. Creates a vector index on the embeddings 7 | //! 4. Queries the vector index 8 | //! 5. Returns the results 9 | use std::env; 10 | 11 | use futures::{StreamExt, TryStreamExt}; 12 | use rig::{ 13 | embeddings::EmbeddingsBuilder, 14 | providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, 15 | vector_store::VectorStoreIndex as _, 16 | Embed, 17 | }; 18 | use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType}; 19 | 20 | #[derive(Embed, Clone, Debug)] 21 | pub struct Word { 22 | pub id: String, 23 | #[embed] 24 | pub definition: String, 25 | } 26 | 27 | #[tokio::main] 28 | async fn main() -> Result<(), anyhow::Error> { 29 | // Initialize OpenAI client 30 | let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); 31 | let openai_client = Client::new(&openai_api_key); 32 | 33 | // Initialize Neo4j client 34 | let neo4j_uri = env::var("NEO4J_URI").expect("NEO4J_URI not set"); 35 | let neo4j_username = env::var("NEO4J_USERNAME").expect("NEO4J_USERNAME not set"); 36 | let neo4j_password = env::var("NEO4J_PASSWORD").expect("NEO4J_PASSWORD not set"); 37 | 38 | let neo4j_client = Neo4jClient::connect(&neo4j_uri, &neo4j_username, &neo4j_password).await?; 39 | 40 | // Select the embedding model and generate our embeddings 41 | let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); 42 | 43 | let embeddings = EmbeddingsBuilder::new(model.clone()) 44 | .document(Word { 45 | id: "doc0".to_string(), 46 | definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), 47 | })? 48 | .document(Word { 49 | id: "doc1".to_string(), 50 | definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 51 | })? 52 | .document(Word { 53 | id: "doc2".to_string(), 54 | definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), 55 | })? 56 | .build() 57 | .await?; 58 | 59 | futures::stream::iter(embeddings) 60 | .map(|(doc, embeddings)| { 61 | neo4j_client.graph.run( 62 | neo4rs::query( 63 | " 64 | CREATE 65 | (document:DocumentEmbeddings { 66 | id: $id, 67 | document: $document, 68 | embedding: $embedding}) 69 | RETURN document", 70 | ) 71 | .param("id", doc.id) 72 | // Here we use the first embedding but we could use any of them. 73 | // Neo4j only takes primitive types or arrays as properties. 74 | .param("embedding", embeddings.first().vec.clone()) 75 | .param("document", doc.definition.to_bolt_type()), 76 | ) 77 | }) 78 | .buffer_unordered(3) 79 | .try_collect::<Vec<_>>() 80 | .await 81 | .unwrap(); 82 | 83 | // Create a vector index on our vector store 84 | println!("Creating vector index..."); 85 | neo4j_client 86 | .graph 87 | .run(neo4rs::query( 88 | "CREATE VECTOR INDEX vector_index IF NOT EXISTS 89 | FOR (m:DocumentEmbeddings) 90 | ON m.embedding 91 | OPTIONS { indexConfig: { 92 | `vector.dimensions`: 1536, 93 | `vector.similarity_function`: 'cosine' 94 | }}", 95 | )) 96 | .await?; 97 | 98 | // ℹ️ The index name must be unique among both indexes and constraints. 99 | // A newly created index is not immediately available but is created in the background. 100 | 101 | // Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready 102 | let index_exists = neo4j_client 103 | .graph 104 | .run(neo4rs::query("CALL db.awaitIndex('vector_index')")) 105 | .await; 106 | if index_exists.is_err() { 107 | println!("Index not ready, waiting for index..."); 108 | std::thread::sleep(std::time::Duration::from_secs(5)); 109 | } 110 | 111 | println!("Index exists: {:?}", index_exists); 112 | 113 | // Create a vector index on our vector store 114 | // IMPORTANT: Reuse the same model that was used to generate the embeddings 115 | let index = neo4j_client 116 | .get_index(model, "vector_index", SearchParams::default()) 117 | .await?; 118 | 119 | // The struct that will reprensent a node in the database. Used to deserialize the results of the query (passed to the `top_n` methods) 120 | // ❗IMPORTANT: The field names must match the property names in the database 121 | #[derive(serde::Deserialize)] 122 | struct Document { 123 | #[allow(dead_code)] 124 | id: String, 125 | document: String, 126 | } 127 | 128 | // Query the index 129 | let results = index 130 | .top_n::<Document>("What is a glarb?", 1) 131 | .await? 132 | .into_iter() 133 | .map(|(score, id, doc)| (score, id, doc.document)) 134 | .collect::<Vec<_>>(); 135 | 136 | println!("Results: {:?}", results); 137 | 138 | let id_results = index 139 | .top_n_ids("What is a linglingdong?", 1) 140 | .await? 141 | .into_iter() 142 | .map(|(score, id)| (score, id)) 143 | .collect::<Vec<_>>(); 144 | 145 | println!("ID results: {:?}", id_results); 146 | 147 | Ok(()) 148 | } 149 | -------------------------------------------------------------------------------- /rig-neo4j/tests/integration_tests.rs: -------------------------------------------------------------------------------- 1 | use testcontainers::{ 2 | core::{IntoContainerPort, Mount, WaitFor}, 3 | runners::AsyncRunner, 4 | GenericImage, ImageExt, 5 | }; 6 | 7 | use futures::{StreamExt, TryStreamExt}; 8 | use rig::vector_store::VectorStoreIndex; 9 | use rig::{ 10 | embeddings::{Embedding, EmbeddingsBuilder}, 11 | providers::openai, 12 | Embed, OneOrMany, 13 | }; 14 | use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType}; 15 | 16 | const BOLT_PORT: u16 = 7687; 17 | const HTTP_PORT: u16 = 7474; 18 | 19 | #[derive(Embed, Clone, serde::Deserialize, Debug)] 20 | struct Word { 21 | id: String, 22 | #[embed] 23 | definition: String, 24 | } 25 | 26 | #[tokio::test] 27 | async fn vector_search_test() { 28 | let mount = Mount::volume_mount("data", std::env::var("GITHUB_WORKSPACE").unwrap()); 29 | // Setup a local Neo 4J container for testing. NOTE: docker service must be running. 30 | let container = GenericImage::new("neo4j", "latest") 31 | .with_wait_for(WaitFor::Duration { 32 | length: std::time::Duration::from_secs(5), 33 | }) 34 | .with_exposed_port(BOLT_PORT.tcp()) 35 | .with_exposed_port(HTTP_PORT.tcp()) 36 | .with_mount(mount) 37 | .with_env_var("NEO4J_AUTH", "none") 38 | .start() 39 | .await 40 | .expect("Failed to start Neo 4J container"); 41 | 42 | let port = container.get_host_port_ipv4(BOLT_PORT).await.unwrap(); 43 | let host = container.get_host().await.unwrap().to_string(); 44 | 45 | let neo4j_client = Neo4jClient::connect(&format!("neo4j://{host}:{port}"), "", "") 46 | .await 47 | .unwrap(); 48 | 49 | // Initialize OpenAI client 50 | let openai_client = openai::Client::from_env(); 51 | 52 | // Select the embedding model and generate our embeddings 53 | let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); 54 | 55 | let embeddings = create_embeddings(model.clone()).await; 56 | 57 | futures::stream::iter(embeddings) 58 | .map(|(doc, embeddings)| { 59 | neo4j_client.graph.run( 60 | neo4rs::query( 61 | " 62 | CREATE 63 | (document:DocumentEmbeddings { 64 | id: $id, 65 | document: $document, 66 | embedding: $embedding}) 67 | RETURN document", 68 | ) 69 | .param("id", doc.id) 70 | // Here we use the first embedding but we could use any of them. 71 | // Neo4j only takes primitive types or arrays as properties. 72 | .param("embedding", embeddings.first().vec.clone()) 73 | .param("document", doc.definition.to_bolt_type()), 74 | ) 75 | }) 76 | .buffer_unordered(3) 77 | .try_collect::<Vec<_>>() 78 | .await 79 | .unwrap(); 80 | 81 | // Create a vector index on our vector store 82 | println!("Creating vector index..."); 83 | neo4j_client 84 | .graph 85 | .run(neo4rs::query( 86 | "CREATE VECTOR INDEX vector_index IF NOT EXISTS 87 | FOR (m:DocumentEmbeddings) 88 | ON m.embedding 89 | OPTIONS { indexConfig: { 90 | `vector.dimensions`: 1536, 91 | `vector.similarity_function`: 'cosine' 92 | }}", 93 | )) 94 | .await 95 | .unwrap(); 96 | 97 | // ℹ️ The index name must be unique among both indexes and constraints. 98 | // A newly created index is not immediately available but is created in the background. 99 | 100 | // Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready 101 | let index_exists = neo4j_client 102 | .graph 103 | .run(neo4rs::query("CALL db.awaitIndex('vector_index')")) 104 | .await; 105 | if index_exists.is_err() { 106 | println!("Index not ready, waiting for index..."); 107 | std::thread::sleep(std::time::Duration::from_secs(5)); 108 | } 109 | 110 | println!("Index exists: {:?}", index_exists); 111 | 112 | // Create a vector index on our vector store 113 | // IMPORTANT: Reuse the same model that was used to generate the embeddings 114 | let index = neo4j_client 115 | .get_index(model, "vector_index", SearchParams::default()) 116 | .await 117 | .unwrap(); 118 | 119 | // Query the index 120 | let results = index 121 | .top_n::<serde_json::Value>("What is a glarb?", 1) 122 | .await 123 | .unwrap(); 124 | 125 | let (_, _, value) = &results.first().unwrap(); 126 | 127 | assert_eq!( 128 | value, 129 | &serde_json::json!({ 130 | "id": "doc1", 131 | "document": "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.", 132 | "embedding": serde_json::Value::Null 133 | }) 134 | ) 135 | } 136 | 137 | async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(Word, OneOrMany<Embedding>)> { 138 | let words = vec![ 139 | Word { 140 | id: "doc0".to_string(), 141 | definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), 142 | }, 143 | Word { 144 | id: "doc1".to_string(), 145 | definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), 146 | }, 147 | Word { 148 | id: "doc2".to_string(), 149 | definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), 150 | } 151 | ]; 152 | 153 | EmbeddingsBuilder::new(model) 154 | .documents(words) 155 | .unwrap() 156 | .build() 157 | .await 158 | .unwrap() 159 | } 160 | --------------------------------------------------------------------------------