├── src ├── tools │ ├── scraper │ │ ├── mod.rs │ │ └── scraper.rs │ ├── serpapi │ │ └── mod.rs │ ├── wolfram │ │ └── mod.rs │ ├── sql │ │ ├── postgres │ │ │ └── mod.rs │ │ └── mod.rs │ ├── text2speech │ │ ├── openai │ │ │ └── mod.rs │ │ ├── mod.rs │ │ └── speech_storage.rs │ ├── duckduckgo │ │ └── mod.rs │ ├── command_executor │ │ └── mod.rs │ ├── mod.rs │ └── tool.rs ├── document_loaders │ ├── csv_loader │ │ └── mod.rs │ ├── html_loader │ │ └── mod.rs │ ├── text_loader │ │ └── mod.rs │ ├── pandoc_loader │ │ └── mod.rs │ ├── git_commit_loader │ │ ├── mod.rs │ │ └── git_commit_loader.rs │ ├── html_to_markdown_loader │ │ └── mod.rs │ ├── test_data │ │ ├── sample.docx │ │ ├── sample.pdf │ │ ├── example.rs │ │ ├── test.csv │ │ └── example.html │ ├── pdf_loader │ │ └── mod.rs │ ├── source_code_loader │ │ └── mod.rs │ ├── mod.rs │ ├── error.rs │ └── document_loader.rs ├── embedding │ ├── ollama │ │ ├── mod.rs │ │ └── ollama_embedder.rs │ ├── openai │ │ ├── mod.rs │ │ └── openai_embedder.rs │ ├── mistralai │ │ ├── mod.rs │ │ └── mistralai_embedder.rs │ ├── fastembed │ │ ├── mod.rs │ │ └── fastembed.rs │ ├── embedder_trait.rs │ ├── mod.rs │ └── error.rs ├── llm │ ├── ollama │ │ ├── mod.rs │ │ └── openai.rs │ ├── qwen │ │ ├── mod.rs │ │ ├── error.rs │ │ └── models.rs │ ├── claude │ │ ├── mod.rs │ │ ├── error.rs │ │ └── models.rs │ ├── deepseek │ │ ├── mod.rs │ │ └── error.rs │ ├── test_data │ │ └── example.jpg │ └── mod.rs ├── chain │ ├── sequential │ │ ├── mod.rs │ │ └── builder.rs │ ├── stuff_documents │ │ ├── mod.rs │ │ └── builder.rs │ ├── conversational_retrieval_qa │ │ └── mod.rs │ ├── conversational │ │ └── prompt.rs │ ├── sql_datbase │ │ ├── mod.rs │ │ ├── prompt.rs │ │ └── builder.rs │ ├── mod.rs │ └── error.rs ├── vectorstore │ ├── qdrant │ │ └── mod.rs │ ├── surrealdb │ │ └── mod.rs │ ├── opensearch │ │ ├── mod.rs │ │ └── builder.rs │ ├── sqlite_vec │ │ ├── mod.rs │ │ └── builder.rs │ ├── sqlite_vss │ │ ├── mod.rs │ │ └── builder.rs │ ├── mod.rs │ ├── pgvector │ │ └── mod.rs │ ├── options.rs │ └── vectorstore.rs ├── agent │ ├── open_ai_tools │ │ ├── mod.rs │ │ ├── prompt.rs │ │ └── builder.rs │ ├── chat │ │ ├── mod.rs │ │ ├── builder.rs │ │ ├── prompt.rs │ │ └── output_parser.rs │ ├── mod.rs │ ├── agent.rs │ └── error.rs ├── semantic_router │ ├── index │ │ ├── mod.rs │ │ ├── error.rs │ │ ├── index.rs │ │ └── memory_index.rs │ ├── route_layer │ │ ├── mod.rs │ │ └── error.rs │ ├── mod.rs │ ├── utils.rs │ └── router.rs ├── memory │ ├── mod.rs │ ├── dummy_memory.rs │ ├── simple_memory.rs │ └── window_buffer.rs ├── output_parsers │ ├── mod.rs │ ├── error.rs │ ├── output_parser.rs │ ├── simple_parser.rs │ └── markdown_parser.rs ├── text_splitter │ ├── mod.rs │ ├── error.rs │ ├── markdown_splitter.rs │ ├── token_splitter.rs │ ├── text_splitter.rs │ ├── plain_text_splitter.rs │ └── options.rs ├── lib.rs ├── prompt │ ├── error.rs │ └── mod.rs ├── schemas │ ├── retrievers.rs │ ├── mod.rs │ ├── stream.rs │ ├── prompt.rs │ ├── agent.rs │ ├── response_format_openai_like.rs │ ├── memory.rs │ ├── convert.rs │ ├── document.rs │ └── tools_openai_like.rs └── language_models │ ├── error.rs │ ├── llm.rs │ └── mod.rs ├── .gitignore ├── .github ├── ISSUE_TEMPLATE │ ├── custom.md │ ├── feature_request.md │ └── bug_report.md └── workflows │ └── ci.yml ├── examples ├── llm_anthropic_claude.rs ├── wolfram_tool.rs ├── embedding_openai.rs ├── vector_store_surrealdb │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── llm_azure_open_ai.rs ├── llm_ollama.rs ├── embedding_azure_open_ai.rs ├── llm_alibaba_qwen.rs ├── embedding_mistralai.rs ├── embedding_ollama.rs ├── llm_deepseek.rs ├── llm_openai.rs ├── speech2text_openai.rs ├── qa_chain.rs ├── semantic_routes.rs ├── agent.rs ├── simple_chain.rs ├── embedding_fastembed.rs ├── streaming_from_chain.rs ├── vision_llm_chain.rs ├── dynamic_semantic_routes.rs ├── sequential_chain.rs ├── rcommiter.rs ├── llm_chain_qwen.rs ├── open_ai_tools_agent.rs ├── llm_chain_deepseek.rs ├── llm_qwen_advanced.rs ├── vector_store_qdrant.rs ├── conversational_chain.rs ├── vector_store_postgres.rs ├── vector_store_sqlite_vec.rs ├── sql_chain.rs ├── vector_store_sqlite_vss.rs ├── git_commits.rs ├── conversational_retriever_simple_chain.rs └── llm_chain.rs ├── scripts └── run-pgvector ├── renovate.json ├── LICENSE └── CONTRIBUTING.md /src/tools/scraper/mod.rs: -------------------------------------------------------------------------------- 1 | mod scraper; 2 | pub use scraper::*; 3 | -------------------------------------------------------------------------------- /src/tools/serpapi/mod.rs: -------------------------------------------------------------------------------- 1 | mod serpapi; 2 | pub use serpapi::*; 3 | -------------------------------------------------------------------------------- /src/tools/wolfram/mod.rs: -------------------------------------------------------------------------------- 1 | mod wolfram; 2 | pub use wolfram::*; 3 | -------------------------------------------------------------------------------- /src/tools/sql/postgres/mod.rs: -------------------------------------------------------------------------------- 1 | mod postgres; 2 | 3 | pub use postgres::*; 4 | -------------------------------------------------------------------------------- /src/tools/text2speech/openai/mod.rs: -------------------------------------------------------------------------------- 1 | mod client; 2 | pub use client::*; 3 | -------------------------------------------------------------------------------- /src/document_loaders/csv_loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod csv_loader; 2 | pub use csv_loader::*; 3 | -------------------------------------------------------------------------------- /src/document_loaders/html_loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod html_loader; 2 | pub use html_loader::*; 3 | -------------------------------------------------------------------------------- /src/document_loaders/text_loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod text_loader; 2 | pub use text_loader::*; 3 | -------------------------------------------------------------------------------- /src/embedding/ollama/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod ollama_embedder; 2 | pub use ollama_embedder::*; 3 | -------------------------------------------------------------------------------- /src/embedding/openai/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod openai_embedder; 2 | pub use openai_embedder::*; 3 | -------------------------------------------------------------------------------- /src/tools/duckduckgo/mod.rs: -------------------------------------------------------------------------------- 1 | mod duckduckgo_search; 2 | pub use duckduckgo_search::*; 3 | -------------------------------------------------------------------------------- /src/tools/command_executor/mod.rs: -------------------------------------------------------------------------------- 1 | mod command_executor; 2 | pub use command_executor::*; 3 | -------------------------------------------------------------------------------- /src/document_loaders/pandoc_loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod pandoc_loader; 2 | pub use pandoc_loader::*; 3 | -------------------------------------------------------------------------------- /src/embedding/mistralai/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod mistralai_embedder; 2 | pub use mistralai_embedder::*; 3 | -------------------------------------------------------------------------------- /src/llm/ollama/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "ollama")] 2 | pub mod client; 3 | 4 | pub mod openai; 5 | -------------------------------------------------------------------------------- /src/document_loaders/git_commit_loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod git_commit_loader; 2 | pub use git_commit_loader::*; 3 | -------------------------------------------------------------------------------- /src/chain/sequential/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod chain; 3 | 4 | pub use builder::*; 5 | pub use chain::*; 6 | -------------------------------------------------------------------------------- /src/llm/qwen/mod.rs: -------------------------------------------------------------------------------- 1 | mod client; 2 | mod models; 3 | pub use client::*; 4 | mod error; 5 | pub use error::*; 6 | -------------------------------------------------------------------------------- /src/tools/sql/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "postgres")] 2 | pub mod postgres; 3 | mod sql; 4 | 5 | pub use sql::*; 6 | -------------------------------------------------------------------------------- /src/chain/stuff_documents/mod.rs: -------------------------------------------------------------------------------- 1 | mod chain; 2 | pub use chain::*; 3 | 4 | mod builder; 5 | pub use builder::*; 6 | -------------------------------------------------------------------------------- /src/vectorstore/qdrant/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod qdrant; 3 | 4 | pub use builder::*; 5 | pub use qdrant::*; 6 | -------------------------------------------------------------------------------- /src/llm/claude/mod.rs: -------------------------------------------------------------------------------- 1 | mod models; 2 | 3 | mod client; 4 | pub use client::*; 5 | 6 | mod error; 7 | pub use error::*; 8 | -------------------------------------------------------------------------------- /src/llm/deepseek/mod.rs: -------------------------------------------------------------------------------- 1 | mod client; 2 | mod models; 3 | pub use client::*; 4 | 5 | mod error; 6 | pub use error::*; 7 | -------------------------------------------------------------------------------- /src/llm/test_data/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abraxas-365/langchain-rust/HEAD/src/llm/test_data/example.jpg -------------------------------------------------------------------------------- /src/vectorstore/surrealdb/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod surrealdb; 3 | 4 | pub use builder::*; 5 | pub use surrealdb::*; 6 | -------------------------------------------------------------------------------- /src/document_loaders/html_to_markdown_loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod html_to_markdown_loader; 2 | pub use html_to_markdown_loader::*; 3 | -------------------------------------------------------------------------------- /src/tools/text2speech/mod.rs: -------------------------------------------------------------------------------- 1 | mod openai; 2 | pub use openai::*; 3 | 4 | mod speech_storage; 5 | pub use speech_storage::*; 6 | -------------------------------------------------------------------------------- /src/vectorstore/opensearch/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod opensearch; 3 | 4 | pub use builder::*; 5 | pub use opensearch::*; 6 | -------------------------------------------------------------------------------- /src/vectorstore/sqlite_vec/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod sqlite_vec; 3 | 4 | pub use builder::*; 5 | pub use sqlite_vec::*; 6 | -------------------------------------------------------------------------------- /src/vectorstore/sqlite_vss/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod sqlite_vss; 3 | 4 | pub use builder::*; 5 | pub use sqlite_vss::*; 6 | -------------------------------------------------------------------------------- /src/agent/open_ai_tools/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | pub use builder::*; 3 | 4 | mod agent; 5 | pub use agent::*; 6 | 7 | mod prompt; 8 | -------------------------------------------------------------------------------- /src/document_loaders/test_data/sample.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abraxas-365/langchain-rust/HEAD/src/document_loaders/test_data/sample.docx -------------------------------------------------------------------------------- /src/document_loaders/test_data/sample.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Abraxas-365/langchain-rust/HEAD/src/document_loaders/test_data/sample.pdf -------------------------------------------------------------------------------- /src/agent/chat/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod chat_agent; 3 | mod output_parser; 4 | mod prompt; 5 | 6 | pub use builder::*; 7 | pub use chat_agent::*; 8 | -------------------------------------------------------------------------------- /src/document_loaders/pdf_loader/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "lopdf")] 2 | pub mod lo_loader; 3 | 4 | #[cfg(feature = "pdf-extract")] 5 | pub mod pdf_extract_loader; 6 | -------------------------------------------------------------------------------- /src/semantic_router/index/mod.rs: -------------------------------------------------------------------------------- 1 | mod index; 2 | pub use index::*; 3 | 4 | mod memory_index; 5 | pub use memory_index::*; 6 | 7 | mod error; 8 | pub use error::*; 9 | -------------------------------------------------------------------------------- /src/chain/conversational_retrieval_qa/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | pub use builder::*; 3 | 4 | mod conversational_retrieval_qa; 5 | pub use conversational_retrieval_qa::*; 6 | -------------------------------------------------------------------------------- /src/document_loaders/source_code_loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod source_code_loader; 2 | pub use source_code_loader::*; 3 | 4 | mod language_parsers; 5 | pub use language_parsers::*; 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | Cargo.lock 3 | .DS_Store 4 | .fastembed_cache 5 | .vscode 6 | .idea 7 | 8 | # Ignore files generated by text_to_speech example 9 | *.mp3 10 | -------------------------------------------------------------------------------- /src/semantic_router/route_layer/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | pub use builder::*; 3 | 4 | mod route_layer; 5 | pub use route_layer::*; 6 | 7 | mod error; 8 | pub use error::*; 9 | -------------------------------------------------------------------------------- /src/memory/mod.rs: -------------------------------------------------------------------------------- 1 | mod dummy_memory; 2 | mod simple_memory; 3 | mod window_buffer; 4 | 5 | pub use dummy_memory::*; 6 | pub use simple_memory::*; 7 | pub use window_buffer::*; 8 | -------------------------------------------------------------------------------- /src/embedding/fastembed/mod.rs: -------------------------------------------------------------------------------- 1 | mod fastembed; 2 | pub use fastembed::*; 3 | 4 | extern crate fastembed as ext_fastembed; 5 | pub use ext_fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; 6 | -------------------------------------------------------------------------------- /src/semantic_router/mod.rs: -------------------------------------------------------------------------------- 1 | mod router; 2 | pub use router::*; 3 | 4 | mod route_layer; 5 | pub use route_layer::*; 6 | 7 | mod index; 8 | pub use index::*; 9 | 10 | pub mod utils; 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /src/output_parsers/mod.rs: -------------------------------------------------------------------------------- 1 | mod output_parser; 2 | pub use output_parser::*; 3 | 4 | mod markdown_parser; 5 | pub use markdown_parser::*; 6 | 7 | mod simple_parser; 8 | pub use simple_parser::*; 9 | 10 | mod error; 11 | pub use error::*; 12 | -------------------------------------------------------------------------------- /src/agent/mod.rs: -------------------------------------------------------------------------------- 1 | mod agent; 2 | pub use agent::*; 3 | 4 | mod executor; 5 | pub use executor::*; 6 | 7 | mod chat; 8 | pub use chat::*; 9 | 10 | mod open_ai_tools; 11 | pub use open_ai_tools::*; 12 | 13 | mod error; 14 | pub use error::*; 15 | -------------------------------------------------------------------------------- /src/llm/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod openai; 2 | pub use openai::*; 3 | 4 | pub mod claude; 5 | pub use claude::*; 6 | 7 | pub mod ollama; 8 | pub use ollama::*; 9 | 10 | pub mod qwen; 11 | pub use qwen::*; 12 | 13 | pub mod deepseek; 14 | pub use deepseek::*; 15 | -------------------------------------------------------------------------------- /src/tools/text2speech/speech_storage.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use async_trait::async_trait; 4 | 5 | #[async_trait] 6 | pub trait SpeechStorage: Send + Sync { 7 | async fn save(&self, key: &str, data: &[u8]) -> Result>; 8 | } 9 | -------------------------------------------------------------------------------- /examples/llm_anthropic_claude.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{language_models::llm::LLM, llm::Claude}; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | let claude = Claude::default().with_model("claude-3-opus-20240229"); 6 | let response = claude.invoke("hola").await.unwrap(); 7 | println!("{}", response); 8 | } 9 | -------------------------------------------------------------------------------- /src/output_parsers/error.rs: -------------------------------------------------------------------------------- 1 | use regex::Error as RegexError; 2 | use thiserror::Error; 3 | 4 | #[derive(Error, Debug)] 5 | pub enum OutputParserError { 6 | #[error("Regex error: {0}")] 7 | RegexError(#[from] RegexError), 8 | 9 | #[error("Parsing error: {0}")] 10 | ParsingError(String), 11 | } 12 | -------------------------------------------------------------------------------- /examples/wolfram_tool.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::tools::{Tool, Wolfram}; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | let wolfram = Wolfram::default().with_excludes(&["Plot"]); 6 | let input = "Solve x^2 - 2x + 1 = 0"; 7 | let result = wolfram.call(input).await; 8 | 9 | println!("{}", result.unwrap()); 10 | } 11 | -------------------------------------------------------------------------------- /examples/embedding_openai.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::embedding::{embedder_trait::Embedder, openai::OpenAiEmbedder}; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | let openai = OpenAiEmbedder::default(); 6 | 7 | let response = openai.embed_query("What is the sky blue?").await.unwrap(); 8 | 9 | println!("{:?}", response); 10 | } 11 | -------------------------------------------------------------------------------- /src/text_splitter/mod.rs: -------------------------------------------------------------------------------- 1 | mod error; 2 | mod markdown_splitter; 3 | mod options; 4 | mod plain_text_splitter; 5 | mod text_splitter; 6 | mod token_splitter; 7 | 8 | pub use error::*; 9 | pub use markdown_splitter::*; 10 | pub use options::*; 11 | pub use plain_text_splitter::*; 12 | pub use text_splitter::*; 13 | pub use token_splitter::*; 14 | -------------------------------------------------------------------------------- /src/semantic_router/index/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum IndexError { 5 | #[error("No Emedding on Route: {0}")] 6 | MissingEmbedding(String), 7 | 8 | #[error("Error: {0}")] 9 | OtherError(String), 10 | 11 | #[error("No Route found: {0}")] 12 | RouterNotFound(String), 13 | } 14 | -------------------------------------------------------------------------------- /scripts/run-pgvector: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | docker run -d \ 3 | -e POSTGRES_DB=langchain-rust \ 4 | -e POSTGRES_USER=username \ 5 | -e POSTGRES_PASSWORD=password \ 6 | -e PGDATA=/var/lib/postgresql/data/pgdata \ 7 | -v pgvolume-langchain-rust:/var/lib/postgresql/data \ 8 | -p 5432:5432 \ 9 | --name pgvector \ 10 | pgvector/pgvector:pg16 11 | -------------------------------------------------------------------------------- /src/embedding/embedder_trait.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | use super::EmbedderError; 4 | 5 | #[async_trait] 6 | pub trait Embedder: Send + Sync { 7 | async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError>; 8 | async fn embed_query(&self, text: &str) -> Result, EmbedderError>; 9 | } 10 | -------------------------------------------------------------------------------- /src/chain/conversational/prompt.rs: -------------------------------------------------------------------------------- 1 | pub const DEFAULT_TEMPLATE: &str = r#"The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. 2 | 3 | Current conversation: 4 | {history} 5 | Human: {input} 6 | AI: 7 | "#; 8 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | pub mod agent; 3 | pub mod chain; 4 | pub mod document_loaders; 5 | pub mod embedding; 6 | pub mod language_models; 7 | pub mod llm; 8 | pub mod memory; 9 | pub mod output_parsers; 10 | pub mod prompt; 11 | pub mod schemas; 12 | pub mod semantic_router; 13 | pub mod text_splitter; 14 | pub mod tools; 15 | pub mod vectorstore; 16 | 17 | pub use url; 18 | -------------------------------------------------------------------------------- /src/tools/mod.rs: -------------------------------------------------------------------------------- 1 | mod tool; 2 | pub use tool::*; 3 | 4 | pub use wolfram::*; 5 | mod wolfram; 6 | 7 | mod scraper; 8 | pub use scraper::*; 9 | 10 | mod sql; 11 | pub use sql::*; 12 | 13 | mod duckduckgo; 14 | pub use duckduckgo::*; 15 | 16 | mod serpapi; 17 | pub use serpapi::*; 18 | 19 | mod command_executor; 20 | pub use command_executor::*; 21 | 22 | mod text2speech; 23 | pub use text2speech::*; 24 | -------------------------------------------------------------------------------- /src/prompt/error.rs: -------------------------------------------------------------------------------- 1 | use serde_json::Error as SerdeJsonError; 2 | use thiserror::Error; 3 | 4 | #[derive(Error, Debug)] 5 | pub enum PromptError { 6 | #[error("Variable {0} is missing from input variables")] 7 | MissingVariable(String), 8 | 9 | #[error("Serialization error: {0}")] 10 | SerializationError(#[from] SerdeJsonError), 11 | 12 | #[error("Error: {0}")] 13 | OtherError(String), 14 | } 15 | -------------------------------------------------------------------------------- /src/chain/sql_datbase/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod chain; 3 | mod prompt; 4 | 5 | pub use builder::*; 6 | pub use chain::*; 7 | pub use prompt::*; 8 | 9 | const STOP_WORD: &str = "\nSQLResult:"; 10 | const SQL_CHAIN_DEFAULT_INPUT_KEY_QUERY: &str = "query"; 11 | const SQL_CHAIN_DEFAULT_INPUT_KEY_TABLE_NAMES: &str = "table_names_to_use"; 12 | const SQL_CHAIN_DEFAULT_OUTPUT_KEY: &str = "result"; 13 | const QUERY_PREFIX_WITH: &str = "\nSQLQuery:"; 14 | -------------------------------------------------------------------------------- /src/output_parsers/output_parser.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | use super::OutputParserError; 4 | 5 | #[async_trait] 6 | pub trait OutputParser: Send + Sync { 7 | async fn parse(&self, output: &str) -> Result; 8 | } 9 | 10 | impl

From

for Box 11 | where 12 | P: OutputParser + 'static, 13 | { 14 | fn from(parser: P) -> Self { 15 | Box::new(parser) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/document_loaders/test_data/example.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | println!("Hello, world!"); 3 | } 4 | 5 | pub struct Person { 6 | name: String, 7 | age: i32, 8 | } 9 | 10 | impl Person { 11 | pub fn new(name: String, age: i32) -> Self { 12 | Self { name, age } 13 | } 14 | 15 | pub fn get_name(&self) -> &str { 16 | &self.name 17 | } 18 | 19 | pub fn get_age(&self) -> i32 { 20 | self.age 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /examples/vector_store_surrealdb/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "vector_store_surrealdb" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | anyhow = "1.0.81" 10 | langchain-rust = { path = "../..", features = ["surrealdb"] } 11 | surrealdb = { version = "2.0.2", features = ["kv-mem"] } 12 | tokio = { version = "1.36.0", features = ["full"] } 13 | -------------------------------------------------------------------------------- /src/schemas/retrievers.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use async_trait::async_trait; 4 | 5 | use super::Document; 6 | 7 | #[async_trait] 8 | pub trait Retriever: Sync + Send { 9 | async fn get_relevant_documents(&self, query: &str) -> Result, Box>; 10 | } 11 | 12 | impl From for Box 13 | where 14 | R: Retriever + 'static, 15 | { 16 | fn from(retriever: R) -> Self { 17 | Box::new(retriever) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/vectorstore/mod.rs: -------------------------------------------------------------------------------- 1 | mod options; 2 | 3 | #[cfg(feature = "postgres")] 4 | pub mod pgvector; 5 | 6 | #[cfg(feature = "sqlite-vss")] 7 | pub mod sqlite_vss; 8 | 9 | #[cfg(feature = "sqlite-vec")] 10 | pub mod sqlite_vec; 11 | 12 | #[cfg(feature = "surrealdb")] 13 | pub mod surrealdb; 14 | 15 | #[cfg(feature = "opensearch")] 16 | pub mod opensearch; 17 | 18 | #[cfg(feature = "qdrant")] 19 | pub mod qdrant; 20 | 21 | mod vectorstore; 22 | 23 | pub use options::*; 24 | pub use vectorstore::*; 25 | -------------------------------------------------------------------------------- /src/embedding/mod.rs: -------------------------------------------------------------------------------- 1 | mod error; 2 | 3 | pub mod embedder_trait; 4 | pub use embedder_trait::*; 5 | 6 | #[cfg(feature = "ollama")] 7 | pub mod ollama; 8 | #[cfg(feature = "ollama")] 9 | pub use ollama::*; 10 | 11 | pub mod openai; 12 | pub use error::*; 13 | 14 | #[cfg(feature = "fastembed")] 15 | mod fastembed; 16 | #[cfg(feature = "fastembed")] 17 | pub use fastembed::*; 18 | 19 | #[cfg(feature = "mistralai")] 20 | pub mod mistralai; 21 | #[cfg(feature = "mistralai")] 22 | pub use mistralai::*; 23 | -------------------------------------------------------------------------------- /src/agent/agent.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use async_trait::async_trait; 4 | 5 | use crate::{ 6 | prompt::PromptArgs, 7 | schemas::agent::{AgentAction, AgentEvent}, 8 | tools::Tool, 9 | }; 10 | 11 | use super::AgentError; 12 | 13 | #[async_trait] 14 | pub trait Agent: Send + Sync { 15 | async fn plan( 16 | &self, 17 | intermediate_steps: &[(AgentAction, String)], 18 | inputs: PromptArgs, 19 | ) -> Result; 20 | 21 | fn get_tools(&self) -> Vec>; 22 | } 23 | -------------------------------------------------------------------------------- /src/schemas/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod agent; 2 | pub use agent::*; 3 | 4 | pub mod memory; 5 | pub use memory::*; 6 | 7 | pub mod messages; 8 | pub use messages::*; 9 | 10 | pub mod prompt; 11 | pub use prompt::*; 12 | 13 | pub mod document; 14 | pub use document::*; 15 | 16 | mod retrievers; 17 | pub use retrievers::*; 18 | 19 | mod tools_openai_like; 20 | pub use tools_openai_like::*; 21 | 22 | pub mod response_format_openai_like; 23 | pub use response_format_openai_like::*; 24 | 25 | pub mod convert; 26 | mod stream; 27 | 28 | pub use stream::*; 29 | -------------------------------------------------------------------------------- /src/chain/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod chain_trait; 2 | pub use chain_trait::*; 3 | 4 | pub mod conversational; 5 | pub use conversational::*; 6 | 7 | pub use llm_chain::*; 8 | pub mod llm_chain; 9 | 10 | mod sequential; 11 | pub use sequential::*; 12 | 13 | pub mod sql_datbase; 14 | pub use sql_datbase::*; 15 | 16 | mod stuff_documents; 17 | pub use stuff_documents::*; 18 | 19 | mod question_answering; 20 | pub use question_answering::*; 21 | 22 | mod conversational_retrieval_qa; 23 | pub use conversational_retrieval_qa::*; 24 | 25 | mod error; 26 | pub use error::*; 27 | 28 | pub mod options; 29 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "config:recommended" 5 | ], 6 | "stabilityDays": "3", 7 | "timezone": "America/Los_Angeles", 8 | "schedule": "after 1am every 3 weeks on Saturday", 9 | "packageRules": [ 10 | { 11 | "matchPackagePatterns": ["tokio", "tokio-test"], 12 | "groupName": "tokio" 13 | }, 14 | { 15 | "matchPackagePatterns": ["pgvector", "sqlx"], 16 | "groupName": "sqlx" 17 | }, 18 | { 19 | "matchPackagePatterns": ["tree-sitter.*"], 20 | "groupName": "tree-sitter" 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /examples/llm_azure_open_ai.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | language_models::llm::LLM, 3 | llm::openai::{AzureConfig, OpenAI}, 4 | }; 5 | 6 | #[tokio::main] 7 | async fn main() { 8 | let azure_config = AzureConfig::default() 9 | .with_api_key("REPLACE_ME_WITH_YOUR_API_KEY") 10 | .with_api_base("https://REPLACE_ME.openai.azure.com") 11 | .with_api_version("2024-02-15-preview") 12 | .with_deployment_id("chatGPT_GPT35-turbo-0301"); 13 | 14 | let open_ai = OpenAI::new(azure_config); 15 | let response = open_ai.invoke("Why is the sky blue?").await.unwrap(); 16 | println!("{}", response); 17 | } 18 | -------------------------------------------------------------------------------- /examples/llm_ollama.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "ollama")] 2 | use langchain_rust::{language_models::llm::LLM, llm::ollama::client::Ollama}; 3 | 4 | #[cfg(feature = "ollama")] 5 | #[tokio::main] 6 | async fn main() { 7 | let ollama = Ollama::default().with_model("llama3.2"); 8 | 9 | let response = ollama.invoke("Hi").await.unwrap(); 10 | println!("{}", response); 11 | } 12 | 13 | #[cfg(not(feature = "ollama"))] 14 | fn main() { 15 | println!("This example requires the 'ollama' feature to be enabled."); 16 | println!("Please run the command as follows:"); 17 | println!("cargo run --example llm_ollama --features=ollama"); 18 | } 19 | -------------------------------------------------------------------------------- /examples/embedding_azure_open_ai.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::embedding::{ 2 | embedder_trait::Embedder, 3 | openai::openai_embedder::{AzureConfig, OpenAiEmbedder}, 4 | }; 5 | 6 | #[tokio::main] 7 | async fn main() { 8 | let azure_config = AzureConfig::default() 9 | .with_api_key("REPLACE_ME_WITH_YOUR_API_KEY") 10 | .with_api_base("https://REPLACE_ME.openai.azure.com") 11 | .with_api_version("2023-05-15") 12 | .with_deployment_id("text-embedding-ada-002"); 13 | 14 | let embedder = OpenAiEmbedder::new(azure_config); 15 | let result = embedder.embed_query("Why is the sky blue?").await.unwrap(); 16 | println!("{:?}", result); 17 | } 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /examples/llm_alibaba_qwen.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::language_models::llm::LLM; 2 | use langchain_rust::llm::Qwen; 3 | use langchain_rust::schemas::Message; 4 | 5 | #[tokio::main] 6 | async fn main() { 7 | // Initialize the Qwen client 8 | // Requires QWEN_API_KEY environment variable to be set 9 | let qwen = Qwen::new() 10 | .with_api_key("your_api_key") 11 | .with_model("qwen-turbo"); // Can use enum: QwenModel::QwenTurbo.to_string() 12 | 13 | // Generate a response 14 | let response = qwen 15 | .generate(&[Message::new_human_message("Introduce the Great Wall")]) 16 | .await 17 | .unwrap(); 18 | 19 | println!("Response: {}", response.generation); 20 | } 21 | -------------------------------------------------------------------------------- /examples/embedding_mistralai.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "mistralai")] 2 | use langchain_rust::embedding::{embedder_trait::Embedder, mistralai::MistralAIEmbedder}; 3 | 4 | #[cfg(feature = "mistralai")] 5 | #[tokio::main] 6 | async fn main() { 7 | let mistralai = MistralAIEmbedder::try_new().unwrap(); 8 | 9 | let embedding = mistralai.embed_query("Why is the sky blue?").await.unwrap(); 10 | 11 | println!("{:?}", embedding); 12 | } 13 | 14 | #[cfg(not(feature = "mistralai"))] 15 | fn main() { 16 | println!("This example requires the 'mistralai' feature to be enabled."); 17 | println!("Please run the command as follows:"); 18 | println!("cargo run --example embedding_mistralai --features=mistralai"); 19 | } 20 | -------------------------------------------------------------------------------- /examples/embedding_ollama.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "ollama")] 2 | use langchain_rust::embedding::{ 3 | embedder_trait::Embedder, ollama::ollama_embedder::OllamaEmbedder, 4 | }; 5 | 6 | #[cfg(feature = "ollama")] 7 | #[tokio::main] 8 | async fn main() { 9 | let ollama = OllamaEmbedder::default().with_model("nomic-embed-text"); 10 | 11 | let response = ollama.embed_query("Why is the sky blue?").await.unwrap(); 12 | 13 | println!("{:?}", response); 14 | } 15 | 16 | #[cfg(not(feature = "ollama"))] 17 | fn main() { 18 | println!("This example requires the 'ollama' feature to be enabled."); 19 | println!("Please run the command as follows:"); 20 | println!("cargo run --example embedding_ollama --features=ollama"); 21 | } 22 | -------------------------------------------------------------------------------- /examples/llm_deepseek.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::language_models::llm::LLM; 2 | use langchain_rust::llm::Deepseek; 3 | use langchain_rust::schemas::Message; 4 | 5 | #[tokio::main] 6 | async fn main() { 7 | // Initialize the Deepseek client 8 | // Requires DEEPSEEK_API_KEY environment variable to be set 9 | let deepseek = Deepseek::new() 10 | .with_api_key("your_api_key") 11 | .with_model("deepseek-chat"); // Can use enum: DeepseekModel::DeepseekChat.to_string() 12 | 13 | // Generate a response 14 | let response = deepseek 15 | .generate(&[Message::new_human_message("Introduce the Great Wall")]) 16 | .await 17 | .unwrap(); 18 | 19 | println!("Response: {}", response.generation); 20 | } 21 | -------------------------------------------------------------------------------- /examples/llm_openai.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::llm::OpenAIConfig; 2 | 3 | use langchain_rust::{language_models::llm::LLM, llm::openai::OpenAI}; 4 | 5 | #[tokio::main] 6 | async fn main() { 7 | //OpenAI Example 8 | let open_ai = OpenAI::default(); 9 | let response = open_ai.invoke("hola").await.unwrap(); 10 | println!("{}", response); 11 | 12 | //or we can set config as 13 | let open_ai = OpenAI::default().with_config( 14 | OpenAIConfig::default() 15 | .with_api_base("xxx") //if you want to specify base url 16 | .with_api_key(""), //if you want to set you open ai key, 17 | ); 18 | 19 | let response = open_ai.invoke("hola").await.unwrap(); 20 | println!("{}", response); 21 | } 22 | -------------------------------------------------------------------------------- /src/schemas/stream.rs: -------------------------------------------------------------------------------- 1 | use serde_json::Value; 2 | use std::io::{self, Write}; 3 | 4 | use crate::language_models::TokenUsage; 5 | 6 | #[derive(Debug, Clone)] 7 | pub struct StreamData { 8 | pub value: Value, 9 | pub tokens: Option, 10 | pub content: String, 11 | } 12 | 13 | impl StreamData { 14 | pub fn new>(value: Value, tokens: Option, content: S) -> Self { 15 | Self { 16 | value, 17 | tokens, 18 | content: content.into(), 19 | } 20 | } 21 | 22 | pub fn to_stdout(&self) -> io::Result<()> { 23 | let stdout = io::stdout(); 24 | let mut handle = stdout.lock(); 25 | write!(handle, "{}", self.content)?; 26 | handle.flush() 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /examples/speech2text_openai.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use langchain_rust::tools::{SpeechStorage, Text2SpeechOpenAI, Tool}; 3 | 4 | #[allow(dead_code)] 5 | struct XStorage {} 6 | 7 | //You can add save te result to s3 or other storage using 8 | 9 | #[async_trait] 10 | impl SpeechStorage for XStorage { 11 | async fn save(&self, path: &str, _data: &[u8]) -> Result> { 12 | println!("Saving to: {}", path); 13 | Ok(path.to_string()) 14 | } 15 | } 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | let openai = Text2SpeechOpenAI::default().with_path("./data/audio.mp3"); 20 | // .with_storage(XStorage {}); 21 | 22 | let path = openai.call("Hi, My name is Luis").await.unwrap(); 23 | println!("Path: {}", path); 24 | } 25 | -------------------------------------------------------------------------------- /src/document_loaders/test_data/test.csv: -------------------------------------------------------------------------------- 1 | name,age,city,country 2 | John Doe,25,New York,United States 3 | Jane Smith,32,London,United Kingdom 4 | Alex Johnson,42,Sydney,Australia 5 | Emma Davis,29,Paris,France 6 | Michael Lee,37,Toronto,Canada 7 | Sophia Wilson,22,Berlin,Germany 8 | Matthew Brown,31,Tokyo,Japan 9 | Olivia Taylor,27,Rome,Italy 10 | David Anderson,35,Moscow,Russia 11 | Emily Thomas,30,Madrid,Spain 12 | Daniel Martinez,26,Mexico City,Mexico 13 | Isabella Lopez,28,Seoul,South Korea 14 | Christopher Harris,33,Cairo,Egypt 15 | Ava Clark,24,Sao Paulo,Brazil 16 | James Wright,39,Amsterdam,Netherlands 17 | Mia Hall,23,Stockholm,Sweden 18 | Andrew Allen,34,Beijing,China 19 | Abigail Baker,36,New Delhi,India 20 | Ryan Scott,41,Johannesburg,South Africa 21 | Grace Green,38,Zurich,Switzerland 22 | -------------------------------------------------------------------------------- /src/llm/claude/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum AnthropicError { 5 | #[error("Anthropic API error: Invalid request - {0}")] 6 | InvalidRequestError(String), 7 | 8 | #[error("Anthropic API error: Authentication failed - {0}")] 9 | AuthenticationError(String), 10 | 11 | #[error("Anthropic API error: Permission denied - {0}")] 12 | PermissionError(String), 13 | 14 | #[error("Anthropic API error: Not found - {0}")] 15 | NotFoundError(String), 16 | 17 | #[error("Anthropic API error: Rate limit exceeded - {0}")] 18 | RateLimitError(String), 19 | 20 | #[error("Anthropic API error: Internal error - {0}")] 21 | ApiError(String), 22 | 23 | #[error("Anthropic API error: Overloaded - {0}")] 24 | OverloadedError(String), 25 | } 26 | -------------------------------------------------------------------------------- /src/memory/dummy_memory.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use tokio::sync::Mutex; 4 | 5 | use crate::schemas::{memory::BaseMemory, messages::Message}; 6 | 7 | pub struct DummyMemory {} 8 | 9 | impl DummyMemory { 10 | pub fn new() -> Self { 11 | Self {} 12 | } 13 | } 14 | 15 | impl Into> for DummyMemory { 16 | fn into(self) -> Arc { 17 | Arc::new(self) 18 | } 19 | } 20 | 21 | impl Into>> for DummyMemory { 22 | fn into(self) -> Arc> { 23 | Arc::new(Mutex::new(self)) 24 | } 25 | } 26 | 27 | impl BaseMemory for DummyMemory { 28 | fn messages(&self) -> Vec { 29 | vec![] 30 | } 31 | fn add_message(&mut self, _message: Message) {} 32 | fn clear(&mut self) {} 33 | } 34 | -------------------------------------------------------------------------------- /src/output_parsers/simple_parser.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | use super::{OutputParser, OutputParserError}; 4 | 5 | pub struct SimpleParser { 6 | trim: bool, 7 | } 8 | impl SimpleParser { 9 | pub fn new() -> Self { 10 | Self { trim: false } 11 | } 12 | pub fn with_trim(mut self, trim: bool) -> Self { 13 | self.trim = trim; 14 | self 15 | } 16 | } 17 | impl Default for SimpleParser { 18 | fn default() -> Self { 19 | Self::new() 20 | } 21 | } 22 | 23 | #[async_trait] 24 | impl OutputParser for SimpleParser { 25 | async fn parse(&self, output: &str) -> Result { 26 | if self.trim { 27 | Ok(output.trim().to_string()) 28 | } else { 29 | Ok(output.to_string()) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/agent/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use crate::{chain::ChainError, language_models::LLMError, prompt::PromptError}; 4 | 5 | #[derive(Error, Debug)] 6 | pub enum AgentError { 7 | #[error("LLM error: {0}")] 8 | LLMError(#[from] LLMError), 9 | 10 | #[error("Chain error: {0}")] 11 | ChainError(#[from] ChainError), 12 | 13 | #[error("Prompt error: {0}")] 14 | PromptError(#[from] PromptError), 15 | 16 | #[error("Tool error: {0}")] 17 | ToolError(String), 18 | 19 | #[error("Missing Object On Builder: {0}")] 20 | MissingObject(String), 21 | 22 | #[error("Missing input variable: {0}")] 23 | MissingInputVariable(String), 24 | 25 | #[error("Serde json error: {0}")] 26 | SerdeJsonError(#[from] serde_json::Error), 27 | 28 | #[error("Error: {0}")] 29 | OtherError(String), 30 | } 31 | -------------------------------------------------------------------------------- /src/schemas/prompt.rs: -------------------------------------------------------------------------------- 1 | use super::messages::Message; 2 | 3 | #[derive(Debug, Clone)] 4 | pub struct PromptValue { 5 | messages: Vec, 6 | } 7 | impl PromptValue { 8 | pub fn from_string(text: &str) -> Self { 9 | let message = Message::new_human_message(text); 10 | Self { 11 | messages: vec![message], 12 | } 13 | } 14 | pub fn from_messages(messages: Vec) -> Self { 15 | Self { messages } 16 | } 17 | 18 | pub fn to_string(&self) -> String { 19 | self.messages 20 | .iter() 21 | .map(|m| format!("{}: {}", m.message_type.to_string(), m.content)) 22 | .collect::>() 23 | .join("\n") 24 | } 25 | 26 | pub fn to_chat_messages(&self) -> Vec { 27 | self.messages.clone() 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/llm/deepseek/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum DeepseekError { 5 | #[error("Deepseek API error: Invalid Format - {0}")] 6 | InvalidFormatError(String), 7 | 8 | #[error("Deepseek API error: Authentication Failed - {0}")] 9 | AuthenticationError(String), 10 | 11 | #[error("Deepseek API error: Insufficient Balance - {0}")] 12 | InsufficientBalanceError(String), 13 | 14 | #[error("Deepseek API error: Invalid Parameters - {0}")] 15 | InvalidParametersError(String), 16 | 17 | #[error("Deepseek API error: Rate Limit Reached - {0}")] 18 | RateLimitError(String), 19 | 20 | #[error("Deepseek API error: Server Error - {0}")] 21 | ServerError(String), 22 | 23 | #[error("Deepseek API error: Server Overloaded - {0}")] 24 | ServerOverloadedError(String), 25 | } 26 | -------------------------------------------------------------------------------- /src/text_splitter/error.rs: -------------------------------------------------------------------------------- 1 | use text_splitter::ChunkConfigError; 2 | use thiserror::Error; 3 | 4 | #[derive(Error, Debug)] 5 | pub enum TextSplitterError { 6 | #[error("Empty input text")] 7 | EmptyInputText, 8 | 9 | #[error("Mismatch metadata and text")] 10 | MetadataTextMismatch, 11 | 12 | #[error("Tokenizer not found")] 13 | TokenizerNotFound, 14 | 15 | #[error("Tokenizer creation failed due to invalid tokenizer")] 16 | InvalidTokenizer, 17 | 18 | #[error("Tokenizer creation failed due to invalid model")] 19 | InvalidModel, 20 | 21 | #[error("Invalid chunk overlap and size")] 22 | InvalidSplitterOptions, 23 | 24 | #[error("Error: {0}")] 25 | OtherError(String), 26 | } 27 | 28 | impl From for TextSplitterError { 29 | fn from(_: ChunkConfigError) -> Self { 30 | Self::InvalidSplitterOptions 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/semantic_router/index/index.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | use crate::semantic_router::{IndexError, Router}; 4 | 5 | #[async_trait] 6 | pub trait Index { 7 | async fn add(&mut self, router: &[Router]) -> Result<(), IndexError>; 8 | 9 | async fn delete(&mut self, route_name: &str) -> Result<(), IndexError>; 10 | 11 | /// Query the index with a vector and return the top_k most similar routes. 12 | /// Returns a list of tuples with the route name and the similarity score. 13 | /// Result> 14 | async fn query(&self, vector: &[f64], top_k: usize) -> Result, IndexError>; 15 | 16 | async fn get_routers(&self) -> Result, IndexError>; 17 | 18 | async fn get_router(&self, route_name: &str) -> Result; 19 | 20 | async fn delete_index(&mut self) -> Result<(), IndexError>; 21 | } 22 | -------------------------------------------------------------------------------- /src/vectorstore/pgvector/mod.rs: -------------------------------------------------------------------------------- 1 | mod builder; 2 | mod pgvector; 3 | 4 | pub use builder::*; 5 | pub use pgvector::*; 6 | 7 | // pgLockIDEmbeddingTable is used for advisor lock to fix issue arising from concurrent 8 | // creation of the embedding table.The same value represents the same lock. 9 | const PG_LOCK_ID_EMBEDDING_TABLE: i64 = 1573678846307946494; 10 | // pgLockIDCollectionTable is used for advisor lock to fix issue arising from concurrent 11 | // creation of the collection table.The same value represents the same lock. 12 | const PG_LOCK_ID_COLLECTION_TABLE: i64 = 1573678846307946495; 13 | // pgLockIDExtension is used for advisor lock to fix issue arising from concurrent creation 14 | // of the vector extension. The value is deliberately set to the same as python langchain 15 | // https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167 16 | const PG_LOCKID_EXTENSION: i64 = 1573678846307946496; 17 | -------------------------------------------------------------------------------- /src/document_loaders/mod.rs: -------------------------------------------------------------------------------- 1 | mod document_loader; 2 | pub use document_loader::*; 3 | 4 | mod text_loader; 5 | pub use text_loader::*; 6 | 7 | mod csv_loader; 8 | pub use csv_loader::*; 9 | 10 | #[cfg(feature = "git")] 11 | mod git_commit_loader; 12 | #[cfg(feature = "git")] 13 | pub use git_commit_loader::*; 14 | 15 | mod pandoc_loader; 16 | pub use pandoc_loader::*; 17 | 18 | #[cfg(any(feature = "lopdf", feature = "pdf-extract"))] 19 | mod pdf_loader; 20 | #[cfg(any(feature = "lopdf", feature = "pdf-extract"))] 21 | pub use pdf_loader::*; 22 | 23 | mod html_loader; 24 | pub use html_loader::*; 25 | 26 | #[cfg(feature = "html-to-markdown")] 27 | mod html_to_markdown_loader; 28 | #[cfg(feature = "html-to-markdown")] 29 | pub use html_to_markdown_loader::*; 30 | 31 | mod error; 32 | pub use error::*; 33 | 34 | mod dir_loader; 35 | pub use dir_loader::*; 36 | 37 | #[cfg(feature = "tree-sitter")] 38 | mod source_code_loader; 39 | #[cfg(feature = "tree-sitter")] 40 | pub use source_code_loader::*; 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /src/memory/simple_memory.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use tokio::sync::Mutex; 4 | 5 | use crate::schemas::{memory::BaseMemory, messages::Message}; 6 | 7 | pub struct SimpleMemory { 8 | messages: Vec, 9 | } 10 | 11 | impl SimpleMemory { 12 | pub fn new() -> Self { 13 | Self { 14 | messages: Vec::new(), 15 | } 16 | } 17 | } 18 | 19 | impl Into> for SimpleMemory { 20 | fn into(self) -> Arc { 21 | Arc::new(self) 22 | } 23 | } 24 | 25 | impl Into>> for SimpleMemory { 26 | fn into(self) -> Arc> { 27 | Arc::new(Mutex::new(self)) 28 | } 29 | } 30 | 31 | impl BaseMemory for SimpleMemory { 32 | fn messages(&self) -> Vec { 33 | self.messages.clone() 34 | } 35 | fn add_message(&mut self, message: Message) { 36 | self.messages.push(message); 37 | } 38 | fn clear(&mut self) { 39 | self.messages.clear(); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /examples/qa_chain.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | chain::{Chain, StuffDocumentBuilder}, 3 | llm::openai::OpenAI, 4 | prompt_args, 5 | schemas::Document, 6 | }; 7 | 8 | #[tokio::main] 9 | async fn main() { 10 | let llm = OpenAI::default(); 11 | 12 | let chain = StuffDocumentBuilder::new() 13 | .llm(llm) 14 | // .prompt() you can add a custom prompt if you want 15 | .build() 16 | .unwrap(); 17 | let input = prompt_args! { 18 | "input_documents"=>vec![ 19 | Document::new(format!( 20 | "\nQuestion: {}\nAnswer: {}\n", 21 | "Which is the favorite text editor of luis", "Nvim" 22 | )), 23 | Document::new(format!( 24 | "\nQuestion: {}\nAnswer: {}\n", 25 | "How old is Luis", "24" 26 | )), 27 | ], 28 | "question"=>"How old is luis and whats his favorite text editor" 29 | }; 30 | 31 | let output = chain.invoke(input).await.unwrap(); 32 | 33 | println!("{}", output); 34 | } 35 | -------------------------------------------------------------------------------- /src/document_loaders/test_data/example.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Chew dad's slippers 4 | 5 | 6 |

7 | Instead of drinking water from the cat bowl, make sure to steal water from 8 | the toilet 9 |

10 |

Chase the red dot

11 |

12 | Munch, munch, chomp, chomp hate dogs. Spill litter box, scratch at owner, 13 | destroy all furniture, especially couch get scared by sudden appearance of 14 | cucumber cat is love, cat is life fat baby cat best buddy little guy for 15 | catch eat throw up catch eat throw up bad birds jump on fridge. Purr like 16 | a car engine oh yes, there is my human woman she does best pats ever that 17 | all i like about her hiss meow . 18 |

19 |

20 | Dead stare with ears cocked when owners are asleep, cry for no apparent 21 | reason meow all night. Plop down in the middle where everybody walks favor 22 | packaging over toy. Sit on the laptop kitty pounce, trip, faceplant. 23 |

24 | 25 | 26 | -------------------------------------------------------------------------------- /src/schemas/agent.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | use tokio::sync::mpsc; 5 | 6 | pub enum ToolInput { 7 | //Will implement this in the future 8 | StrInput(String), 9 | DictInput(HashMap), 10 | } 11 | 12 | #[derive(Clone, Debug, Deserialize, Serialize)] 13 | pub struct AgentAction { 14 | pub tool: String, 15 | pub tool_input: String, //this should be ToolInput in the future 16 | pub log: String, 17 | } 18 | 19 | ///Log tools is a struct used by the openai-like agents 20 | #[derive(Clone, Debug, Deserialize, Serialize)] 21 | pub struct LogTools { 22 | pub tool_id: String, 23 | pub tools: String, 24 | } 25 | 26 | #[derive(Clone, Debug, Deserialize, Serialize)] 27 | pub struct AgentFinish { 28 | pub output: String, 29 | } 30 | 31 | #[derive(Debug)] 32 | pub enum AgentEvent { 33 | Action(Vec), 34 | Finish(AgentFinish), 35 | } 36 | 37 | pub enum AgentPlan { 38 | Text(AgentEvent), 39 | Stream(mpsc::Receiver>), 40 | } 41 | -------------------------------------------------------------------------------- /src/chain/sql_datbase/prompt.rs: -------------------------------------------------------------------------------- 1 | pub const DEFAULT_SQLTEMPLATE: &str = r#"Given an input question, first create a syntactically correct {{dialect}} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {{top_k}} results. You can order the results by a relevant column to return the most interesting examples in the database. 2 | 3 | Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. 4 | 5 | Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. 6 | 7 | Use the following format: 8 | 9 | Question: Question here 10 | SQLQuery: SQL Query to run 11 | SQLResult: Result of the SQLQuery 12 | Answer: Final answer here 13 | 14 | "#; 15 | 16 | pub const DEFAULT_SQLSUFFIX: &str = r#"Only use the following tables: 17 | {{table_info}} 18 | 19 | Question: {{input}}"#; 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Ryo Kanazawa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/semantic_routes.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | embedding::openai::OpenAiEmbedder, 3 | semantic_router::{AggregationMethod, RouteLayerBuilder, Router}, 4 | }; 5 | 6 | #[tokio::main] 7 | async fn main() { 8 | let capital_route = Router::new( 9 | "capital", 10 | &[ 11 | "Capital of France is Paris.", 12 | "What is the capital of France?", 13 | ], 14 | ); 15 | let weather_route = Router::new( 16 | "temperature", 17 | &[ 18 | "What is the temperature?", 19 | "Is it raining?", 20 | "Is it cloudy?", 21 | ], 22 | ); 23 | let router_layer = RouteLayerBuilder::default() 24 | .embedder(OpenAiEmbedder::default()) 25 | .add_route(capital_route) 26 | .add_route(weather_route) 27 | .aggregation_method(AggregationMethod::Sum) 28 | .threshold(0.82) 29 | .build() 30 | .await 31 | .unwrap(); 32 | 33 | let routes = router_layer 34 | .call("What is the temperature in Peru?") 35 | .await 36 | .unwrap(); 37 | 38 | println!("{:?}", routes); 39 | } 40 | -------------------------------------------------------------------------------- /examples/agent.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use langchain_rust::{ 4 | agent::{AgentExecutor, ConversationalAgentBuilder}, 5 | chain::{options::ChainCallOptions, Chain}, 6 | llm::openai::{OpenAI, OpenAIModel}, 7 | memory::SimpleMemory, 8 | prompt_args, 9 | tools::CommandExecutor, 10 | }; 11 | 12 | #[tokio::main] 13 | async fn main() { 14 | let llm = OpenAI::default().with_model(OpenAIModel::Gpt4Turbo); 15 | let memory = SimpleMemory::new(); 16 | let command_executor = CommandExecutor::default(); 17 | let agent = ConversationalAgentBuilder::new() 18 | .tools(&[Arc::new(command_executor)]) 19 | .options(ChainCallOptions::new().with_max_tokens(1000)) 20 | .build(llm) 21 | .unwrap(); 22 | 23 | let executor = AgentExecutor::from_agent(agent).with_memory(memory.into()); 24 | 25 | let input_variables = prompt_args! { 26 | "input" => "What is the name of the current dir", 27 | }; 28 | 29 | match executor.invoke(input_variables).await { 30 | Ok(result) => { 31 | println!("Result: {:?}", result); 32 | } 33 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/agent/open_ai_tools/prompt.rs: -------------------------------------------------------------------------------- 1 | pub const PREFIX: &str = r#" 2 | 3 | Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. 4 | 5 | Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. 6 | 7 | Overall, Assistant is a powerful system that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist."#; 8 | -------------------------------------------------------------------------------- /examples/simple_chain.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | chain::{Chain, LLMChainBuilder}, 3 | llm::openai::{OpenAI, OpenAIModel}, 4 | prompt::HumanMessagePromptTemplate, 5 | prompt_args, template_jinja2, 6 | }; 7 | use std::io::{self, Write}; // Include io Library for terminal input 8 | 9 | #[tokio::main] 10 | async fn main() { 11 | let prompt = HumanMessagePromptTemplate::new(template_jinja2!( 12 | "Give me a creative name for a store that sells: {{producto}}", 13 | "producto" 14 | )); 15 | 16 | let llm = OpenAI::default().with_model(OpenAIModel::Gpt35); 17 | let chain = LLMChainBuilder::new() 18 | .prompt(prompt) 19 | .llm(llm) 20 | .build() 21 | .unwrap(); 22 | 23 | print!("Please enter a product: "); 24 | io::stdout().flush().unwrap(); // Display prompt to terminal 25 | 26 | let mut product = String::new(); 27 | io::stdin().read_line(&mut product).unwrap(); // Get product from terminal input 28 | 29 | let product = product.trim(); 30 | 31 | let output = chain 32 | .invoke(prompt_args!["producto" => product]) // Use product input here 33 | .await 34 | .unwrap(); 35 | 36 | println!("Output: {}", output); 37 | } 38 | -------------------------------------------------------------------------------- /src/chain/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use crate::{language_models::LLMError, output_parsers::OutputParserError, prompt::PromptError}; 4 | 5 | #[derive(Error, Debug)] 6 | pub enum ChainError { 7 | #[error("LLM error: {0}")] 8 | LLMError(#[from] LLMError), 9 | 10 | #[error("Retriever error: {0}")] 11 | RetrieverError(String), 12 | 13 | #[error("OutputParser error: {0}")] 14 | OutputParser(#[from] OutputParserError), 15 | 16 | #[error("Prompt error: {0}")] 17 | PromptError(#[from] PromptError), 18 | 19 | #[error("Missing Object On Builder: {0}")] 20 | MissingObject(String), 21 | 22 | #[error("Missing input variable: {0}")] 23 | MissingInputVariable(String), 24 | 25 | #[error("Serde json error: {0}")] 26 | SerdeJsonError(#[from] serde_json::Error), 27 | 28 | #[error("Incorrect input variable: expected type {expected_type}, {source}")] 29 | IncorrectInputVariable { 30 | source: serde_json::Error, 31 | expected_type: String, 32 | }, 33 | 34 | #[error("Error: {0}")] 35 | OtherError(String), 36 | 37 | #[error("Database error: {0}")] 38 | DatabaseError(String), 39 | 40 | #[error("Agent error: {0}")] 41 | AgentError(String), 42 | } 43 | -------------------------------------------------------------------------------- /src/schemas/response_format_openai_like.rs: -------------------------------------------------------------------------------- 1 | use crate::schemas::convert::OpenAIFromLangchain; 2 | 3 | #[derive(Clone, Debug)] 4 | pub enum ResponseFormat { 5 | Text, 6 | JsonObject, 7 | JsonSchema { 8 | description: Option, 9 | name: String, 10 | schema: Option, 11 | strict: Option, 12 | }, 13 | } 14 | 15 | impl OpenAIFromLangchain for async_openai::types::ResponseFormat { 16 | fn from_langchain(langchain: ResponseFormat) -> Self { 17 | match langchain { 18 | ResponseFormat::Text => async_openai::types::ResponseFormat::Text, 19 | ResponseFormat::JsonObject => async_openai::types::ResponseFormat::JsonObject, 20 | ResponseFormat::JsonSchema { 21 | name, 22 | description, 23 | schema, 24 | strict, 25 | } => async_openai::types::ResponseFormat::JsonSchema { 26 | json_schema: async_openai::types::ResponseFormatJsonSchema { 27 | name, 28 | description, 29 | schema, 30 | strict, 31 | }, 32 | }, 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/embedding/error.rs: -------------------------------------------------------------------------------- 1 | use async_openai::error::OpenAIError; 2 | #[cfg(feature = "mistralai")] 3 | use mistralai_client::v1::error::{ApiError, ClientError}; 4 | #[cfg(feature = "ollama")] 5 | use ollama_rs::error::OllamaError; 6 | use reqwest::{Error as ReqwestError, StatusCode}; 7 | use thiserror::Error; 8 | 9 | #[derive(Error, Debug)] 10 | pub enum EmbedderError { 11 | #[error("Network request failed: {0}")] 12 | RequestError(#[from] ReqwestError), 13 | 14 | #[error("OpenAI error: {0}")] 15 | OpenAIError(#[from] OpenAIError), 16 | 17 | #[error("URL parsing error: {0}")] 18 | UrlParseError(#[from] url::ParseError), 19 | 20 | #[error("HTTP error: {status_code} {error_message}")] 21 | HttpError { 22 | status_code: StatusCode, 23 | error_message: String, 24 | }, 25 | 26 | #[error("FastEmbed error: {0}")] 27 | FastEmbedError(String), 28 | 29 | #[cfg(feature = "ollama")] 30 | #[error("Ollama error: {0}")] 31 | OllamaError(#[from] OllamaError), 32 | 33 | #[cfg(feature = "mistralai")] 34 | #[error("MistralAI Client error: {0}")] 35 | MistralAIClientError(#[from] ClientError), 36 | 37 | #[cfg(feature = "mistralai")] 38 | #[error("MistralAI API error: {0}")] 39 | MistralAIApiError(#[from] ApiError), 40 | } 41 | -------------------------------------------------------------------------------- /src/schemas/memory.rs: -------------------------------------------------------------------------------- 1 | use super::messages::Message; 2 | 3 | pub trait BaseMemory: Send + Sync { 4 | fn messages(&self) -> Vec; 5 | 6 | // Use a trait object for Display instead of a generic type 7 | fn add_user_message(&mut self, message: &dyn std::fmt::Display) { 8 | // Convert the Display trait object to a String and pass it to the constructor 9 | self.add_message(Message::new_human_message(message.to_string())); 10 | } 11 | 12 | // Use a trait object for Display instead of a generic type 13 | fn add_ai_message(&mut self, message: &dyn std::fmt::Display) { 14 | // Convert the Display trait object to a String and pass it to the constructor 15 | self.add_message(Message::new_ai_message(message.to_string())); 16 | } 17 | 18 | fn add_message(&mut self, message: Message); 19 | 20 | fn clear(&mut self); 21 | 22 | fn to_string(&self) -> String { 23 | self.messages() 24 | .iter() 25 | .map(|msg| format!("{}: {}", msg.message_type.to_string(), msg.content)) 26 | .collect::>() 27 | .join("\n") 28 | } 29 | } 30 | 31 | impl From for Box 32 | where 33 | M: BaseMemory + 'static, 34 | { 35 | fn from(memory: M) -> Self { 36 | Box::new(memory) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/embedding_fastembed.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "fastembed")] 2 | use langchain_rust::embedding::{Embedder, EmbeddingModel, FastEmbed, InitOptions, TextEmbedding}; 3 | 4 | #[cfg(feature = "fastembed")] 5 | #[tokio::main] 6 | async fn main() { 7 | // With default model 8 | let fastembed = FastEmbed::try_new().unwrap(); 9 | let embeddings = fastembed 10 | .embed_documents(&["hello world".to_string(), "foo bar".to_string()]) 11 | .await 12 | .unwrap(); 13 | 14 | println!("Len: {}", embeddings.len()); 15 | println!("Embeddings: {:?}", embeddings); 16 | 17 | // With custom model 18 | let model = TextEmbedding::try_new( 19 | InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true), 20 | ) 21 | .unwrap(); 22 | 23 | let fastembed = FastEmbed::from(model); 24 | 25 | fastembed 26 | .embed_documents(&["hello world".to_string(), "foo bar".to_string()]) 27 | .await 28 | .unwrap(); 29 | 30 | println!("Len: {:?}", embeddings.len()); 31 | } 32 | 33 | #[cfg(not(feature = "fastembed"))] 34 | fn main() { 35 | println!("This example requires the 'fastembed' feature to be enabled."); 36 | println!("Please run the command as follows:"); 37 | println!("cargo run --example embedding_fastembed --features=fastembed"); 38 | } 39 | -------------------------------------------------------------------------------- /src/memory/window_buffer.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use tokio::sync::Mutex; 4 | 5 | use crate::schemas::{memory::BaseMemory, messages::Message}; 6 | 7 | pub struct WindowBufferMemory { 8 | window_size: usize, 9 | messages: Vec, 10 | } 11 | 12 | impl Default for WindowBufferMemory { 13 | fn default() -> Self { 14 | Self::new(10) 15 | } 16 | } 17 | 18 | impl WindowBufferMemory { 19 | pub fn new(window_size: usize) -> Self { 20 | Self { 21 | messages: Vec::new(), 22 | window_size, 23 | } 24 | } 25 | } 26 | 27 | impl Into> for WindowBufferMemory { 28 | fn into(self) -> Arc { 29 | Arc::new(self) 30 | } 31 | } 32 | 33 | impl Into>> for WindowBufferMemory { 34 | fn into(self) -> Arc> { 35 | Arc::new(Mutex::new(self)) 36 | } 37 | } 38 | 39 | impl BaseMemory for WindowBufferMemory { 40 | fn messages(&self) -> Vec { 41 | self.messages.clone() 42 | } 43 | fn add_message(&mut self, message: Message) { 44 | if self.messages.len() >= self.window_size { 45 | self.messages.remove(0); 46 | } 47 | self.messages.push(message); 48 | } 49 | fn clear(&mut self) { 50 | self.messages.clear(); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /examples/streaming_from_chain.rs: -------------------------------------------------------------------------------- 1 | use futures::StreamExt; 2 | use langchain_rust::{ 3 | chain::{Chain, LLMChainBuilder}, 4 | fmt_message, fmt_template, 5 | llm::openai::OpenAI, 6 | message_formatter, 7 | prompt::HumanMessagePromptTemplate, 8 | prompt_args, 9 | schemas::messages::Message, 10 | template_fstring, 11 | }; 12 | 13 | #[tokio::main] 14 | async fn main() { 15 | let open_ai = OpenAI::default(); 16 | 17 | let prompt = message_formatter![ 18 | fmt_message!(Message::new_system_message( 19 | "You are world class technical documentation writer." 20 | )), 21 | fmt_template!(HumanMessagePromptTemplate::new(template_fstring!( 22 | "{input}", "input" 23 | ))) 24 | ]; 25 | 26 | let chain = LLMChainBuilder::new() 27 | .prompt(prompt) 28 | .llm(open_ai.clone()) 29 | .build() 30 | .unwrap(); 31 | 32 | let mut stream = chain 33 | .stream(prompt_args! { 34 | "input" => "Who is the writer of 20,000 Leagues Under the Sea?", 35 | }) 36 | .await 37 | .unwrap(); 38 | 39 | while let Some(result) = stream.next().await { 40 | match result { 41 | Ok(value) => value.to_stdout().unwrap(), 42 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/document_loaders/error.rs: -------------------------------------------------------------------------------- 1 | use std::{io, string::FromUtf8Error}; 2 | 3 | use thiserror::Error; 4 | 5 | use crate::text_splitter::TextSplitterError; 6 | 7 | #[derive(Error, Debug)] 8 | pub enum LoaderError { 9 | #[error("Error loading document: {0}")] 10 | LoadDocumentError(String), 11 | 12 | #[error("{0}")] 13 | TextSplitterError(#[from] TextSplitterError), 14 | 15 | #[error(transparent)] 16 | IOError(#[from] io::Error), 17 | 18 | #[error(transparent)] 19 | FromUtf8Error(#[from] FromUtf8Error), 20 | 21 | #[error(transparent)] 22 | CSVError(#[from] csv::Error), 23 | 24 | #[cfg(any(feature = "lopdf"))] 25 | #[cfg(not(feature = "pdf-extract"))] 26 | #[error(transparent)] 27 | LoPdfError(#[from] lopdf::Error), 28 | 29 | #[cfg(feature = "pdf-extract")] 30 | #[error(transparent)] 31 | PdfExtractError(#[from] pdf_extract::Error), 32 | 33 | #[cfg(feature = "pdf-extract")] 34 | #[error(transparent)] 35 | PdfExtractOutputError(#[from] pdf_extract::OutputError), 36 | 37 | #[error(transparent)] 38 | ReadabilityError(#[from] readability::error::Error), 39 | 40 | #[error(transparent)] 41 | JoinError(#[from] tokio::task::JoinError), 42 | 43 | #[cfg(feature = "git")] 44 | #[error(transparent)] 45 | DiscoveryError(#[from] gix::discover::Error), 46 | 47 | #[error("Error: {0}")] 48 | OtherError(String), 49 | } 50 | -------------------------------------------------------------------------------- /src/chain/sequential/builder.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | 3 | use crate::chain::Chain; 4 | 5 | use super::SequentialChain; 6 | 7 | pub struct SequentialChainBuilder { 8 | chains: Vec>, 9 | } 10 | 11 | impl SequentialChainBuilder { 12 | pub fn new() -> Self { 13 | Self { chains: Vec::new() } 14 | } 15 | 16 | pub fn add_chain(mut self, chain: C) -> Self { 17 | self.chains.push(Box::new(chain)); 18 | self 19 | } 20 | 21 | pub fn build(self) -> SequentialChain { 22 | let outputs: HashSet = self 23 | .chains 24 | .iter() 25 | .flat_map(|c| c.get_output_keys()) 26 | .collect(); 27 | 28 | let input_keys: HashSet = self 29 | .chains 30 | .iter() 31 | .flat_map(|c| c.get_input_keys()) 32 | .collect(); 33 | 34 | SequentialChain { 35 | chains: self.chains, 36 | input_keys, 37 | outputs, 38 | } 39 | } 40 | } 41 | 42 | #[macro_export] 43 | macro_rules! sequential_chain { 44 | ( $( $chain:expr ),* $(,)? ) => { 45 | { 46 | let mut builder = $crate::chain::SequentialChainBuilder::new(); 47 | $( 48 | builder = builder.add_chain($chain); 49 | )* 50 | builder.build() 51 | } 52 | }; 53 | } 54 | -------------------------------------------------------------------------------- /src/semantic_router/utils.rs: -------------------------------------------------------------------------------- 1 | pub fn combine_embeddings(embeddings: &[Vec]) -> Vec { 2 | embeddings 3 | .iter() 4 | // Initialize a vector with zeros based on the length of the first embedding vector. 5 | // It's assumed all embeddings have the same dimensions. 6 | .fold( 7 | vec![0f64; embeddings[0].len()], 8 | |mut accumulator, embedding_vec| { 9 | for (i, &value) in embedding_vec.iter().enumerate() { 10 | accumulator[i] += value; 11 | } 12 | accumulator 13 | }, 14 | ) 15 | // Calculate the mean for each element across all embeddings. 16 | .iter() 17 | .map(|&sum| sum / embeddings.len() as f64) 18 | .collect() 19 | } 20 | 21 | pub fn cosine_similarity(vec1: &[f64], vec2: &[f64]) -> f64 { 22 | let dot_product: f64 = vec1.iter().zip(vec2.iter()).map(|(a, b)| a * b).sum(); 23 | let magnitude_vec1: f64 = vec1.iter().map(|x| x.powi(2)).sum::().sqrt(); 24 | let magnitude_vec2: f64 = vec2.iter().map(|x| x.powi(2)).sum::().sqrt(); 25 | dot_product / (magnitude_vec1 * magnitude_vec2) 26 | } 27 | 28 | pub fn sum_vectors(vectors: &[Vec]) -> Vec { 29 | let mut sum_vec = vec![0.0; vectors[0].len()]; 30 | for vec in vectors { 31 | for (i, &value) in vec.iter().enumerate() { 32 | sum_vec[i] += value; 33 | } 34 | } 35 | sum_vec 36 | } 37 | -------------------------------------------------------------------------------- /src/semantic_router/router.rs: -------------------------------------------------------------------------------- 1 | use std::hash::{Hash, Hasher}; 2 | 3 | #[derive(Debug, Clone)] 4 | pub struct Router { 5 | pub name: String, 6 | pub utterances: Vec, 7 | pub embedding: Option>>, 8 | pub similarity: Option, 9 | pub tool_description: Option, 10 | } 11 | impl Router { 12 | pub fn new>(name: &str, utterances: &[S]) -> Self { 13 | Self { 14 | name: name.into(), 15 | utterances: utterances.iter().map(|s| s.as_ref().to_string()).collect(), 16 | embedding: None, 17 | similarity: None, 18 | tool_description: None, 19 | } 20 | } 21 | 22 | pub fn with_embedding(mut self, embedding: Vec>) -> Self { 23 | self.embedding = Some(embedding); 24 | self 25 | } 26 | 27 | pub fn with_tool_description>(mut self, tool_description: S) -> Self { 28 | self.tool_description = Some(tool_description.into()); 29 | self 30 | } 31 | 32 | pub fn with_similarity(mut self, similarity: f64) -> Self { 33 | self.similarity = Some(similarity); 34 | self 35 | } 36 | } 37 | 38 | impl Eq for Router {} 39 | impl PartialEq for Router { 40 | fn eq(&self, other: &Self) -> bool { 41 | self.name == other.name 42 | } 43 | } 44 | 45 | impl Hash for Router { 46 | fn hash(&self, state: &mut H) { 47 | self.name.hash(state); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /examples/vision_llm_chain.rs: -------------------------------------------------------------------------------- 1 | use base64::prelude::*; 2 | use langchain_rust::chain::{Chain, LLMChainBuilder}; 3 | use langchain_rust::llm::OpenAI; 4 | use langchain_rust::prompt::HumanMessagePromptTemplate; 5 | use langchain_rust::schemas::Message; 6 | use langchain_rust::{fmt_message, fmt_template, message_formatter, prompt_args, template_fstring}; 7 | 8 | #[tokio::main] 9 | async fn main() { 10 | // Convert image to base64. Can also pass a link to an image instead. 11 | let image = std::fs::read("./src/llm/test_data/example.jpg").unwrap(); 12 | let image_base64 = BASE64_STANDARD.encode(image); 13 | 14 | let prompt = message_formatter![ 15 | fmt_template!(HumanMessagePromptTemplate::new(template_fstring!( 16 | "{input}", "input" 17 | ))), 18 | fmt_message!(Message::new_human_message_with_images(vec![format!( 19 | "data:image/jpeg;base64,{image_base64}" 20 | )])), 21 | ]; 22 | 23 | // let open_ai = OpenAI::new(langchain_rust::llm::ollama::openai::OllamaConfig::default()) 24 | // .with_model("llava"); 25 | let open_ai = OpenAI::default(); 26 | let chain = LLMChainBuilder::new() 27 | .prompt(prompt) 28 | .llm(open_ai) 29 | .build() 30 | .unwrap(); 31 | 32 | match chain 33 | .invoke(prompt_args! { "input" => "Describe this image"}) 34 | .await 35 | { 36 | Ok(result) => { 37 | println!("Result: {:?}", result); 38 | } 39 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/language_models/error.rs: -------------------------------------------------------------------------------- 1 | use async_openai::error::OpenAIError; 2 | #[cfg(feature = "ollama")] 3 | use ollama_rs::error::OllamaError; 4 | use reqwest::Error as ReqwestError; 5 | use serde_json::Error as SerdeJsonError; 6 | use thiserror::Error; 7 | use tokio::time::error::Elapsed; 8 | 9 | use crate::llm::{AnthropicError, DeepseekError, QwenError}; 10 | 11 | #[derive(Error, Debug)] 12 | pub enum LLMError { 13 | #[error("OpenAI error: {0}")] 14 | OpenAIError(#[from] OpenAIError), 15 | 16 | #[error("Anthropic error: {0}")] 17 | AnthropicError(#[from] AnthropicError), 18 | 19 | #[error("Qwen error: {0}")] 20 | QwenError(#[from] QwenError), 21 | 22 | #[error("Deepseek error: {0}")] 23 | DeepseekError(#[from] DeepseekError), 24 | 25 | #[cfg(feature = "ollama")] 26 | #[error("Ollama error: {0}")] 27 | OllamaError(#[from] OllamaError), 28 | 29 | #[error("Network request failed: {0}")] 30 | RequestError(#[from] ReqwestError), 31 | 32 | #[error("JSON serialization/deserialization error: {0}")] 33 | SerdeError(#[from] SerdeJsonError), 34 | 35 | #[error("IO error: {0}")] 36 | IoError(#[from] std::io::Error), 37 | 38 | #[error("Operation timed out")] 39 | Timeout(#[from] Elapsed), 40 | 41 | #[error("Invalid URL: {0}")] 42 | InvalidUrl(String), 43 | 44 | #[error("Content not found in response: Expected at {0}")] 45 | ContentNotFound(String), 46 | 47 | #[error("Parsing error: {0}")] 48 | ParsingError(String), 49 | 50 | #[error("Error: {0}")] 51 | OtherError(String), 52 | } 53 | -------------------------------------------------------------------------------- /examples/dynamic_semantic_routes.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | embedding::openai::OpenAiEmbedder, 3 | semantic_router::{AggregationMethod, RouteLayerBuilder, Router}, 4 | tools::{SerpApi, Tool}, 5 | }; 6 | 7 | #[tokio::main] 8 | async fn main() { 9 | let tool = SerpApi::default(); 10 | let capital_route = Router::new( 11 | "capital", 12 | &[ 13 | "Capital of France is Paris.", 14 | "What is the capital of France?", 15 | ], 16 | ) 17 | .with_tool_description(tool.description()); 18 | let weather_route = Router::new( 19 | "temperature", 20 | &[ 21 | "What is the temperature?", 22 | "Is it raining?", 23 | "Is it cloudy?", 24 | ], 25 | ); 26 | let router_layer = RouteLayerBuilder::default() 27 | .embedder(OpenAiEmbedder::default()) 28 | .add_route(capital_route) 29 | .add_route(weather_route) 30 | .aggregation_method(AggregationMethod::Sum) 31 | .threshold(0.82) 32 | .build() 33 | .await 34 | .unwrap(); 35 | 36 | let route = router_layer 37 | .call("What is the capital of USA") 38 | .await 39 | .unwrap(); 40 | 41 | let route_choice = match route { 42 | Some(route) => route, 43 | None => panic!("No Similar Route"), 44 | }; 45 | 46 | println!("{:?}", &route_choice); 47 | if route_choice.route == "capital" { 48 | let tool_output = tool.run(route_choice.tool_input.unwrap()).await.unwrap(); 49 | println!("{:?}", tool_output); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/text_splitter/markdown_splitter.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use text_splitter::ChunkConfig; 3 | use tiktoken_rs::tokenizer::Tokenizer; 4 | 5 | use super::{SplitterOptions, TextSplitter, TextSplitterError}; 6 | 7 | pub struct MarkdownSplitter { 8 | splitter_options: SplitterOptions, 9 | } 10 | 11 | impl Default for MarkdownSplitter { 12 | fn default() -> Self { 13 | MarkdownSplitter::new(SplitterOptions::default()) 14 | } 15 | } 16 | 17 | impl MarkdownSplitter { 18 | pub fn new(options: SplitterOptions) -> MarkdownSplitter { 19 | MarkdownSplitter { 20 | splitter_options: options, 21 | } 22 | } 23 | 24 | #[deprecated = "Use `SplitterOptions::get_tokenizer_from_str` instead"] 25 | pub fn get_tokenizer_from_str(&self, s: &str) -> Option { 26 | match s.to_lowercase().as_str() { 27 | "cl100k_base" => Some(Tokenizer::Cl100kBase), 28 | "p50k_base" => Some(Tokenizer::P50kBase), 29 | "r50k_base" => Some(Tokenizer::R50kBase), 30 | "p50k_edit" => Some(Tokenizer::P50kEdit), 31 | "gpt2" => Some(Tokenizer::Gpt2), 32 | _ => None, 33 | } 34 | } 35 | } 36 | 37 | #[async_trait] 38 | impl TextSplitter for MarkdownSplitter { 39 | async fn split_text(&self, text: &str) -> Result, TextSplitterError> { 40 | let chunk_config = ChunkConfig::try_from(&self.splitter_options)?; 41 | Ok(text_splitter::MarkdownSplitter::new(chunk_config) 42 | .chunks(text) 43 | .map(|x| x.to_string()) 44 | .collect()) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/text_splitter/token_splitter.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use text_splitter::ChunkConfig; 3 | use tiktoken_rs::tokenizer::Tokenizer; 4 | 5 | use super::{SplitterOptions, TextSplitter, TextSplitterError}; 6 | 7 | #[derive(Debug, Clone)] 8 | pub struct TokenSplitter { 9 | splitter_options: SplitterOptions, 10 | } 11 | 12 | impl Default for TokenSplitter { 13 | fn default() -> Self { 14 | TokenSplitter::new(SplitterOptions::default()) 15 | } 16 | } 17 | 18 | impl TokenSplitter { 19 | pub fn new(options: SplitterOptions) -> TokenSplitter { 20 | TokenSplitter { 21 | splitter_options: options, 22 | } 23 | } 24 | 25 | #[deprecated = "Use `SplitterOptions::get_tokenizer_from_str` instead"] 26 | pub fn get_tokenizer_from_str(&self, s: &str) -> Option { 27 | match s.to_lowercase().as_str() { 28 | "cl100k_base" => Some(Tokenizer::Cl100kBase), 29 | "p50k_base" => Some(Tokenizer::P50kBase), 30 | "r50k_base" => Some(Tokenizer::R50kBase), 31 | "p50k_edit" => Some(Tokenizer::P50kEdit), 32 | "gpt2" => Some(Tokenizer::Gpt2), 33 | _ => None, 34 | } 35 | } 36 | } 37 | 38 | #[async_trait] 39 | impl TextSplitter for TokenSplitter { 40 | async fn split_text(&self, text: &str) -> Result, TextSplitterError> { 41 | let chunk_config = ChunkConfig::try_from(&self.splitter_options)?; 42 | Ok(text_splitter::TextSplitter::new(chunk_config) 43 | .chunks(text) 44 | .map(|x| x.to_string()) 45 | .collect()) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/semantic_router/route_layer/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use crate::{ 4 | chain::ChainError, embedding::EmbedderError, language_models::LLMError, 5 | semantic_router::IndexError, 6 | }; 7 | use serde_json::Error as SerdeJsonError; 8 | 9 | #[derive(Error, Debug)] 10 | pub enum RouterBuilderError { 11 | #[error("Invalid Router configuration: at least one of utterances or embedding must be provided, and utterances cannot be an empty vector.")] 12 | InvalidConfiguration, 13 | } 14 | 15 | #[derive(Error, Debug)] 16 | pub enum RouteLayerBuilderError { 17 | #[error("Route layer should have an embedder")] 18 | MissingEmbedder, 19 | 20 | #[error("Route layer should have an LLM")] 21 | MissingLLM, 22 | 23 | #[error("Missing Index")] 24 | MissingIndex, 25 | 26 | #[error("Route layer error: {0}")] 27 | RouteLayerError(#[from] RouteLayerError), 28 | 29 | #[error("Index error: {0}")] 30 | IndexError(#[from] IndexError), 31 | 32 | #[error("Embedding error: {0}")] 33 | EmbeddingError(#[from] EmbedderError), 34 | 35 | #[error("Chain error: {0}")] 36 | ChainError(#[from] ChainError), 37 | } 38 | 39 | #[derive(Error, Debug)] 40 | pub enum RouteLayerError { 41 | #[error("Embedding error: {0}")] 42 | EmbeddingError(#[from] EmbedderError), 43 | 44 | #[error("Index error: {0}")] 45 | IndexError(#[from] IndexError), 46 | 47 | #[error("LLM error: {0}")] 48 | LLMError(#[from] LLMError), 49 | 50 | #[error("Serialization error: {0}")] 51 | SerializationError(#[from] SerdeJsonError), 52 | 53 | #[error("Chain error: {0}")] 54 | ChainError(#[from] ChainError), 55 | } 56 | -------------------------------------------------------------------------------- /src/prompt/mod.rs: -------------------------------------------------------------------------------- 1 | mod chat; 2 | mod error; 3 | mod prompt; 4 | 5 | use std::collections::HashMap; 6 | 7 | pub use chat::*; 8 | pub use error::*; 9 | pub use prompt::*; 10 | use serde_json::Value; 11 | 12 | use crate::schemas::{messages::Message, prompt::PromptValue}; 13 | 14 | // pub type PromptArgs<'a> = HashMap<&'a str, &'a str>; 15 | pub type PromptArgs = HashMap; 16 | pub trait PromptFromatter: Send + Sync { 17 | fn template(&self) -> String; 18 | fn variables(&self) -> Vec; 19 | fn format(&self, input_variables: PromptArgs) -> Result; 20 | } 21 | impl From for Box 22 | where 23 | PA: PromptFromatter + 'static, 24 | { 25 | fn from(prompt: PA) -> Self { 26 | Box::new(prompt) 27 | } 28 | } 29 | 30 | /// Represents a generic template for formatting messages. 31 | pub trait MessageFormatter: Send + Sync { 32 | fn format_messages(&self, input_variables: PromptArgs) -> Result, PromptError>; 33 | 34 | /// Returns a list of required input variable names for the template. 35 | fn input_variables(&self) -> Vec; 36 | } 37 | impl From for Box 38 | where 39 | MF: MessageFormatter + 'static, 40 | { 41 | fn from(prompt: MF) -> Self { 42 | Box::new(prompt) 43 | } 44 | } 45 | 46 | pub trait FormatPrompter: Send + Sync { 47 | fn format_prompt(&self, input_variables: PromptArgs) -> Result; 48 | fn get_input_variables(&self) -> Vec; 49 | } 50 | impl From for Box 51 | where 52 | FP: FormatPrompter + 'static, 53 | { 54 | fn from(prompt: FP) -> Self { 55 | Box::new(prompt) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/document_loaders/document_loader.rs: -------------------------------------------------------------------------------- 1 | use std::pin::Pin; 2 | 3 | use async_stream::stream; 4 | use async_trait::async_trait; 5 | use futures::Stream; 6 | use futures_util::{pin_mut, StreamExt}; 7 | 8 | use crate::{schemas::Document, text_splitter::TextSplitter}; 9 | 10 | use super::LoaderError; 11 | 12 | #[async_trait] 13 | pub trait Loader: Send + Sync { 14 | async fn load( 15 | self, 16 | ) -> Result< 17 | Pin> + Send + 'static>>, 18 | LoaderError, 19 | >; 20 | async fn load_and_split( 21 | self, 22 | splitter: TS, 23 | ) -> Result< 24 | Pin> + Send + 'static>>, 25 | LoaderError, 26 | >; 27 | } 28 | 29 | pub(crate) async fn process_doc_stream( 30 | doc_stream: Pin> + Send>>, 31 | splitter: TS, 32 | ) -> impl Stream> { 33 | stream! { 34 | pin_mut!(doc_stream); 35 | while let Some(doc_result) = doc_stream.next().await { 36 | match doc_result { 37 | Ok(doc) => { 38 | match splitter.split_documents(&[doc]).await { 39 | Ok(docs) => { 40 | for doc in docs { 41 | yield Ok(doc); 42 | } 43 | }, 44 | Err(e) => yield Err(LoaderError::TextSplitterError(e)), 45 | } 46 | } 47 | Err(e) => yield Err(e), 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/language_models/llm.rs: -------------------------------------------------------------------------------- 1 | use std::pin::Pin; 2 | 3 | use async_trait::async_trait; 4 | use futures::Stream; 5 | 6 | use crate::schemas::{Message, StreamData}; 7 | 8 | use super::{options::CallOptions, GenerateResult, LLMError}; 9 | 10 | #[async_trait] 11 | pub trait LLM: Sync + Send + LLMClone { 12 | async fn generate(&self, messages: &[Message]) -> Result; 13 | async fn invoke(&self, prompt: &str) -> Result { 14 | self.generate(&[Message::new_human_message(prompt)]) 15 | .await 16 | .map(|res| res.generation) 17 | } 18 | async fn stream( 19 | &self, 20 | _messages: &[Message], 21 | ) -> Result> + Send>>, LLMError>; 22 | 23 | /// This is usefull when you want to create a chain and override 24 | /// LLM options 25 | fn add_options(&mut self, _options: CallOptions) { 26 | // No action taken 27 | } 28 | //This is usefull when using non chat models 29 | fn messages_to_string(&self, messages: &[Message]) -> String { 30 | messages 31 | .iter() 32 | .map(|m| format!("{:?}: {}", m.message_type, m.content)) 33 | .collect::>() 34 | .join("\n") 35 | } 36 | } 37 | 38 | pub trait LLMClone { 39 | fn clone_box(&self) -> Box; 40 | } 41 | 42 | impl LLMClone for T 43 | where 44 | T: 'static + LLM + Clone, 45 | { 46 | fn clone_box(&self) -> Box { 47 | Box::new(self.clone()) 48 | } 49 | } 50 | 51 | impl From for Box 52 | where 53 | L: 'static + LLM, 54 | { 55 | fn from(llm: L) -> Self { 56 | Box::new(llm) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Langchain-rust 2 | 3 | Thank you for your interest in contributing to Langchain-rust! We appreciate any contributions, whether it's bug fixes, new features, or documentation improvements. Here's how you can get started: 4 | 5 | ## Getting Started 6 | 7 | 1. Fork the Langchain-rust repository to your GitHub account. 8 | 2. Clone the forked repository to your local machine. 9 | 3. Create a new branch for your contributions: `git checkout -b my-contribution`. 10 | 4. Make your changes and improvements to the codebase. 11 | 5. Run the tests to ensure that everything is functioning correctly. 12 | 6. Commit your changes: `git commit -m "Add my contribution"`. 13 | 7. Push the changes to your forked repository: `git push origin my-contribution`. 14 | 8. Open a new pull request on the Langchain-rust repository. 15 | 16 | ## Guidelines 17 | 18 | - Make sure your code follows the Rust programming style and conventions. 19 | - Write clear and concise commit messages to describe your changes. 20 | - Include appropriate documentation and tests for the changes you make. 21 | - If your contribution includes new features or significant changes, please update the relevant documentation or README files. 22 | 23 | ## Code of Conduct 24 | 25 | Langchain-rust follows a code of conduct to ensure a welcoming and inclusive environment for all contributors. Please review the CODE_OF_CONDUCT.md file in the repository to familiarize yourself with our expectations and guidelines. 26 | 27 | ## Communication 28 | 29 | If you have any questions or need assistance with your contributions, feel free to reach out to us on our official Discord channel or open an issue on the repository. 30 | 31 | We appreciate your support and look forward to your contributions! 32 | -------------------------------------------------------------------------------- /src/text_splitter/text_splitter.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use async_trait::async_trait; 4 | use serde_json::Value; 5 | 6 | use crate::schemas::Document; 7 | 8 | use super::TextSplitterError; 9 | 10 | #[async_trait] 11 | pub trait TextSplitter: Send + Sync { 12 | async fn split_text(&self, text: &str) -> Result, TextSplitterError>; 13 | 14 | async fn split_documents( 15 | &self, 16 | documents: &[Document], 17 | ) -> Result, TextSplitterError> { 18 | let mut texts: Vec = Vec::new(); 19 | let mut metadatas: Vec> = Vec::new(); 20 | documents.iter().for_each(|d| { 21 | texts.push(d.page_content.clone()); 22 | metadatas.push(d.metadata.clone()); 23 | }); 24 | 25 | self.create_documents(&texts, &metadatas).await 26 | } 27 | 28 | async fn create_documents( 29 | &self, 30 | text: &[String], 31 | metadatas: &[HashMap], 32 | ) -> Result, TextSplitterError> { 33 | let mut metadatas = metadatas.to_vec(); 34 | if metadatas.is_empty() { 35 | metadatas = vec![HashMap::new(); text.len()]; 36 | } 37 | 38 | if text.len() != metadatas.len() { 39 | return Err(TextSplitterError::MetadataTextMismatch); 40 | } 41 | 42 | let mut documents: Vec = Vec::new(); 43 | for i in 0..text.len() { 44 | let chunks = self.split_text(&text[i]).await?; 45 | for chunk in chunks { 46 | let document = Document::new(chunk).with_metadata(metadatas[i].clone()); 47 | documents.push(document); 48 | } 49 | } 50 | 51 | Ok(documents) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /examples/sequential_chain.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | chain::{Chain, LLMChainBuilder}, 3 | llm::openai::{OpenAI, OpenAIModel}, 4 | prompt::HumanMessagePromptTemplate, 5 | prompt_args, sequential_chain, template_jinja2, 6 | }; 7 | use std::io::{self, Write}; // Include io Library for terminal input 8 | 9 | #[tokio::main] 10 | async fn main() { 11 | let llm = OpenAI::default().with_model(OpenAIModel::Gpt35); 12 | let prompt = HumanMessagePromptTemplate::new(template_jinja2!( 13 | "Dame un nombre creativo para una tienda que vende: {{producto}}", 14 | "producto" 15 | )); 16 | 17 | let get_name_chain = LLMChainBuilder::new() 18 | .prompt(prompt) 19 | .llm(llm.clone()) 20 | .output_key("name") 21 | .build() 22 | .unwrap(); 23 | 24 | let prompt = HumanMessagePromptTemplate::new(template_jinja2!( 25 | "Dame un slogan para el siguiente nombre: {{name}}", 26 | "name" 27 | )); 28 | let get_slogan_chain = LLMChainBuilder::new() 29 | .prompt(prompt) 30 | .llm(llm.clone()) 31 | .output_key("slogan") 32 | .build() 33 | .unwrap(); 34 | 35 | let sequential_chain = sequential_chain!(get_name_chain, get_slogan_chain); 36 | 37 | print!("Please enter a product: "); 38 | io::stdout().flush().unwrap(); // Display prompt to terminal 39 | 40 | let mut product = String::new(); 41 | io::stdin().read_line(&mut product).unwrap(); // Get product from terminal input 42 | 43 | let product = product.trim(); 44 | let output = sequential_chain 45 | .execute(prompt_args! { 46 | "producto" => product 47 | }) 48 | .await 49 | .unwrap(); 50 | 51 | println!("Name: {}", output["name"]); 52 | println!("Slogan: {}", output["slogan"]); 53 | } 54 | -------------------------------------------------------------------------------- /src/vectorstore/options.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use serde_json::Value; 4 | 5 | use crate::embedding::embedder_trait::Embedder; 6 | 7 | /// The `VecStoreOptions` struct is responsible for determining options when 8 | /// interacting with a Vector Store. The options include `name_space`, `score_threshold`, 9 | /// `filters`, and `embedder`. 10 | /// 11 | /// # Usage 12 | /// ```rust,ignore 13 | /// let options = VecStoreOptions::new() 14 | /// .with_name_space("my_custom_namespace") 15 | /// .with_score_threshold(0.5) 16 | /// .with_filters(json!({"genre": "Sci-Fi"})) 17 | /// .with_embedder(my_embedder); 18 | /// ``` 19 | pub struct VecStoreOptions { 20 | pub name_space: Option, 21 | pub score_threshold: Option, 22 | pub filters: Option, 23 | pub embedder: Option>, 24 | } 25 | 26 | impl Default for VecStoreOptions { 27 | fn default() -> Self { 28 | Self::new() 29 | } 30 | } 31 | 32 | impl VecStoreOptions { 33 | pub fn new() -> Self { 34 | VecStoreOptions { 35 | name_space: None, 36 | score_threshold: None, 37 | filters: None, 38 | embedder: None, 39 | } 40 | } 41 | 42 | pub fn with_name_space>(mut self, name_space: S) -> Self { 43 | self.name_space = Some(name_space.into()); 44 | self 45 | } 46 | 47 | pub fn with_score_threshold(mut self, score_threshold: f32) -> Self { 48 | self.score_threshold = Some(score_threshold); 49 | self 50 | } 51 | 52 | pub fn with_filters(mut self, filters: F) -> Self { 53 | self.filters = Some(filters); 54 | self 55 | } 56 | 57 | pub fn with_embedder(mut self, embedder: E) -> Self { 58 | self.embedder = Some(Arc::new(embedder)); 59 | self 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /examples/rcommiter.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, BufRead}; 2 | use std::process::{Command, Stdio}; 3 | 4 | use langchain_rust::chain::chain_trait::Chain; 5 | use langchain_rust::chain::llm_chain::LLMChainBuilder; 6 | use langchain_rust::llm::openai::OpenAI; 7 | use langchain_rust::prompt::HumanMessagePromptTemplate; 8 | use langchain_rust::{prompt_args, template_jinja2}; 9 | 10 | //to try this in action , add something to this file stage it an run it 11 | #[tokio::main] 12 | async fn main() -> io::Result<()> { 13 | let llm = OpenAI::default(); 14 | let chain = LLMChainBuilder::new() 15 | .prompt(HumanMessagePromptTemplate::new(template_jinja2!( 16 | r#" 17 | Create a conventional commit message for the following changes. 18 | 19 | File changes: 20 | {{input}} 21 | 22 | 23 | 24 | "#, 25 | "input" 26 | ))) 27 | .llm(llm) 28 | .build() 29 | .expect("Failed to build LLMChain"); 30 | 31 | let shell_command = r#" 32 | git diff --cached --name-only --diff-filter=ACM | while read -r file; do echo "\n---------------------------\n name:$file"; git diff --cached "$file" | sed 's/^/changes:/'; done 33 | "#; 34 | 35 | let output = Command::new("sh") 36 | .arg("-c") 37 | .arg(shell_command) 38 | .stdout(Stdio::piped()) 39 | .spawn()? 40 | .stdout 41 | .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Could not capture stdout."))?; 42 | 43 | let reader = io::BufReader::new(output); 44 | 45 | let complete_changes = reader 46 | .lines() 47 | .map(|line| line.unwrap()) 48 | .collect::>() 49 | .join("\n"); 50 | 51 | let res = chain 52 | .invoke(prompt_args! { 53 | "input"=>complete_changes, 54 | }) 55 | .await 56 | .expect("Failed to invoke chain"); 57 | 58 | println!("{}", res); 59 | Ok(()) 60 | } 61 | -------------------------------------------------------------------------------- /examples/llm_chain_qwen.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | chain::{chain_trait::Chain, llm_chain::LLMChainBuilder}, 3 | language_models::options::CallOptions, 4 | llm::{Qwen, QwenModel}, 5 | prompt::{PromptTemplate, TemplateFormat}, 6 | prompt_args, 7 | }; 8 | use std::env; 9 | 10 | #[tokio::main] 11 | async fn main() { 12 | // Get API key from environment variable 13 | let api_key = env::var("QWEN_API_KEY").expect("QWEN_API_KEY environment variable must be set"); 14 | 15 | // Setup the Qwen client with desired model and parameters 16 | let qwen = Qwen::new() 17 | .with_api_key(api_key) 18 | .with_model(QwenModel::QwenTurbo.to_string()) 19 | .with_options( 20 | CallOptions::default() 21 | .with_max_tokens(800) 22 | .with_temperature(0.8), 23 | ); 24 | 25 | // Create a prompt template 26 | let template = r#" 27 | You are a helpful assistant that provides detailed information. 28 | 29 | User question: {question} 30 | 31 | Please provide a comprehensive answer: 32 | "#; 33 | 34 | let prompt = PromptTemplate::new( 35 | template.to_owned(), 36 | vec!["question".to_owned()], 37 | TemplateFormat::FString, 38 | ); 39 | 40 | // Create an LLMChain using the builder pattern 41 | let chain = LLMChainBuilder::new() 42 | .prompt(prompt) 43 | .llm(qwen) 44 | .build() 45 | .unwrap(); 46 | 47 | // Execute the chain with a question 48 | let inputs = prompt_args! { 49 | "question" => "Explain the importance of quantum computing and its potential applications." 50 | }; 51 | 52 | let result = chain.call(inputs).await.unwrap(); 53 | 54 | println!( 55 | "Question: Explain the importance of quantum computing and its potential applications." 56 | ); 57 | println!("\nQwen's response:"); 58 | println!("{}", result.generation); 59 | } 60 | -------------------------------------------------------------------------------- /src/schemas/convert.rs: -------------------------------------------------------------------------------- 1 | pub trait LangchainIntoOpenAI: Sized { 2 | fn into_openai(self) -> T; 3 | } 4 | 5 | pub trait LangchainFromOpenAI: Sized { 6 | fn from_openai(openai: T) -> Self; 7 | } 8 | 9 | pub trait OpenAiIntoLangchain: Sized { 10 | fn into_langchain(self) -> T; 11 | } 12 | 13 | pub trait OpenAIFromLangchain: Sized { 14 | fn from_langchain(langchain: T) -> Self; 15 | } 16 | 17 | impl LangchainIntoOpenAI for T 18 | where 19 | U: OpenAIFromLangchain, 20 | { 21 | fn into_openai(self) -> U { 22 | U::from_langchain(self) 23 | } 24 | } 25 | 26 | impl OpenAiIntoLangchain for T 27 | where 28 | U: LangchainFromOpenAI, 29 | { 30 | fn into_langchain(self) -> U { 31 | U::from_openai(self) 32 | } 33 | } 34 | 35 | // Try into and from OpenAI 36 | 37 | pub trait TryLangchainIntoOpenAI: Sized { 38 | type Error; 39 | 40 | fn try_into_openai(self) -> Result; 41 | } 42 | 43 | pub trait TryLangchainFromOpenAI: Sized { 44 | type Error; 45 | 46 | fn try_from_openai(openai: T) -> Result; 47 | } 48 | 49 | pub trait TryOpenAiIntoLangchain: Sized { 50 | type Error; 51 | 52 | fn try_into_langchain(self) -> Result; 53 | } 54 | 55 | pub trait TryOpenAiFromLangchain: Sized { 56 | type Error; 57 | 58 | fn try_from_langchain(langchain: T) -> Result; 59 | } 60 | 61 | impl TryLangchainIntoOpenAI for T 62 | where 63 | U: TryOpenAiFromLangchain, 64 | { 65 | type Error = U::Error; 66 | 67 | fn try_into_openai(self) -> Result { 68 | U::try_from_langchain(self) 69 | } 70 | } 71 | 72 | impl TryOpenAiIntoLangchain for T 73 | where 74 | U: TryLangchainFromOpenAI, 75 | { 76 | type Error = U::Error; 77 | 78 | fn try_into_langchain(self) -> Result { 79 | U::try_from_openai(self) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /examples/open_ai_tools_agent.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, sync::Arc}; 2 | 3 | use async_trait::async_trait; 4 | use langchain_rust::{ 5 | agent::{AgentExecutor, OpenAiToolAgentBuilder}, 6 | chain::{options::ChainCallOptions, Chain}, 7 | llm::openai::OpenAI, 8 | memory::SimpleMemory, 9 | prompt_args, 10 | tools::{CommandExecutor, DuckDuckGoSearchResults, SerpApi, Tool}, 11 | }; 12 | 13 | use serde_json::Value; 14 | struct Date {} 15 | 16 | #[async_trait] 17 | impl Tool for Date { 18 | fn name(&self) -> String { 19 | "Date".to_string() 20 | } 21 | fn description(&self) -> String { 22 | "Useful when you need to get the date,input is a query".to_string() 23 | } 24 | async fn run(&self, _input: Value) -> Result> { 25 | Ok("25 of november of 2025".to_string()) 26 | } 27 | } 28 | 29 | #[tokio::main] 30 | async fn main() { 31 | let llm = OpenAI::default(); 32 | let memory = SimpleMemory::new(); 33 | let serpapi_tool = SerpApi::default(); 34 | let duckduckgo_tool = DuckDuckGoSearchResults::default(); 35 | let tool_calc = Date {}; 36 | let command_executor = CommandExecutor::default(); 37 | let agent = OpenAiToolAgentBuilder::new() 38 | .tools(&[ 39 | Arc::new(serpapi_tool), 40 | Arc::new(tool_calc), 41 | Arc::new(command_executor), 42 | Arc::new(duckduckgo_tool), 43 | ]) 44 | .options(ChainCallOptions::new().with_max_tokens(1000)) 45 | .build(llm) 46 | .unwrap(); 47 | 48 | let executor = AgentExecutor::from_agent(agent).with_memory(memory.into()); 49 | 50 | let input_variables = prompt_args! { 51 | "input" => "What the name of the current dir, And what date is today", 52 | }; 53 | 54 | match executor.invoke(input_variables).await { 55 | Ok(result) => { 56 | println!("Result: {:?}", result.replace("\n", " ")); 57 | } 58 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /examples/llm_chain_deepseek.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | chain::{chain_trait::Chain, llm_chain::LLMChainBuilder}, 3 | language_models::options::CallOptions, 4 | llm::{Deepseek, DeepseekModel}, 5 | prompt::{PromptTemplate, TemplateFormat}, 6 | prompt_args, 7 | }; 8 | use std::env; 9 | 10 | #[tokio::main] 11 | async fn main() { 12 | // Get API key from environment variable 13 | let api_key = 14 | env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY environment variable must be set"); 15 | 16 | // Setup the Deepseek client with desired model and parameters 17 | let deepseek = Deepseek::new() 18 | .with_api_key(api_key) 19 | .with_model(DeepseekModel::DeepseekChat.to_string()) 20 | .with_options( 21 | CallOptions::default() 22 | .with_max_tokens(800) 23 | .with_temperature(1.3), // Using recommended temperature for general conversation 24 | ); 25 | 26 | // Create a prompt template 27 | let template = r#" 28 | You are a helpful assistant that provides detailed information. 29 | 30 | User question: {question} 31 | 32 | Please provide a comprehensive answer: 33 | "#; 34 | 35 | let prompt = PromptTemplate::new( 36 | template.to_owned(), 37 | vec!["question".to_owned()], 38 | TemplateFormat::FString, 39 | ); 40 | 41 | // Create an LLMChain using the builder pattern 42 | let chain = LLMChainBuilder::new() 43 | .prompt(prompt) 44 | .llm(deepseek) 45 | .build() 46 | .unwrap(); 47 | 48 | // Execute the chain with a question 49 | let inputs = prompt_args! { 50 | "question" => "Explain the importance of quantum computing and its potential applications." 51 | }; 52 | 53 | let result = chain.call(inputs).await.unwrap(); 54 | 55 | println!( 56 | "Question: Explain the importance of quantum computing and its potential applications." 57 | ); 58 | println!("\nDeepseek's response:"); 59 | println!("{}", result.generation); 60 | } 61 | -------------------------------------------------------------------------------- /src/output_parsers/markdown_parser.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use regex::Regex; 3 | 4 | use super::{OutputParser, OutputParserError}; 5 | 6 | pub struct MarkdownParser { 7 | expresion: String, 8 | trim: bool, 9 | } 10 | impl MarkdownParser { 11 | pub fn new() -> Self { 12 | Self { 13 | expresion: r"```(?:\w+)?\s*([\s\S]+?)\s*```".to_string(), 14 | trim: false, 15 | } 16 | } 17 | 18 | pub fn with_custom_expresion(mut self, expresion: &str) -> Self { 19 | self.expresion = expresion.to_string(); 20 | self 21 | } 22 | 23 | pub fn with_trim(mut self, trim: bool) -> Self { 24 | self.trim = trim; 25 | self 26 | } 27 | } 28 | impl Default for MarkdownParser { 29 | fn default() -> Self { 30 | Self::new() 31 | } 32 | } 33 | 34 | #[async_trait] 35 | impl OutputParser for MarkdownParser { 36 | async fn parse(&self, output: &str) -> Result { 37 | let re = Regex::new(r"```(?:\w+)?\s*([\s\S]+?)\s*```")?; 38 | if let Some(cap) = re.captures(output) { 39 | let find = cap[1].to_string(); 40 | if self.trim { 41 | Ok(find.trim().to_string()) 42 | } else { 43 | Ok(find) 44 | } 45 | } else { 46 | Err(OutputParserError::ParsingError( 47 | "No code block found".into(), 48 | )) 49 | } 50 | } 51 | } 52 | 53 | #[cfg(test)] 54 | mod tests { 55 | use super::*; 56 | 57 | #[tokio::test] 58 | async fn test_markdown_parser_finds_code_block() { 59 | let parser = MarkdownParser::new(); 60 | let markdown_content = r#" 61 | ```rust 62 | fn main() { 63 | println!("Hello, world!"); 64 | } 65 | ``` 66 | "#; 67 | let result = parser.parse(markdown_content).await; 68 | println!("{:?}", result); 69 | 70 | let correct = r#"fn main() { 71 | println!("Hello, world!"); 72 | }"#; 73 | assert!(result.is_ok()); 74 | assert_eq!(result.unwrap(), correct); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/agent/open_ai_tools/builder.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use crate::{ 4 | agent::AgentError, 5 | chain::{options::ChainCallOptions, LLMChainBuilder}, 6 | language_models::{llm::LLM, options::CallOptions}, 7 | schemas::FunctionDefinition, 8 | tools::Tool, 9 | }; 10 | 11 | use super::{prompt::PREFIX, OpenAiToolAgent}; 12 | 13 | pub struct OpenAiToolAgentBuilder { 14 | tools: Option>>, 15 | prefix: Option, 16 | options: Option, 17 | } 18 | 19 | impl OpenAiToolAgentBuilder { 20 | pub fn new() -> Self { 21 | Self { 22 | tools: None, 23 | prefix: None, 24 | options: None, 25 | } 26 | } 27 | 28 | pub fn tools(mut self, tools: &[Arc]) -> Self { 29 | self.tools = Some(tools.to_vec()); 30 | self 31 | } 32 | 33 | pub fn prefix>(mut self, prefix: S) -> Self { 34 | self.prefix = Some(prefix.into()); 35 | self 36 | } 37 | 38 | pub fn options(mut self, options: ChainCallOptions) -> Self { 39 | self.options = Some(options); 40 | self 41 | } 42 | 43 | pub fn build(self, llm: L) -> Result { 44 | let tools = self.tools.unwrap_or_default(); 45 | let prefix = self.prefix.unwrap_or_else(|| PREFIX.to_string()); 46 | let mut llm = llm; 47 | 48 | let prompt = OpenAiToolAgent::create_prompt(&prefix)?; 49 | let default_options = ChainCallOptions::default().with_max_tokens(1000); 50 | let functions = tools 51 | .iter() 52 | .map(FunctionDefinition::from_langchain_tool) 53 | .collect::>(); 54 | llm.add_options(CallOptions::new().with_functions(functions)); 55 | let chain = Box::new( 56 | LLMChainBuilder::new() 57 | .prompt(prompt) 58 | .llm(llm) 59 | .options(self.options.unwrap_or(default_options)) 60 | .build()?, 61 | ); 62 | 63 | Ok(OpenAiToolAgent { chain, tools }) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/schemas/document.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | use serde_json::Value; 5 | 6 | /// The `Document` struct represents a document with content, metadata, and a score. 7 | /// The `page_content` field is a string that contains the content of the document. 8 | /// The `metadata` field is a `HashMap` where the keys represent metadata properties and the values represent property values. 9 | /// The `score` field represents a relevance score for the document and is a floating point number. 10 | /// 11 | /// # Usage 12 | /// ```rust,ignore 13 | /// let my_doc = Document::new("This is the document content.".to_string()) 14 | /// .with_metadata({ 15 | /// let mut metadata = HashMap::new(); 16 | /// metadata.insert("author".to_string(), json!("John Doe")); 17 | /// metadata 18 | /// }) 19 | /// .with_score(0.75); 20 | /// ``` 21 | #[derive(Debug, Clone, Serialize, Deserialize)] 22 | pub struct Document { 23 | pub page_content: String, 24 | pub metadata: HashMap, 25 | pub score: f64, 26 | } 27 | 28 | impl Document { 29 | /// Constructs a new `Document` with provided `page_content`, an empty `metadata` map and a `score` of 0. 30 | pub fn new>(page_content: S) -> Self { 31 | Document { 32 | page_content: page_content.into(), 33 | metadata: HashMap::new(), 34 | score: 0.0, 35 | } 36 | } 37 | 38 | /// Sets the `metadata` Map of the `Document` to the provided HashMap. 39 | pub fn with_metadata(mut self, metadata: HashMap) -> Self { 40 | self.metadata = metadata; 41 | self 42 | } 43 | 44 | /// Sets the `score` of the `Document` to the provided float. 45 | pub fn with_score(mut self, score: f64) -> Self { 46 | self.score = score; 47 | self 48 | } 49 | } 50 | 51 | impl Default for Document { 52 | /// Provides a default `Document` with an empty `page_content`, an empty `metadata` map and a `score` of 0. 53 | fn default() -> Self { 54 | Document { 55 | page_content: "".to_string(), 56 | metadata: HashMap::new(), 57 | score: 0.0, 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/agent/chat/builder.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use crate::{ 4 | agent::AgentError, 5 | chain::{llm_chain::LLMChainBuilder, options::ChainCallOptions}, 6 | language_models::llm::LLM, 7 | tools::Tool, 8 | }; 9 | 10 | use super::{ 11 | output_parser::ChatOutputParser, 12 | prompt::{PREFIX, SUFFIX}, 13 | ConversationalAgent, 14 | }; 15 | 16 | pub struct ConversationalAgentBuilder { 17 | tools: Option>>, 18 | prefix: Option, 19 | suffix: Option, 20 | options: Option, 21 | } 22 | 23 | impl ConversationalAgentBuilder { 24 | pub fn new() -> Self { 25 | Self { 26 | tools: None, 27 | prefix: None, 28 | suffix: None, 29 | options: None, 30 | } 31 | } 32 | 33 | pub fn tools(mut self, tools: &[Arc]) -> Self { 34 | self.tools = Some(tools.to_vec()); 35 | self 36 | } 37 | 38 | pub fn prefix>(mut self, prefix: S) -> Self { 39 | self.prefix = Some(prefix.into()); 40 | self 41 | } 42 | 43 | pub fn suffix>(mut self, suffix: S) -> Self { 44 | self.suffix = Some(suffix.into()); 45 | self 46 | } 47 | 48 | pub fn options(mut self, options: ChainCallOptions) -> Self { 49 | self.options = Some(options); 50 | self 51 | } 52 | 53 | pub fn build>>(self, llm: L) -> Result { 54 | let tools = self.tools.unwrap_or_default(); 55 | let prefix = self.prefix.unwrap_or_else(|| PREFIX.to_string()); 56 | let suffix = self.suffix.unwrap_or_else(|| SUFFIX.to_string()); 57 | 58 | let prompt = ConversationalAgent::create_prompt(&tools, &suffix, &prefix)?; 59 | let default_options = ChainCallOptions::default().with_max_tokens(1000); 60 | let chain = Box::new( 61 | LLMChainBuilder::new() 62 | .prompt(prompt) 63 | .llm(llm) 64 | .options(self.options.unwrap_or(default_options)) 65 | .build()?, 66 | ); 67 | 68 | Ok(ConversationalAgent { 69 | chain, 70 | tools, 71 | output_parser: ChatOutputParser::new(), 72 | }) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/llm/qwen/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum QwenError { 5 | #[error("Qwen API error: Invalid parameter - {0}")] 6 | InvalidParameterError(String), 7 | 8 | #[error("Qwen API error: Invalid API Key - {0}")] 9 | InvalidApiKeyError(String), 10 | 11 | #[error("Qwen API error: Network error - {0}")] 12 | NetworkError(String), 13 | 14 | #[error("Qwen API error: Model Unavailable - {0}")] 15 | ModelUnavailableError(String), 16 | 17 | #[error("Qwen API error: Rate limit exceeded - {0}")] 18 | ModelServingError(String), 19 | 20 | #[error("Qwen API error: Internal error - {0}")] 21 | InternalError(String), 22 | 23 | #[error("Qwen API error: System error - {0}")] 24 | SystemError(String), 25 | 26 | #[error("Qwen API error: Billing issue - {0}")] 27 | BillingError(String), 28 | 29 | #[error("Qwen API error: Mismatched model - {0}")] 30 | MismatchedModelError(String), 31 | 32 | #[error("Qwen API error: Duplicate custom ID - {0}")] 33 | DuplicateCustomIdError(String), 34 | 35 | #[error("Qwen API error: Model not found - {0}")] 36 | ModelNotFoundError(String), 37 | 38 | #[error("Qwen API error: Connection error - {0}")] 39 | APIConnectionError(String), 40 | 41 | #[error("Qwen API error: Prepaid bill overdue - {0}")] 42 | PrepaidBillOverdueError(String), 43 | 44 | #[error("Qwen API error: Postpaid bill overdue - {0}")] 45 | PostpaidBillOverdueError(String), 46 | 47 | #[error("Qwen API error: Commodity not purchased - {0}")] 48 | CommodityNotPurchasedError(String), 49 | 50 | #[error("Qwen API error: Internal algorithm error - {0}")] 51 | InternalAlgorithmError(String), 52 | 53 | #[error("Qwen API error: Timeout - {0}")] 54 | TimeoutError(String), 55 | 56 | #[error("Qwen API error: Rewrite failed - {0}")] 57 | RewriteFailedError(String), 58 | 59 | #[error("Qwen API error: Retrieval failed - {0}")] 60 | RetrievalFailedError(String), 61 | 62 | #[error("Qwen API error: Application process failed - {0}")] 63 | AppProcessFailedError(String), 64 | 65 | #[error("Qwen API error: Model service failed - {0}")] 66 | ModelServiceFailedError(String), 67 | 68 | #[error("Qwen API error: Plugin invocation failed - {0}")] 69 | InvokePluginFailedError(String), 70 | } 71 | -------------------------------------------------------------------------------- /src/embedding/fastembed/fastembed.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | use crate::embedding::{Embedder, EmbedderError}; 4 | use fastembed::TextEmbedding; 5 | 6 | pub struct FastEmbed { 7 | model: TextEmbedding, 8 | batch_size: Option, 9 | } 10 | 11 | impl FastEmbed { 12 | pub fn try_new() -> Result { 13 | Ok(Self { 14 | model: TextEmbedding::try_new(Default::default()) 15 | .map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?, 16 | batch_size: None, 17 | }) 18 | } 19 | 20 | pub fn with_batch_size(mut self, batch_size: usize) -> Self { 21 | self.batch_size = Some(batch_size); 22 | self 23 | } 24 | } 25 | 26 | impl From for FastEmbed { 27 | fn from(model: TextEmbedding) -> Self { 28 | Self { 29 | model, 30 | batch_size: None, 31 | } 32 | } 33 | } 34 | 35 | #[async_trait] 36 | impl Embedder for FastEmbed { 37 | async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { 38 | let embeddings = self 39 | .model 40 | .embed(documents.to_vec(), self.batch_size) 41 | .map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?; 42 | 43 | Ok(embeddings 44 | .into_iter() 45 | .map(|inner_vec| { 46 | inner_vec 47 | .into_iter() 48 | .map(|x| x as f64) 49 | .collect::>() 50 | }) 51 | .collect::>>()) 52 | } 53 | 54 | async fn embed_query(&self, text: &str) -> Result, EmbedderError> { 55 | let embedding = self 56 | .model 57 | .embed(vec![text], self.batch_size) 58 | .map_err(|e| EmbedderError::FastEmbedError(e.to_string()))?; 59 | 60 | Ok(embedding[0].iter().map(|x| *x as f64).collect()) 61 | } 62 | } 63 | 64 | #[cfg(test)] 65 | mod tests { 66 | use super::*; 67 | #[tokio::test] 68 | async fn test_fastembed() { 69 | let fastembed = FastEmbed::try_new().unwrap(); 70 | let embeddings = fastembed 71 | .embed_documents(&["hello world".to_string(), "foo bar".to_string()]) 72 | .await 73 | .unwrap(); 74 | assert_eq!(embeddings.len(), 2); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/language_models/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | pub mod llm; 6 | pub mod options; 7 | 8 | mod error; 9 | pub use error::*; 10 | 11 | //TODO: check if its this should have a data:serde::Value to save all other things, like OpenAI 12 | //function responses 13 | #[derive(Debug, Serialize, Deserialize, Clone, Default)] 14 | pub struct GenerateResult { 15 | pub tokens: Option, 16 | pub generation: String, 17 | } 18 | 19 | impl GenerateResult { 20 | pub fn to_hashmap(&self) -> HashMap { 21 | let mut map = HashMap::new(); 22 | 23 | // Insert the 'generation' field into the hashmap 24 | map.insert("generation".to_string(), self.generation.clone()); 25 | 26 | // Check if 'tokens' is Some and insert its fields into the hashmap 27 | if let Some(ref tokens) = self.tokens { 28 | map.insert( 29 | "prompt_tokens".to_string(), 30 | tokens.prompt_tokens.to_string(), 31 | ); 32 | map.insert( 33 | "completion_tokens".to_string(), 34 | tokens.completion_tokens.to_string(), 35 | ); 36 | map.insert("total_tokens".to_string(), tokens.total_tokens.to_string()); 37 | } 38 | 39 | map 40 | } 41 | } 42 | 43 | #[derive(Debug, Serialize, Deserialize, Clone, Default)] 44 | pub struct TokenUsage { 45 | pub prompt_tokens: u32, 46 | pub completion_tokens: u32, 47 | pub total_tokens: u32, 48 | } 49 | 50 | impl TokenUsage { 51 | pub fn sum(&self, other: &TokenUsage) -> TokenUsage { 52 | TokenUsage { 53 | prompt_tokens: self.prompt_tokens + other.prompt_tokens, 54 | completion_tokens: self.completion_tokens + other.completion_tokens, 55 | total_tokens: self.total_tokens + other.total_tokens, 56 | } 57 | } 58 | 59 | pub fn add(&mut self, other: &TokenUsage) { 60 | self.prompt_tokens += other.prompt_tokens; 61 | self.completion_tokens += other.completion_tokens; 62 | self.total_tokens += other.total_tokens; 63 | } 64 | } 65 | 66 | impl TokenUsage { 67 | pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self { 68 | Self { 69 | prompt_tokens, 70 | completion_tokens, 71 | total_tokens: prompt_tokens + completion_tokens, 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /examples/llm_qwen_advanced.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | language_models::{llm::LLM, options::CallOptions}, 3 | llm::{Qwen, QwenModel}, 4 | schemas::Message, 5 | }; 6 | use std::{env, io::Write}; 7 | 8 | #[tokio::main] 9 | async fn main() { 10 | // Get API key from environment variable 11 | let api_key = env::var("QWEN_API_KEY").expect("QWEN_API_KEY environment variable must be set"); 12 | // Example 1: Basic generation with options 13 | println!("=== Example 1: Basic Generation with Options ==="); 14 | let qwen = Qwen::new() 15 | .with_api_key(api_key.clone()) 16 | .with_model(QwenModel::QwenTurbo.to_string()) 17 | .with_options( 18 | CallOptions::default() 19 | .with_max_tokens(500) 20 | .with_temperature(0.7) 21 | .with_top_p(0.9), 22 | ); 23 | 24 | // Create a system and user message 25 | let messages = vec![ 26 | Message::new_system_message("You are a helpful AI assistant who responds in Chinese."), 27 | Message::new_human_message( 28 | "What are the three most popular programming languages in 2023?", 29 | ), 30 | ]; 31 | 32 | let response = qwen.generate(&messages).await.unwrap(); 33 | println!("Response: {}", response.generation); 34 | println!("Tokens used: {:?}", response.tokens); 35 | println!("\n"); 36 | 37 | // Example 2: Streaming response 38 | println!("=== Example 2: Streaming Response ==="); 39 | 40 | // Create a streaming callback function 41 | let callback = |content: String| { 42 | print!("{}", content); 43 | let _ = std::io::stdout().flush(); 44 | async { Ok(()) } 45 | }; 46 | 47 | let streaming_options = CallOptions::default() 48 | .with_max_tokens(100) 49 | .with_streaming_func(callback); 50 | 51 | let streaming_qwen = Qwen::new() 52 | .with_api_key(api_key.clone()) 53 | .with_model(QwenModel::QwenPlus.to_string()) 54 | .with_options(streaming_options); 55 | 56 | let stream_messages = vec![Message::new_human_message( 57 | "Write a short poem about artificial intelligence.", 58 | )]; 59 | 60 | println!("Streaming response:"); 61 | let streaming_response = streaming_qwen.generate(&stream_messages).await.unwrap(); 62 | println!( 63 | "\n\nDone streaming. Total tokens: {:?}", 64 | streaming_response.tokens 65 | ); 66 | } 67 | -------------------------------------------------------------------------------- /src/llm/claude/models.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::schemas::{Message, MessageType}; 4 | 5 | #[derive(Serialize, Deserialize)] 6 | pub(crate) struct ClaudeMessage { 7 | pub role: String, 8 | pub content: String, 9 | } 10 | impl ClaudeMessage { 11 | pub fn new>(role: S, content: S) -> Self { 12 | Self { 13 | role: role.into(), 14 | content: content.into(), 15 | } 16 | } 17 | 18 | pub fn from_message(message: &Message) -> Self { 19 | match message.message_type { 20 | MessageType::SystemMessage => Self::new("system", &message.content), 21 | MessageType::AIMessage => Self::new("assistant", &message.content), 22 | MessageType::HumanMessage => Self::new("user", &message.content), 23 | MessageType::ToolMessage => Self::new("tool", &message.content), 24 | } 25 | } 26 | } 27 | 28 | #[derive(Serialize, Deserialize)] 29 | pub(crate) struct Payload { 30 | pub model: String, 31 | pub messages: Vec, 32 | pub max_tokens: u32, 33 | #[serde(skip_serializing_if = "Option::is_none")] 34 | pub system: Option, 35 | #[serde(skip_serializing_if = "Option::is_none")] 36 | pub stream: Option, 37 | #[serde(skip_serializing_if = "Option::is_none")] 38 | pub stop_sequences: Option>, 39 | #[serde(skip_serializing_if = "Option::is_none")] 40 | pub temperature: Option, 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | pub top_p: Option, 43 | #[serde(skip_serializing_if = "Option::is_none")] 44 | pub top_k: Option, 45 | } 46 | 47 | #[derive(Debug, Serialize, Deserialize, Clone)] 48 | pub(crate) struct ApiResponse { 49 | pub content: Vec, 50 | pub id: String, 51 | pub model: String, 52 | pub role: String, 53 | pub stop_reason: Option, 54 | pub stop_sequence: Option, // Adjust based on actual stop_sequence type 55 | #[serde(rename = "type")] 56 | pub message_type: String, 57 | pub usage: Usage, 58 | } 59 | 60 | #[derive(Debug, Serialize, Deserialize, Clone)] 61 | pub(crate) struct Content { 62 | pub text: String, 63 | #[serde(rename = "type")] 64 | pub content_type: String, 65 | } 66 | 67 | #[derive(Debug, Serialize, Deserialize, Clone)] 68 | pub(crate) struct Usage { 69 | pub input_tokens: u32, 70 | pub output_tokens: u32, 71 | } 72 | -------------------------------------------------------------------------------- /src/text_splitter/plain_text_splitter.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | use super::{TextSplitter, TextSplitterError}; 4 | 5 | // Options is a struct that contains options for a plain text splitter. 6 | #[derive(Debug, Clone)] 7 | pub struct PlainTextSplitterOptions { 8 | pub chunk_size: usize, 9 | pub chunk_overlap: usize, 10 | pub trim_chunks: bool, 11 | } 12 | 13 | impl Default for PlainTextSplitterOptions { 14 | fn default() -> Self { 15 | Self::new() 16 | } 17 | } 18 | 19 | impl PlainTextSplitterOptions { 20 | pub fn new() -> Self { 21 | PlainTextSplitterOptions { 22 | chunk_size: 512, 23 | chunk_overlap: 0, 24 | trim_chunks: false, 25 | } 26 | } 27 | 28 | pub fn with_chunk_size(mut self, chunk_size: usize) -> Self { 29 | self.chunk_size = chunk_size; 30 | self 31 | } 32 | 33 | pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self { 34 | self.chunk_overlap = chunk_overlap; 35 | self 36 | } 37 | 38 | pub fn with_trim_chunks(mut self, trim_chunks: bool) -> Self { 39 | self.trim_chunks = trim_chunks; 40 | self 41 | } 42 | 43 | pub fn chunk_size(&self) -> usize { 44 | self.chunk_size 45 | } 46 | 47 | pub fn chunk_overlap(&self) -> usize { 48 | self.chunk_overlap 49 | } 50 | 51 | pub fn trim_chunks(&self) -> bool { 52 | self.trim_chunks 53 | } 54 | } 55 | 56 | pub struct PlainTextSplitter { 57 | splitter_options: PlainTextSplitterOptions, 58 | } 59 | 60 | impl Default for PlainTextSplitter { 61 | fn default() -> Self { 62 | PlainTextSplitter::new(PlainTextSplitterOptions::default()) 63 | } 64 | } 65 | 66 | impl PlainTextSplitter { 67 | pub fn new(options: PlainTextSplitterOptions) -> PlainTextSplitter { 68 | PlainTextSplitter { 69 | splitter_options: options, 70 | } 71 | } 72 | } 73 | 74 | #[async_trait] 75 | impl TextSplitter for PlainTextSplitter { 76 | async fn split_text(&self, text: &str) -> Result, TextSplitterError> { 77 | let splitter = text_splitter::TextSplitter::new( 78 | text_splitter::ChunkConfig::new(self.splitter_options.chunk_size) 79 | .with_trim(self.splitter_options.trim_chunks) 80 | .with_overlap(self.splitter_options.chunk_overlap)?, 81 | ); 82 | 83 | Ok(splitter.chunks(text).map(|x| x.to_string()).collect()) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/vectorstore/opensearch/builder.rs: -------------------------------------------------------------------------------- 1 | use crate::embedding::Embedder; 2 | use crate::vectorstore::opensearch::Store; 3 | use opensearch::OpenSearch; 4 | use std::error::Error; 5 | use std::sync::Arc; 6 | 7 | pub struct StoreBuilder { 8 | client: Option, 9 | embedder: Option>, 10 | k: i32, 11 | index: Option, 12 | vector_field: String, 13 | content_field: String, 14 | } 15 | 16 | impl StoreBuilder { 17 | // Returns a new StoreBuilder instance with default values for each option 18 | pub fn new() -> Self { 19 | StoreBuilder { 20 | client: None, 21 | embedder: None, 22 | k: 2, 23 | index: None, 24 | vector_field: "vector_field".to_string(), 25 | content_field: "page_content".to_string(), 26 | } 27 | } 28 | 29 | pub fn client(mut self, client: OpenSearch) -> Self { 30 | self.client = Some(client); 31 | self 32 | } 33 | 34 | pub fn embedder(mut self, embedder: E) -> Self { 35 | self.embedder = Some(Arc::new(embedder)); 36 | self 37 | } 38 | 39 | pub fn k(mut self, k: i32) -> Self { 40 | self.k = k; 41 | self 42 | } 43 | 44 | pub fn index(mut self, index: &str) -> Self { 45 | self.index = Some(index.to_string()); 46 | self 47 | } 48 | 49 | pub fn vector_field(mut self, vector_field: &str) -> Self { 50 | self.vector_field = vector_field.to_string(); 51 | self 52 | } 53 | 54 | pub fn content_field(mut self, content_field: &str) -> Self { 55 | self.content_field = content_field.to_string(); 56 | self 57 | } 58 | 59 | // Finalize the builder and construct the Store object 60 | pub async fn build(self) -> Result> { 61 | if self.client.is_none() { 62 | return Err("Client is required".into()); 63 | } 64 | 65 | if self.embedder.is_none() { 66 | return Err("Embedder is required".into()); 67 | } 68 | 69 | if self.index.is_none() { 70 | return Err("Index is required".into()); 71 | } 72 | 73 | Ok(Store { 74 | client: self.client.unwrap(), 75 | embedder: self.embedder.unwrap(), 76 | k: self.k, 77 | index: self.index.unwrap(), 78 | vector_field: self.vector_field, 79 | content_field: self.content_field, 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/embedding/mistralai/mistralai_embedder.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use crate::embedding::{embedder_trait::Embedder, EmbedderError}; 4 | use async_trait::async_trait; 5 | use mistralai_client::v1::{client::Client, constants::EmbedModel}; 6 | 7 | pub struct MistralAIEmbedder { 8 | client: Arc, 9 | model: EmbedModel, 10 | } 11 | 12 | impl MistralAIEmbedder { 13 | pub fn try_new() -> Result { 14 | Ok(Self { 15 | client: Arc::new( 16 | Client::new(None, None, None, None).map_err(EmbedderError::MistralAIClientError)?, 17 | ), 18 | model: EmbedModel::MistralEmbed, 19 | }) 20 | } 21 | } 22 | 23 | #[async_trait] 24 | impl Embedder for MistralAIEmbedder { 25 | async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { 26 | log::debug!("Embedding documents: {:?}", documents); 27 | 28 | let response = self 29 | .client 30 | .embeddings_async(self.model.clone(), documents.into(), None) 31 | .await 32 | .map_err(EmbedderError::MistralAIApiError)?; 33 | 34 | Ok(response 35 | .data 36 | .into_iter() 37 | .map(|item| item.embedding.into_iter().map(|x| x as f64).collect()) 38 | .collect::>>()) 39 | } 40 | 41 | async fn embed_query(&self, text: &str) -> Result, EmbedderError> { 42 | log::debug!("Embedding query: {:?}", text); 43 | 44 | let response = self 45 | .client 46 | .embeddings_async(self.model.clone(), vec![text.to_string()], None) 47 | .await 48 | .map_err(EmbedderError::MistralAIApiError)?; 49 | 50 | Ok(response.data[0] 51 | .embedding 52 | .iter() 53 | .map(|x| *x as f64) 54 | .collect()) 55 | } 56 | } 57 | 58 | #[cfg(test)] 59 | mod tests { 60 | use super::*; 61 | 62 | #[tokio::test] 63 | #[ignore] 64 | async fn test_mistralai_embed_query() { 65 | let mistralai = MistralAIEmbedder::try_new().unwrap(); 66 | let embeddings = mistralai.embed_query("Why is the sky blue?").await.unwrap(); 67 | assert_eq!(embeddings.len(), 1024); 68 | } 69 | 70 | #[tokio::test] 71 | #[ignore] 72 | async fn test_mistralai_embed_documents() { 73 | let mistralai = MistralAIEmbedder::try_new().unwrap(); 74 | let embeddings = mistralai 75 | .embed_documents(&["hello world".to_string(), "foo bar".to_string()]) 76 | .await 77 | .unwrap(); 78 | assert_eq!(embeddings.len(), 2); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/vectorstore/vectorstore.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use async_trait::async_trait; 4 | 5 | use crate::schemas::{self, Document}; 6 | 7 | use super::VecStoreOptions; 8 | 9 | // VectorStore is the trait for saving and querying documents in the 10 | // form of vector embeddings. 11 | #[async_trait] 12 | pub trait VectorStore: Send + Sync { 13 | type Options; 14 | 15 | async fn add_documents( 16 | &self, 17 | docs: &[Document], 18 | opt: &Self::Options, 19 | ) -> Result, Box>; 20 | 21 | async fn similarity_search( 22 | &self, 23 | query: &str, 24 | limit: usize, 25 | opt: &Self::Options, 26 | ) -> Result, Box>; 27 | } 28 | 29 | impl From for Box> 30 | where 31 | VS: 'static + VectorStore, 32 | { 33 | fn from(vector_store: VS) -> Self { 34 | Box::new(vector_store) 35 | } 36 | } 37 | 38 | #[macro_export] 39 | macro_rules! add_documents { 40 | ($obj:expr, $docs:expr) => { 41 | $obj.add_documents($docs, &$crate::vectorstore::VecStoreOptions::default()) 42 | }; 43 | ($obj:expr, $docs:expr, $opt:expr) => { 44 | $obj.add_documents($docs, $opt) 45 | }; 46 | } 47 | 48 | #[macro_export] 49 | macro_rules! similarity_search { 50 | ($obj:expr, $query:expr, $limit:expr) => { 51 | $obj.similarity_search( 52 | $query, 53 | $limit, 54 | &$crate::vectorstore::VecStoreOptions::default(), 55 | ) 56 | }; 57 | ($obj:expr, $query:expr, $limit:expr, $opt:expr) => { 58 | $obj.similarity_search($query, $limit, $opt) 59 | }; 60 | } 61 | 62 | // Retriever is a retriever for vector stores. 63 | pub struct Retriever { 64 | vstore: Box>>, 65 | num_docs: usize, 66 | options: VecStoreOptions, 67 | } 68 | impl Retriever { 69 | pub fn new>>>>( 70 | vstore: V, 71 | num_docs: usize, 72 | ) -> Self { 73 | Retriever { 74 | vstore: vstore.into(), 75 | num_docs, 76 | options: VecStoreOptions::::new(), 77 | } 78 | } 79 | 80 | pub fn with_options(mut self, options: VecStoreOptions) -> Self { 81 | self.options = options; 82 | self 83 | } 84 | } 85 | 86 | #[async_trait] 87 | impl schemas::Retriever for Retriever { 88 | async fn get_relevant_documents(&self, query: &str) -> Result, Box> { 89 | self.vstore 90 | .similarity_search(query, self.num_docs, &self.options) 91 | .await 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /examples/vector_store_qdrant.rs: -------------------------------------------------------------------------------- 1 | // To run this example execute: cargo run --example vector_store_qdrant --features qdrant 2 | 3 | #[cfg(feature = "qdrant")] 4 | use langchain_rust::{ 5 | embedding::openai::openai_embedder::OpenAiEmbedder, 6 | schemas::Document, 7 | vectorstore::qdrant::{Qdrant, StoreBuilder}, 8 | vectorstore::VectorStore, 9 | }; 10 | #[cfg(feature = "qdrant")] 11 | use std::io::Write; 12 | 13 | #[cfg(feature = "qdrant")] 14 | #[tokio::main] 15 | async fn main() { 16 | // Initialize Embedder 17 | 18 | use langchain_rust::vectorstore::VecStoreOptions; 19 | 20 | // Requires OpenAI API key to be set in the environment variable OPENAI_API_KEY 21 | let embedder = OpenAiEmbedder::default(); 22 | 23 | // Initialize the qdrant_client::Qdrant 24 | // Ensure Qdrant is running at localhost, with gRPC port at 6334 25 | // docker run -p 6334:6334 qdrant/qdrant 26 | let client = Qdrant::from_url("http://localhost:6334").build().unwrap(); 27 | 28 | let store = StoreBuilder::new() 29 | .embedder(embedder) 30 | .client(client) 31 | .collection_name("langchain-rs") 32 | .build() 33 | .await 34 | .unwrap(); 35 | 36 | // Add documents to the database 37 | let doc1 = Document::new( 38 | "langchain-rust is a port of the langchain python library to rust and was written in 2024.", 39 | ); 40 | let doc2 = Document::new( 41 | "langchaingo is a port of the langchain python library to go language and was written in 2023." 42 | ); 43 | let doc3 = Document::new( 44 | "Capital of United States of America (USA) is Washington D.C. and the capital of France is Paris." 45 | ); 46 | let doc4 = Document::new("Capital of France is Paris."); 47 | 48 | store 49 | .add_documents(&vec![doc1, doc2, doc3, doc4], &VecStoreOptions::default()) 50 | .await 51 | .unwrap(); 52 | 53 | // Ask for user input 54 | print!("Query> "); 55 | std::io::stdout().flush().unwrap(); 56 | let mut query = String::new(); 57 | std::io::stdin().read_line(&mut query).unwrap(); 58 | 59 | let results = store 60 | .similarity_search(&query, 2, &VecStoreOptions::default()) 61 | .await 62 | .unwrap(); 63 | 64 | if results.is_empty() { 65 | println!("No results found."); 66 | return; 67 | } else { 68 | results.iter().for_each(|r| { 69 | println!("Document: {}", r.page_content); 70 | }); 71 | } 72 | } 73 | 74 | #[cfg(not(feature = "qdrant"))] 75 | fn main() { 76 | println!("This example requires the 'qdrant' feature to be enabled."); 77 | println!("Please run the command as follows:"); 78 | println!("cargo run --example vector_store_qdrant --features qdrant"); 79 | } 80 | -------------------------------------------------------------------------------- /examples/conversational_chain.rs: -------------------------------------------------------------------------------- 1 | use std::io::{stdout, Write}; 2 | 3 | use futures_util::StreamExt; 4 | use langchain_rust::{ 5 | chain::{builder::ConversationalChainBuilder, Chain}, 6 | // fmt_message, fmt_template, 7 | llm::openai::{OpenAI, OpenAIModel}, 8 | memory::SimpleMemory, 9 | // message_formatter, 10 | // prompt::HumanMessagePromptTemplate, 11 | prompt_args, 12 | // schemas::Message, 13 | // template_fstring, 14 | }; 15 | 16 | #[tokio::main] 17 | async fn main() { 18 | let llm = OpenAI::default().with_model(OpenAIModel::Gpt35); 19 | //We initialise a simple memory. By default conversational chain have this memory, but we 20 | //initialise it as an example, if you dont want to have memory use DummyMemory 21 | let memory = SimpleMemory::new(); 22 | 23 | let chain = ConversationalChainBuilder::new() 24 | .llm(llm) 25 | //IF YOU WANT TO ADD A CUSTOM PROMPT YOU CAN UN COMMENT THIS: 26 | // .prompt(message_formatter![ 27 | // fmt_message!(Message::new_system_message("You are a helpful assistant")), 28 | // fmt_template!(HumanMessagePromptTemplate::new( 29 | // template_fstring!(" 30 | // The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. 31 | // 32 | // Current conversation: 33 | // {history} 34 | // Human: {input} 35 | // AI: 36 | // ", 37 | // "input","history"))) 38 | // 39 | // ]) 40 | .memory(memory.into()) 41 | .build() 42 | .expect("Error building ConversationalChain"); 43 | 44 | let input_variables = prompt_args! { 45 | "input" => "Im from Peru", 46 | }; 47 | 48 | let mut stream = chain.stream(input_variables).await.unwrap(); 49 | while let Some(result) = stream.next().await { 50 | match result { 51 | Ok(data) => { 52 | //If you just want to print to stdout, you can use data.to_stdout().unwrap(); 53 | print!("{}", data.content); 54 | stdout().flush().unwrap(); 55 | } 56 | Err(e) => { 57 | println!("Error: {:?}", e); 58 | } 59 | } 60 | } 61 | 62 | let input_variables = prompt_args! { 63 | "input" => "Which are the typical dish", 64 | }; 65 | match chain.invoke(input_variables).await { 66 | Ok(result) => { 67 | println!("\n"); 68 | println!("Result: {:?}", result); 69 | } 70 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/chain/stuff_documents/builder.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | chain::{options::ChainCallOptions, ChainError, LLMChainBuilder}, 3 | language_models::llm::LLM, 4 | output_parsers::OutputParser, 5 | prompt::FormatPrompter, 6 | template_jinja2, 7 | }; 8 | 9 | use super::StuffDocument; 10 | 11 | pub struct StuffDocumentBuilder { 12 | llm: Option>, 13 | options: Option, 14 | output_key: Option, 15 | output_parser: Option>, 16 | prompt: Option>, 17 | } 18 | impl StuffDocumentBuilder { 19 | pub fn new() -> Self { 20 | Self { 21 | llm: None, 22 | options: None, 23 | output_key: None, 24 | output_parser: None, 25 | prompt: None, 26 | } 27 | } 28 | 29 | pub fn llm>>(mut self, llm: L) -> Self { 30 | self.llm = Some(llm.into()); 31 | self 32 | } 33 | 34 | pub fn options(mut self, options: ChainCallOptions) -> Self { 35 | self.options = Some(options); 36 | self 37 | } 38 | 39 | pub fn output_key>(mut self, output_key: S) -> Self { 40 | self.output_key = Some(output_key.into()); 41 | self 42 | } 43 | 44 | ///If you want to add a custom prompt,keep in mind which variables are obligatory. 45 | pub fn prompt>>(mut self, prompt: P) -> Self { 46 | self.prompt = Some(prompt.into()); 47 | self 48 | } 49 | 50 | pub fn build(self) -> Result { 51 | let llm = self 52 | .llm 53 | .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?; 54 | let prompt = match self.prompt { 55 | Some(prompt) => prompt, 56 | None => Box::new(template_jinja2!( 57 | DEFAULT_STUFF_QA_TEMPLATE, 58 | "context", 59 | "question" 60 | )), 61 | }; 62 | 63 | let llm_chain = { 64 | let mut builder = LLMChainBuilder::new() 65 | .prompt(prompt) 66 | .options(self.options.unwrap_or_default()) 67 | .llm(llm); 68 | if let Some(output_parser) = self.output_parser { 69 | builder = builder.output_parser(output_parser); 70 | } 71 | 72 | builder.build()? 73 | }; 74 | 75 | Ok(StuffDocument::new(llm_chain)) 76 | } 77 | } 78 | 79 | const DEFAULT_STUFF_QA_TEMPLATE: &str = r#"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. 80 | 81 | {{context}} 82 | 83 | Question:{{question}} 84 | Helpful Answer: 85 | "#; 86 | -------------------------------------------------------------------------------- /examples/vector_store_postgres.rs: -------------------------------------------------------------------------------- 1 | // To run this example execute: cargo run --example vector_store_postgres --features postgres 2 | // To start pgvector docker run ./scripts/run-pgvector 3 | 4 | #[cfg(feature = "postgres")] 5 | use langchain_rust::{ 6 | add_documents, 7 | embedding::openai::openai_embedder::OpenAiEmbedder, 8 | schemas::Document, 9 | similarity_search, 10 | vectorstore::{pgvector::StoreBuilder, VectorStore}, 11 | }; 12 | #[cfg(feature = "postgres")] 13 | use std::io::Write; 14 | #[cfg(feature = "postgres")] 15 | use tokio::io::{self, AsyncBufReadExt, BufReader}; 16 | 17 | #[cfg(feature = "postgres")] 18 | #[tokio::main] 19 | async fn main() { 20 | // Initialize Embedder 21 | let embedder = OpenAiEmbedder::default(); 22 | 23 | // Initialize the Postgres Vector Store 24 | let store = StoreBuilder::new() 25 | .embedder(embedder) 26 | .pre_delete_collection(true) 27 | .connection_url("postgresql://username:password@localhost:5432/langchain-rust") 28 | .vector_dimensions(1536) 29 | .build() 30 | .await 31 | .unwrap(); 32 | 33 | // Get input with words list 34 | let mut input = String::new(); 35 | print!("Please enter a list separated by commas: "); 36 | std::io::stdout().flush().unwrap(); 37 | let mut reader = BufReader::new(io::stdin()); 38 | reader.read_line(&mut input).await.unwrap(); 39 | let input = input.trim_end(); 40 | let list: Vec<&str> = input.split(',').collect(); 41 | 42 | // Transform it to a list of documents 43 | let documents: Vec = list 44 | .iter() 45 | .map(|text| Document::new(text.trim().to_string())) 46 | .collect(); 47 | 48 | // Add documents to the database 49 | let _ = add_documents!(store, &documents).await.map_err(|e| { 50 | println!("Error adding documents: {:?}", e); 51 | }); 52 | 53 | // Get the input to search 54 | let mut search_input = String::new(); 55 | print!("Please enter the text you want to search: "); 56 | std::io::stdout().flush().unwrap(); 57 | 58 | reader.read_line(&mut search_input).await.unwrap(); 59 | let search_input = search_input.trim_end(); 60 | 61 | // Perform a similarity search in the database 62 | let data = similarity_search!(store, search_input, 10) 63 | .await 64 | .map_err(|e| { 65 | println!("Error searching documents: {:?}", e); 66 | }) 67 | .unwrap(); 68 | 69 | data.iter().for_each(|d| println!("{:?}", d.page_content)); 70 | } 71 | 72 | #[cfg(not(feature = "postgres"))] 73 | fn main() { 74 | println!("This example requires the 'postgres' feature to be enabled."); 75 | println!("Please run the command as follows:"); 76 | println!("cargo run --example vector_store_postgres --features postgres"); 77 | } 78 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | - main 8 | push: 9 | branches: 10 | - main 11 | tags: 12 | - 'v0.[0-9]+.[0-9]+' 13 | - 'v0.[0-9]+.[0-9]+-beta.[0-9]+' 14 | - 'v0.[0-9]+.[0-9]+-alpha.[0-9]+' 15 | paths-ignore: 16 | - 'renovate.json' 17 | 18 | jobs: 19 | build: 20 | runs-on: ubuntu-latest 21 | env: 22 | # emit backtraces on panics. 23 | RUST_BACKTRACE: 1 24 | steps: 25 | - uses: actions/checkout@v4 26 | with: 27 | fetch-depth: 1 28 | - name: Get the build metadata 29 | shell: bash 30 | run: | 31 | echo "TAG_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 32 | echo "CARGO_VERSION=$(grep -m 1 '^version = ' Cargo.toml | cut -f 3 -d ' ' | tr -d \")" >> $GITHUB_ENV 33 | - name: Validate git tag and Cargo.toml version 34 | shell: bash 35 | if: startsWith(github.ref, 'refs/tags/') 36 | run: | 37 | if [ "${{ env.TAG_VERSION }}" != "v${{ env.CARGO_VERSION }}" ]; then 38 | echo "git tag version (${{ env.TAG_VERSION }}) does not match Cargo.toml version (v${{ env.CARGO_VERSION }})" 39 | exit 1 40 | fi 41 | - name: install pandoc 42 | run: sudo apt-get install pandoc 43 | - name: Install Rust 44 | uses: actions-rs/toolchain@v1 45 | with: 46 | toolchain: stable 47 | override: true 48 | profile: minimal 49 | components: rustfmt, clippy 50 | # - uses: swatinem/rust-cache@v2.7.3 51 | - name: Rust fmt 52 | uses: actions-rs/cargo@v1 53 | with: 54 | command: fmt 55 | args: --all -- --check 56 | - name: Build release 57 | uses: actions-rs/cargo@v1 58 | with: 59 | command: build 60 | args: --verbose --all --release --all-features 61 | - name: Run Test 62 | uses: actions-rs/cargo@v1 63 | with: 64 | command: test 65 | args: --release --all-features 66 | 67 | publish_crate: 68 | if: startsWith(github.ref, 'refs/tags/') 69 | needs: 70 | - build 71 | runs-on: ubuntu-latest 72 | steps: 73 | - uses: actions/checkout@v4 74 | with: 75 | fetch-depth: 1 76 | - name: Install Rust 77 | uses: actions-rs/toolchain@v1 78 | with: 79 | toolchain: stable 80 | profile: minimal 81 | override: true 82 | - name: Login to crates.io 83 | uses: actions-rs/cargo@v1 84 | with: 85 | command: login 86 | args: ${{ secrets.CRATES_TOKEN }} 87 | - name: Publish langchain-rust to crates.io 88 | uses: actions-rs/cargo@v1 89 | with: 90 | command: publish 91 | args: --all-features 92 | -------------------------------------------------------------------------------- /examples/vector_store_sqlite_vec.rs: -------------------------------------------------------------------------------- 1 | // Make sure vec0 libraries are installed in the system or the path of the executable. 2 | // To run this example execute: cargo run --example vector_store_sqlite_vec --features sqlite-vec 3 | // Download the libraries from https://github.com/asg017/sqlite-vec 4 | 5 | #[cfg(feature = "sqlite-vec")] 6 | use langchain_rust::{ 7 | embedding::openai::openai_embedder::OpenAiEmbedder, 8 | schemas::Document, 9 | vectorstore::{sqlite_vec::StoreBuilder, VecStoreOptions, VectorStore}, 10 | }; 11 | 12 | #[cfg(feature = "sqlite-vec")] 13 | use std::io::Write; 14 | 15 | #[cfg(feature = "sqlite-vec")] 16 | #[tokio::main] 17 | async fn main() { 18 | // Initialize Embedder 19 | let embedder = OpenAiEmbedder::default(); 20 | 21 | let database_url = std::env::var("DATABASE_URL").unwrap_or("sqlite::memory:".to_string()); 22 | 23 | // Initialize the Sqlite Vector Store 24 | let store = StoreBuilder::new() 25 | .embedder(embedder) 26 | .connection_url(database_url) 27 | .table("documents") 28 | .vector_dimensions(1536) 29 | .build() 30 | .await 31 | .unwrap(); 32 | 33 | // Initialize the tables in the database. This is required to be done only once. 34 | store.initialize().await.unwrap(); 35 | 36 | // Add documents to the database 37 | let doc1 = Document::new( 38 | "langchain-rust is a port of the langchain python library to rust and was written in 2024.", 39 | ); 40 | let doc2 = Document::new( 41 | "langchaingo is a port of the langchain python library to go language and was written in 2023." 42 | ); 43 | let doc3 = Document::new( 44 | "Capital of United States of America (USA) is Washington D.C. and the capital of France is Paris." 45 | ); 46 | let doc4 = Document::new("Capital of France is Paris."); 47 | 48 | store 49 | .add_documents(&vec![doc1, doc2, doc3, doc4], &VecStoreOptions::default()) 50 | .await 51 | .unwrap(); 52 | 53 | // Ask for user input 54 | print!("Query> "); 55 | std::io::stdout().flush().unwrap(); 56 | let mut query = String::new(); 57 | std::io::stdin().read_line(&mut query).unwrap(); 58 | 59 | let results = store 60 | .similarity_search(&query, 2, &VecStoreOptions::default()) 61 | .await 62 | .unwrap(); 63 | 64 | if results.is_empty() { 65 | println!("No results found."); 66 | return; 67 | } else { 68 | results.iter().for_each(|r| { 69 | println!("Document: {}", r.page_content); 70 | }); 71 | } 72 | } 73 | 74 | #[cfg(not(feature = "sqlite-vec"))] 75 | fn main() { 76 | println!("This example requires the 'sqlite-vec' feature to be enabled."); 77 | println!("Please run the command as follows:"); 78 | println!("cargo run --example vector_store_sqlite_vec --features sqlite-vec"); 79 | } 80 | -------------------------------------------------------------------------------- /src/embedding/openai/openai_embedder.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | 3 | use crate::embedding::{embedder_trait::Embedder, EmbedderError}; 4 | pub use async_openai::config::{AzureConfig, Config, OpenAIConfig}; 5 | use async_openai::{ 6 | types::{CreateEmbeddingRequestArgs, EmbeddingInput}, 7 | Client, 8 | }; 9 | use async_trait::async_trait; 10 | 11 | #[derive(Debug)] 12 | pub struct OpenAiEmbedder { 13 | config: C, 14 | model: String, 15 | } 16 | 17 | impl Into> for OpenAiEmbedder { 18 | fn into(self) -> Box { 19 | Box::new(self) 20 | } 21 | } 22 | 23 | impl OpenAiEmbedder { 24 | pub fn new(config: C) -> Self { 25 | OpenAiEmbedder { 26 | config, 27 | model: String::from("text-embedding-ada-002"), 28 | } 29 | } 30 | 31 | pub fn with_model>(mut self, model: S) -> Self { 32 | self.model = model.into(); 33 | self 34 | } 35 | 36 | pub fn with_config(mut self, config: C) -> Self { 37 | self.config = config; 38 | self 39 | } 40 | } 41 | 42 | impl Default for OpenAiEmbedder { 43 | fn default() -> Self { 44 | OpenAiEmbedder::new(OpenAIConfig::default()) 45 | } 46 | } 47 | 48 | #[async_trait] 49 | impl Embedder for OpenAiEmbedder { 50 | async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { 51 | let client = Client::with_config(self.config.clone()); 52 | 53 | let request = CreateEmbeddingRequestArgs::default() 54 | .model(&self.model) 55 | .input(EmbeddingInput::StringArray(documents.into())) 56 | .build()?; 57 | 58 | let response = client.embeddings().create(request).await?; 59 | 60 | let embeddings = response 61 | .data 62 | .into_iter() 63 | .map(|item| item.embedding) 64 | .map(|embedding| { 65 | embedding 66 | .into_iter() 67 | .map(|x| x as f64) 68 | .collect::>() 69 | }) 70 | .collect(); 71 | 72 | Ok(embeddings) 73 | } 74 | 75 | async fn embed_query(&self, text: &str) -> Result, EmbedderError> { 76 | let client = Client::with_config(self.config.clone()); 77 | 78 | let request = CreateEmbeddingRequestArgs::default() 79 | .model(&self.model) 80 | .input(text) 81 | .build()?; 82 | 83 | let mut response = client.embeddings().create(request).await?; 84 | 85 | let item = response.data.swap_remove(0); 86 | 87 | Ok(item 88 | .embedding 89 | .into_iter() 90 | .map(|x| x as f64) 91 | .collect::>()) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /examples/sql_chain.rs: -------------------------------------------------------------------------------- 1 | // To run this example execute: cargo run --example sql_chain --features postgres 2 | 3 | #[cfg(feature = "postgres")] 4 | use langchain_rust::{ 5 | chain::{options::ChainCallOptions, Chain, SQLDatabaseChainBuilder}, 6 | llm::openai::OpenAI, 7 | tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder}, 8 | }; 9 | 10 | #[cfg(feature = "postgres")] 11 | use std::io::{self, Write}; // Include io Library for terminal input 12 | 13 | #[cfg(feature = "postgres")] 14 | #[tokio::main] 15 | async fn main() { 16 | let options = ChainCallOptions::default(); 17 | let llm = OpenAI::default(); 18 | 19 | let db = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); 20 | let engine = PostgreSQLEngine::new(&db).await.unwrap(); 21 | let db = SQLDatabaseBuilder::new(engine).build().await.unwrap(); 22 | let chain = SQLDatabaseChainBuilder::new() 23 | .llm(llm) 24 | .top_k(4) 25 | .database(db) 26 | .options(options) 27 | .build() 28 | .expect("Failed to build LLMChain"); 29 | 30 | print!("Please enter a question: "); 31 | io::stdout().flush().unwrap(); 32 | 33 | let mut input = String::new(); 34 | io::stdin().read_line(&mut input).unwrap(); 35 | 36 | let input = input.trim(); 37 | let input_variables = chain.prompt_builder().query(input).build(); 38 | match chain.invoke(input_variables).await { 39 | Ok(result) => { 40 | println!("Result: {:?}", result); 41 | } 42 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 43 | } 44 | } 45 | 46 | #[cfg(not(feature = "postgres"))] 47 | fn main() { 48 | println!("This example requires the 'postgres' feature to be enabled."); 49 | println!("Please run the command as follows:"); 50 | println!("cargo run --example sql_chain --features postgres"); 51 | } 52 | 53 | //You can use this docker migrations for example, you can ask , whats the phone number of John 54 | 55 | // -- Migrations file 56 | // 57 | // -- Create the 'users' table 58 | // CREATE TABLE users ( 59 | // id serial PRIMARY KEY, 60 | // name varchar(255), 61 | // address text 62 | // ); 63 | // 64 | // -- Create the 'more_info' table 65 | // CREATE TABLE more_info ( 66 | // id serial PRIMARY KEY, 67 | // user_id int references users(id), 68 | // ig_nickname varchar(255), 69 | // phone_number varchar(255) 70 | // ); 71 | // 72 | // 73 | // -- Dummy Data 74 | // 75 | // -- Inserting into 'users' table 76 | // INSERT INTO users(name, address) 77 | // VALUES 78 | // ('John Doe', '123 Main St'), 79 | // ('Jane Doe', '456 Oak St'), 80 | // ('Jim Doe', '789 Pine St'); 81 | // 82 | // -- Inserting into 'more_info' table 83 | // INSERT INTO more_info(user_id, ig_nickname, phone_number) 84 | // VALUES 85 | // (1, 'john_ig', '123-456-7890'), 86 | // (2, 'jane_ig', '456-789-0123'), 87 | // (3, 'jim_ig', '789-012-3456'); 88 | -------------------------------------------------------------------------------- /src/agent/chat/prompt.rs: -------------------------------------------------------------------------------- 1 | pub const PREFIX: &str = r#" 2 | 3 | Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. 4 | 5 | Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. 6 | 7 | Overall, Assistant is a powerful system that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist."#; 8 | 9 | pub const FORMAT_INSTRUCTIONS: &str = r#"RESPONSE FORMAT INSTRUCTIONS 10 | ---------------------------- 11 | 12 | When responding to me, please output a response in one of two formats: 13 | 14 | **Option 1:** 15 | Use this if you want the human to use a tool. 16 | Markdown code snippet formatted in the following schema: 17 | 18 | ```json 19 | { 20 | "action": string, \\ The action to take. Must be one of {{tool_names}} 21 | "action_input": string \\ The input to the action 22 | } 23 | ``` 24 | 25 | **Option #2:** 26 | Use this if you want to respond directly to the human. Markdown code snippet formatted in the following schema: 27 | 28 | ```json 29 | { 30 | "action": "Final Answer", 31 | "action_input": string \\ You should put what you want to return to use here 32 | } 33 | 34 | ```"#; 35 | 36 | pub const SUFFIX: &str = r#"TOOLS 37 | ------ 38 | Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are: 39 | 40 | {{tools}} 41 | 42 | {{format_instructions}} 43 | 44 | USER'S INPUT 45 | Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): 46 | 47 | {{input}}"#; 48 | 49 | pub const TEMPLATE_TOOL_RESPONSE: &str = r#"TOOL RESPONSE: 50 | --------------------- 51 | {{observation}} 52 | 53 | USER'S INPUT 54 | -------------------- 55 | 56 | Okay, so what is the response to my last comment? If using information obtained from the tools you must mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! Remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else."#; 57 | -------------------------------------------------------------------------------- /src/vectorstore/sqlite_vec/builder.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, str::FromStr, sync::Arc}; 2 | 3 | use sqlx::{ 4 | sqlite::{SqliteConnectOptions, SqlitePoolOptions}, 5 | Pool, Sqlite, 6 | }; 7 | 8 | use super::Store; 9 | use crate::embedding::embedder_trait::Embedder; 10 | 11 | pub struct StoreBuilder { 12 | pool: Option>, 13 | connection_url: Option, 14 | table: String, 15 | vector_dimensions: i32, 16 | embedder: Option>, 17 | } 18 | 19 | impl StoreBuilder { 20 | pub fn new() -> Self { 21 | StoreBuilder { 22 | pool: None, 23 | connection_url: None, 24 | table: "documents".to_string(), 25 | vector_dimensions: 0, 26 | embedder: None, 27 | } 28 | } 29 | 30 | pub fn pool(mut self, pool: Pool) -> Self { 31 | self.pool = Some(pool); 32 | self.connection_url = None; 33 | self 34 | } 35 | 36 | pub fn connection_url>(mut self, connection_url: S) -> Self { 37 | self.connection_url = Some(connection_url.into()); 38 | self.pool = None; 39 | self 40 | } 41 | 42 | pub fn table(mut self, table: &str) -> Self { 43 | self.table = table.into(); 44 | self 45 | } 46 | 47 | pub fn vector_dimensions(mut self, vector_dimensions: i32) -> Self { 48 | self.vector_dimensions = vector_dimensions; 49 | self 50 | } 51 | 52 | pub fn embedder(mut self, embedder: E) -> Self { 53 | self.embedder = Some(Arc::new(embedder)); 54 | self 55 | } 56 | 57 | // Finalize the builder and construct the Store object 58 | pub async fn build(self) -> Result> { 59 | if self.embedder.is_none() { 60 | return Err("Embedder is required".into()); 61 | } 62 | 63 | Ok(Store { 64 | pool: self.get_pool().await?, 65 | table: self.table, 66 | vector_dimensions: self.vector_dimensions, 67 | embedder: self.embedder.unwrap(), 68 | }) 69 | } 70 | 71 | async fn get_pool(&self) -> Result, Box> { 72 | match &self.pool { 73 | Some(pool) => Ok(pool.clone()), 74 | None => { 75 | let connection_url = self 76 | .connection_url 77 | .as_ref() 78 | .ok_or("Connection URL or DB is required")?; 79 | 80 | let pool: Pool = SqlitePoolOptions::new() 81 | .connect_with( 82 | SqliteConnectOptions::from_str(connection_url)? 83 | .create_if_missing(true) 84 | .extension("vec0"), 85 | ) 86 | .await?; 87 | 88 | Ok(pool) 89 | } 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/semantic_router/index/memory_index.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use async_trait::async_trait; 4 | 5 | use crate::semantic_router::{utils::cosine_similarity, IndexError, Router}; 6 | 7 | use super::Index; 8 | 9 | pub struct MemoryIndex { 10 | routers: HashMap, 11 | } 12 | impl MemoryIndex { 13 | pub fn new() -> Self { 14 | Self { 15 | routers: HashMap::new(), 16 | } 17 | } 18 | } 19 | 20 | #[async_trait] 21 | impl Index for MemoryIndex { 22 | async fn add(&mut self, routers: &[Router]) -> Result<(), IndexError> { 23 | for router in routers { 24 | if router.embedding.is_none() { 25 | return Err(IndexError::MissingEmbedding(router.name.clone())); 26 | } 27 | if self.routers.contains_key(&router.name) { 28 | log::warn!("Router {} already exists in the index", router.name); 29 | } 30 | self.routers.insert(router.name.clone(), router.clone()); 31 | } 32 | 33 | Ok(()) 34 | } 35 | 36 | async fn delete(&mut self, router_name: &str) -> Result<(), IndexError> { 37 | if self.routers.remove(router_name).is_none() { 38 | log::warn!("Router {} not found in the index", router_name); 39 | } 40 | Ok(()) 41 | } 42 | 43 | async fn query(&self, vector: &[f64], top_k: usize) -> Result, IndexError> { 44 | let mut all_similarities: Vec<(String, f64)> = Vec::new(); 45 | 46 | // Compute similarity for each embedding of each router 47 | for (name, router) in &self.routers { 48 | if let Some(embeddings) = &router.embedding { 49 | for embedding in embeddings { 50 | let similarity = cosine_similarity(vector, embedding); 51 | all_similarities.push((name.clone(), similarity)); 52 | } 53 | } 54 | } 55 | 56 | // Sort all similarities by descending similarity score 57 | all_similarities 58 | .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); 59 | 60 | // Only keep the top_k similarities 61 | let top_similarities: Vec<(String, f64)> = 62 | all_similarities.into_iter().take(top_k).collect(); 63 | 64 | Ok(top_similarities) 65 | } 66 | 67 | async fn get_routers(&self) -> Result, IndexError> { 68 | let routes = self.routers.values().cloned().collect(); 69 | Ok(routes) 70 | } 71 | 72 | async fn get_router(&self, route_name: &str) -> Result { 73 | return self 74 | .routers 75 | .get(route_name) 76 | .cloned() 77 | .ok_or(IndexError::RouterNotFound(route_name.into())); 78 | } 79 | 80 | async fn delete_index(&mut self) -> Result<(), IndexError> { 81 | self.routers.clear(); 82 | Ok(()) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/vectorstore/sqlite_vss/builder.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, str::FromStr, sync::Arc}; 2 | 3 | use sqlx::{ 4 | sqlite::{SqliteConnectOptions, SqlitePoolOptions}, 5 | Pool, Sqlite, 6 | }; 7 | 8 | use super::Store; 9 | use crate::embedding::embedder_trait::Embedder; 10 | 11 | pub struct StoreBuilder { 12 | pool: Option>, 13 | connection_url: Option, 14 | table: String, 15 | vector_dimensions: i32, 16 | embedder: Option>, 17 | } 18 | 19 | impl StoreBuilder { 20 | pub fn new() -> Self { 21 | StoreBuilder { 22 | pool: None, 23 | connection_url: None, 24 | table: "documents".to_string(), 25 | vector_dimensions: 0, 26 | embedder: None, 27 | } 28 | } 29 | 30 | pub fn pool(mut self, pool: Pool) -> Self { 31 | self.pool = Some(pool); 32 | self.connection_url = None; 33 | self 34 | } 35 | 36 | pub fn connection_url>(mut self, connection_url: S) -> Self { 37 | self.connection_url = Some(connection_url.into()); 38 | self.pool = None; 39 | self 40 | } 41 | 42 | pub fn table(mut self, table: &str) -> Self { 43 | self.table = table.into(); 44 | self 45 | } 46 | 47 | pub fn vector_dimensions(mut self, vector_dimensions: i32) -> Self { 48 | self.vector_dimensions = vector_dimensions; 49 | self 50 | } 51 | 52 | pub fn embedder(mut self, embedder: E) -> Self { 53 | self.embedder = Some(Arc::new(embedder)); 54 | self 55 | } 56 | 57 | // Finalize the builder and construct the Store object 58 | pub async fn build(self) -> Result> { 59 | if self.embedder.is_none() { 60 | return Err("Embedder is required".into()); 61 | } 62 | 63 | Ok(Store { 64 | pool: self.get_pool().await?, 65 | table: self.table, 66 | vector_dimensions: self.vector_dimensions, 67 | embedder: self.embedder.unwrap(), 68 | }) 69 | } 70 | 71 | async fn get_pool(&self) -> Result, Box> { 72 | match &self.pool { 73 | Some(pool) => Ok(pool.clone()), 74 | None => { 75 | let connection_url = self 76 | .connection_url 77 | .as_ref() 78 | .ok_or("Connection URL or DB is required")?; 79 | 80 | let pool: Pool = SqlitePoolOptions::new() 81 | .connect_with( 82 | SqliteConnectOptions::from_str(connection_url)? 83 | .create_if_missing(true) 84 | .extension("vector0") 85 | .extension("vss0"), 86 | ) 87 | .await?; 88 | 89 | Ok(pool) 90 | } 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /examples/vector_store_sqlite_vss.rs: -------------------------------------------------------------------------------- 1 | // To run this example execute: cargo run --example vector_store_sqlite_vss --features sqlite-vss 2 | // Make sure vector0 and vss0 libraries are installed in the system or the path of the executable. 3 | // Download the libraries from https://github.com/asg017/sqlite-vss 4 | // For static compilation of sqlite-vss extension refer to the following link: 5 | // https://github.com/launchbadge/sqlx/issues/3147. 6 | 7 | #[cfg(feature = "sqlite-vss")] 8 | use langchain_rust::{ 9 | embedding::openai::openai_embedder::OpenAiEmbedder, 10 | schemas::Document, 11 | vectorstore::{sqlite_vss::StoreBuilder, VecStoreOptions, VectorStore}, 12 | }; 13 | #[cfg(feature = "sqlite-vss")] 14 | use std::io::Write; 15 | 16 | #[cfg(feature = "sqlite-vss")] 17 | #[tokio::main] 18 | async fn main() { 19 | // Initialize Embedder 20 | let embedder = OpenAiEmbedder::default(); 21 | 22 | let database_url = std::env::var("DATABASE_URL").unwrap_or("sqlite::memory:".to_string()); 23 | 24 | // Initialize the Sqlite Vector Store 25 | let store = StoreBuilder::new() 26 | .embedder(embedder) 27 | .connection_url(database_url) 28 | .table("documents") 29 | .vector_dimensions(1536) 30 | .build() 31 | .await 32 | .unwrap(); 33 | 34 | // Initialize the tables in the database. This is required to be done only once. 35 | store.initialize().await.unwrap(); 36 | 37 | // Add documents to the database 38 | let doc1 = Document::new( 39 | "langchain-rust is a port of the langchain python library to rust and was written in 2024.", 40 | ); 41 | let doc2 = Document::new( 42 | "langchaingo is a port of the langchain python library to go language and was written in 2023." 43 | ); 44 | let doc3 = Document::new( 45 | "Capital of United States of America (USA) is Washington D.C. and the capital of France is Paris." 46 | ); 47 | let doc4 = Document::new("Capital of France is Paris."); 48 | 49 | store 50 | .add_documents(&vec![doc1, doc2, doc3, doc4], &VecStoreOptions::default()) 51 | .await 52 | .unwrap(); 53 | 54 | // Ask for user input 55 | print!("Query> "); 56 | std::io::stdout().flush().unwrap(); 57 | let mut query = String::new(); 58 | std::io::stdin().read_line(&mut query).unwrap(); 59 | 60 | let results = store 61 | .similarity_search(&query, 2, &VecStoreOptions::default()) 62 | .await 63 | .unwrap(); 64 | 65 | if results.is_empty() { 66 | println!("No results found."); 67 | return; 68 | } else { 69 | results.iter().for_each(|r| { 70 | println!("Document: {}", r.page_content); 71 | }); 72 | } 73 | } 74 | 75 | #[cfg(not(feature = "sqlite-vss"))] 76 | fn main() { 77 | println!("This example requires the 'sqlite-vss' feature to be enabled."); 78 | println!("Please run the command as follows:"); 79 | println!("cargo run --example vector_store_sqlite_vss --features sqlite-vss"); 80 | } 81 | -------------------------------------------------------------------------------- /src/tools/tool.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::string::String; 3 | 4 | use async_trait::async_trait; 5 | use serde_json::{json, Value}; 6 | 7 | #[async_trait] 8 | pub trait Tool: Send + Sync { 9 | /// Returns the name of the tool. 10 | fn name(&self) -> String; 11 | 12 | /// Provides a description of what the tool does and when to use it. 13 | fn description(&self) -> String; 14 | /// This are the parametters for OpenAi-like function call. 15 | /// You should return a jsnon like this one 16 | /// ```json 17 | /// { 18 | /// "type": "object", 19 | /// "properties": { 20 | /// "command": { 21 | /// "type": "string", 22 | /// "description": "The raw command you want executed" 23 | /// } 24 | /// }, 25 | /// "required": ["command"] 26 | /// } 27 | /// 28 | /// If there s no implementation the defaul will be the self.description() 29 | ///``` 30 | fn parameters(&self) -> Value { 31 | json!({ 32 | "type": "object", 33 | "properties": { 34 | "input": { 35 | "type": "string", 36 | "description":self.description() 37 | } 38 | }, 39 | "required": ["input"] 40 | }) 41 | } 42 | 43 | /// Processes an input string and executes the tool's functionality, returning a `Result`. 44 | /// 45 | /// This function utilizes `parse_input` to parse the input and then calls `run`. 46 | /// Its used by the Agent 47 | async fn call(&self, input: &str) -> Result> { 48 | let input = self.parse_input(input).await; 49 | self.run(input).await 50 | } 51 | 52 | /// Executes the core functionality of the tool. 53 | /// 54 | /// Example implementation: 55 | /// ```rust,ignore 56 | /// async fn run(&self, input: Value) -> Result> { 57 | /// let input_str = input.as_str().ok_or("Input should be a string")?; 58 | /// self.simple_search(input_str).await 59 | /// } 60 | /// ``` 61 | async fn run(&self, input: Value) -> Result>; 62 | 63 | /// Parses the input string, which could be a JSON value or a raw string, depending on the LLM model. 64 | /// 65 | /// Implement this function to extract the parameters needed for your tool. If a simple 66 | /// string is sufficient, the default implementation can be used. 67 | async fn parse_input(&self, input: &str) -> Value { 68 | log::info!("Using default implementation: {}", input); 69 | match serde_json::from_str::(input) { 70 | Ok(input) => { 71 | if input["input"].is_string() { 72 | Value::String(input["input"].as_str().unwrap().to_string()) 73 | } else { 74 | Value::String(input.to_string()) 75 | } 76 | } 77 | Err(_) => Value::String(input.to_string()), 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/text_splitter/options.rs: -------------------------------------------------------------------------------- 1 | use text_splitter::ChunkConfig; 2 | use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE}; 3 | 4 | use super::TextSplitterError; 5 | 6 | // Options is a struct that contains options for a text splitter. 7 | #[derive(Debug, Clone)] 8 | pub struct SplitterOptions { 9 | pub chunk_size: usize, 10 | pub chunk_overlap: usize, 11 | pub model_name: String, 12 | pub encoding_name: String, 13 | pub trim_chunks: bool, 14 | } 15 | 16 | impl Default for SplitterOptions { 17 | fn default() -> Self { 18 | Self::new() 19 | } 20 | } 21 | 22 | impl SplitterOptions { 23 | pub fn new() -> Self { 24 | SplitterOptions { 25 | chunk_size: 512, 26 | chunk_overlap: 0, 27 | model_name: String::from("gpt-3.5-turbo"), 28 | encoding_name: String::from("cl100k_base"), 29 | trim_chunks: false, 30 | } 31 | } 32 | } 33 | 34 | // Builder pattern for Options struct 35 | impl SplitterOptions { 36 | pub fn with_chunk_size(mut self, chunk_size: usize) -> Self { 37 | self.chunk_size = chunk_size; 38 | self 39 | } 40 | 41 | pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self { 42 | self.chunk_overlap = chunk_overlap; 43 | self 44 | } 45 | 46 | pub fn with_model_name(mut self, model_name: &str) -> Self { 47 | self.model_name = String::from(model_name); 48 | self 49 | } 50 | 51 | pub fn with_encoding_name(mut self, encoding_name: &str) -> Self { 52 | self.encoding_name = String::from(encoding_name); 53 | self 54 | } 55 | 56 | pub fn with_trim_chunks(mut self, trim_chunks: bool) -> Self { 57 | self.trim_chunks = trim_chunks; 58 | self 59 | } 60 | 61 | pub fn get_tokenizer_from_str(s: &str) -> Option { 62 | match s.to_lowercase().as_str() { 63 | "cl100k_base" => Some(Tokenizer::Cl100kBase), 64 | "p50k_base" => Some(Tokenizer::P50kBase), 65 | "r50k_base" => Some(Tokenizer::R50kBase), 66 | "p50k_edit" => Some(Tokenizer::P50kEdit), 67 | "gpt2" => Some(Tokenizer::Gpt2), 68 | _ => None, 69 | } 70 | } 71 | } 72 | 73 | impl TryFrom<&SplitterOptions> for ChunkConfig { 74 | type Error = TextSplitterError; 75 | 76 | fn try_from(options: &SplitterOptions) -> Result { 77 | let tk = if !options.encoding_name.is_empty() { 78 | let tokenizer = SplitterOptions::get_tokenizer_from_str(&options.encoding_name) 79 | .ok_or(TextSplitterError::TokenizerNotFound)?; 80 | 81 | get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)? 82 | } else { 83 | get_bpe_from_model(&options.model_name).map_err(|_| TextSplitterError::InvalidModel)? 84 | }; 85 | 86 | Ok(ChunkConfig::new(options.chunk_size) 87 | .with_sizer(tk) 88 | .with_trim(options.trim_chunks) 89 | .with_overlap(options.chunk_overlap)?) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /examples/git_commits.rs: -------------------------------------------------------------------------------- 1 | // To run this example execute: cargo run --example git_commits --features sqlite-vss,git -- /path/to/git/repo 2 | // Make sure vector0 and vss0 libraries are installed in the system or the path of the executable. 3 | // Download the libraries from https://github.com/asg017/sqlite-vss 4 | // For static compilation of sqlite-vss extension refer to the following link: 5 | // https://github.com/launchbadge/sqlx/issues/3147. 6 | 7 | #[cfg(feature = "sqlite-vss")] 8 | use futures_util::StreamExt; 9 | #[cfg(feature = "sqlite-vss")] 10 | use langchain_rust::{ 11 | document_loaders::GitCommitLoader, 12 | document_loaders::Loader, 13 | embedding::openai::OpenAiEmbedder, 14 | vectorstore::{sqlite_vss::StoreBuilder, VecStoreOptions, VectorStore}, 15 | }; 16 | #[cfg(feature = "sqlite-vss")] 17 | use std::io::Write; 18 | 19 | #[cfg(feature = "sqlite-vss")] 20 | #[tokio::main] 21 | async fn main() { 22 | // Initialize Embedder 23 | let embedder = OpenAiEmbedder::default(); 24 | 25 | let database_url = std::env::var("DATABASE_URL").unwrap_or("sqlite::memory:".to_string()); 26 | 27 | let repo_path = std::env::args() 28 | .nth(1) 29 | .expect("Please provide the path to the git repository."); 30 | 31 | // Initialize the Sqlite Vector Store 32 | let store = StoreBuilder::new() 33 | .embedder(embedder) 34 | .connection_url(database_url) 35 | .table("documents") 36 | .vector_dimensions(1536) 37 | .build() 38 | .await 39 | .unwrap(); 40 | 41 | // Initialize the tables in the database. This is required to be done only once. 42 | store.initialize().await.unwrap(); 43 | 44 | let git_commit_loader = GitCommitLoader::from_path(repo_path).unwrap(); 45 | 46 | let mut stream = git_commit_loader.load().await.unwrap(); 47 | while let Some(result) = stream.next().await { 48 | match result { 49 | Ok(document) => { 50 | store 51 | .add_documents(&[document], &VecStoreOptions::default()) 52 | .await 53 | .unwrap(); 54 | } 55 | Err(e) => panic!("Error fetching git commits {:?}", e), 56 | } 57 | } 58 | 59 | // Ask for user input 60 | print!("Query> "); 61 | std::io::stdout().flush().unwrap(); 62 | let mut input = String::new(); 63 | std::io::stdin().read_line(&mut input).unwrap(); 64 | 65 | let results = store 66 | .similarity_search(&input, 2, &VecStoreOptions::default()) 67 | .await 68 | .unwrap(); 69 | 70 | if results.is_empty() { 71 | println!("No results found."); 72 | return; 73 | } else { 74 | results.iter().for_each(|r| { 75 | println!("Document: {}", r.page_content); 76 | }); 77 | } 78 | } 79 | 80 | #[cfg(not(feature = "sqlite-vss"))] 81 | fn main() { 82 | println!("This example requires the 'sqlite-vss' and 'git' feature to be enabled."); 83 | println!("Please run the command as follows:"); 84 | println!("cargo run --example git_commits --features sqlite-vss,git -- /path/to/git/repo"); 85 | } 86 | -------------------------------------------------------------------------------- /examples/vector_store_surrealdb/src/main.rs: -------------------------------------------------------------------------------- 1 | // To run this example execute: `cargo run` in the folder. 2 | 3 | use langchain_rust::{ 4 | embedding::openai::openai_embedder::OpenAiEmbedder, 5 | schemas::Document, 6 | vectorstore::{surrealdb::StoreBuilder, VecStoreOptions, VectorStore}, 7 | }; 8 | use std::io::Write; 9 | 10 | #[tokio::main] 11 | async fn main() { 12 | // Initialize Embedder 13 | let embedder = OpenAiEmbedder::default(); 14 | 15 | let database_url = std::env::var("DATABASE_URL").unwrap_or("memory".to_string()); 16 | 17 | let surrealdb_config = surrealdb::opt::Config::new() 18 | .set_strict(true) 19 | .capabilities(surrealdb::opt::capabilities::Capabilities::all()); 20 | // Uncomment the following lines to enable authentication 21 | // .user(surrealdb::opt::auth::Root { 22 | // username: "root".into(), 23 | // password: "root".into(), 24 | // }); 25 | 26 | let db = surrealdb::engine::any::connect((database_url, surrealdb_config)) 27 | .await 28 | .unwrap(); 29 | db.query("DEFINE NAMESPACE test;") 30 | .await 31 | .unwrap() 32 | .check() 33 | .unwrap(); 34 | db.query("USE NAMESPACE test; DEFINE DATABASE test;") 35 | .await 36 | .unwrap() 37 | .check() 38 | .unwrap(); 39 | 40 | db.use_ns("test").await.unwrap(); 41 | db.use_db("test").await.unwrap(); 42 | 43 | // Initialize the Sqlite Vector Store 44 | let store = StoreBuilder::new() 45 | .embedder(embedder) 46 | .db(db) 47 | .vector_dimensions(1536) 48 | .build() 49 | .await 50 | .unwrap(); 51 | 52 | // Intialize the tables in the database. This is required to be done only once. 53 | store.initialize().await.unwrap(); 54 | 55 | // Add documents to the database 56 | let doc1 = Document::new( 57 | "langchain-rust is a port of the langchain python library to rust and was written in 2024.", 58 | ); 59 | let doc2 = Document::new( 60 | "langchaingo is a port of the langchain python library to go language and was written in 2023." 61 | ); 62 | let doc3 = Document::new( 63 | "Capital of United States of America (USA) is Washington D.C. and the capital of France is Paris." 64 | ); 65 | let doc4 = Document::new("Capital of France is Paris."); 66 | 67 | store 68 | .add_documents(&vec![doc1, doc2, doc3, doc4], &VecStoreOptions::default()) 69 | .await 70 | .unwrap(); 71 | 72 | // Ask for user input 73 | print!("Query> "); 74 | std::io::stdout().flush().unwrap(); 75 | let mut query = String::new(); 76 | std::io::stdin().read_line(&mut query).unwrap(); 77 | 78 | let results = store 79 | .similarity_search( 80 | &query, 81 | 2, 82 | &VecStoreOptions::default().with_score_threshold(0.6), 83 | ) 84 | .await 85 | .unwrap(); 86 | 87 | if results.is_empty() { 88 | println!("No results found."); 89 | return; 90 | } else { 91 | results.iter().for_each(|r| { 92 | println!("Document: {}", r.page_content); 93 | }); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/llm/ollama/openai.rs: -------------------------------------------------------------------------------- 1 | use async_openai::config::Config; 2 | use reqwest::header::HeaderMap; 3 | use secrecy::SecretString; 4 | use serde::Deserialize; 5 | 6 | const OLLAMA_API_BASE: &str = "http://localhost:11434/v1"; 7 | 8 | /// Ollama has [OpenAI compatiblity](https://ollama.com/blog/openai-compatibility), meaning that you can use it as an OpenAI API. 9 | /// 10 | /// This struct implements the `Config` trait of OpenAI, and has the necessary setup for OpenAI configurations for you to use Ollama. 11 | /// 12 | /// ## Example 13 | /// 14 | /// ```rs 15 | /// let ollama = OpenAI::new(OllamaConfig::default()).with_model("llama3.2"); 16 | /// let response = ollama.invoke("Say hello!").await.unwrap(); 17 | /// ``` 18 | #[derive(Clone, Debug, Deserialize)] 19 | #[serde(default)] 20 | pub struct OllamaConfig { 21 | api_base: String, 22 | api_key: SecretString, 23 | } 24 | 25 | impl OllamaConfig { 26 | pub fn new() -> Self { 27 | Self::default() 28 | } 29 | 30 | pub fn with_api_key>(mut self, api_key: S) -> Self { 31 | self.api_key = SecretString::from(api_key.into()); 32 | self 33 | } 34 | 35 | pub fn with_api_base>(mut self, api_base: S) -> Self { 36 | self.api_base = api_base.into(); 37 | self 38 | } 39 | } 40 | 41 | impl Config for OllamaConfig { 42 | fn api_key(&self) -> &SecretString { 43 | &self.api_key 44 | } 45 | 46 | fn api_base(&self) -> &str { 47 | &self.api_base 48 | } 49 | 50 | fn headers(&self) -> HeaderMap { 51 | HeaderMap::default() 52 | } 53 | 54 | fn query(&self) -> Vec<(&str, &str)> { 55 | vec![] 56 | } 57 | 58 | fn url(&self, path: &str) -> String { 59 | format!("{}{}", self.api_base(), path) 60 | } 61 | } 62 | 63 | impl Default for OllamaConfig { 64 | fn default() -> Self { 65 | Self { 66 | api_base: OLLAMA_API_BASE.to_string(), 67 | api_key: SecretString::from("ollama".to_string()), 68 | } 69 | } 70 | } 71 | 72 | #[cfg(test)] 73 | mod tests { 74 | use super::*; 75 | use crate::{language_models::llm::LLM, llm::openai::OpenAI, schemas::Message}; 76 | use tokio::io::AsyncWriteExt; 77 | use tokio_stream::StreamExt; 78 | 79 | #[tokio::test] 80 | #[ignore] 81 | async fn test_ollama_openai() { 82 | let ollama = OpenAI::new(OllamaConfig::default()).with_model("llama2"); 83 | let response = ollama.invoke("hola").await.unwrap(); 84 | println!("{}", response); 85 | } 86 | 87 | #[tokio::test] 88 | #[ignore] 89 | async fn test_ollama_openai_stream() { 90 | let ollama = OpenAI::new(OllamaConfig::default()).with_model("phi3"); 91 | 92 | let message = Message::new_human_message("Why does water boil at 100 degrees?"); 93 | let mut stream = ollama.stream(&[message]).await.unwrap(); 94 | let mut stdout = tokio::io::stdout(); 95 | while let Some(res) = stream.next().await { 96 | let data = res.unwrap(); 97 | stdout.write(data.content.as_bytes()).await.unwrap(); 98 | } 99 | stdout.flush().await.unwrap(); 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /examples/conversational_retriever_simple_chain.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use async_trait::async_trait; 4 | use futures_util::StreamExt; 5 | use langchain_rust::{ 6 | chain::{Chain, ConversationalRetrieverChainBuilder}, 7 | fmt_message, fmt_template, 8 | llm::{OpenAI, OpenAIModel}, 9 | memory::SimpleMemory, 10 | message_formatter, 11 | prompt::HumanMessagePromptTemplate, 12 | prompt_args, 13 | schemas::{Document, Message, Retriever}, 14 | template_jinja2, 15 | }; 16 | 17 | struct RetrieverMock {} 18 | #[async_trait] 19 | impl Retriever for RetrieverMock { 20 | async fn get_relevant_documents( 21 | &self, 22 | _question: &str, 23 | ) -> Result, Box> { 24 | Ok(vec![ 25 | Document::new(format!( 26 | "\nQuestion: {}\nAnswer: {}\n", 27 | "Which is the favorite text editor of luis", "Nvim" 28 | )), 29 | Document::new(format!( 30 | "\nQuestion: {}\nAnswer: {}\n", 31 | "How old is Luis", "24" 32 | )), 33 | Document::new(format!( 34 | "\nQuestion: {}\nAnswer: {}\n", 35 | "Where do luis live", "Peru" 36 | )), 37 | Document::new(format!( 38 | "\nQuestion: {}\nAnswer: {}\n", 39 | "Whats his favorite food", "Pan con chicharron" 40 | )), 41 | ]) 42 | } 43 | } 44 | #[tokio::main] 45 | async fn main() { 46 | let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string()); 47 | let prompt=message_formatter![ 48 | fmt_message!(Message::new_system_message("You are a helpful assistant")), 49 | fmt_template!(HumanMessagePromptTemplate::new( 50 | template_jinja2!(" 51 | Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. 52 | 53 | {{context}} 54 | 55 | Question:{{question}} 56 | Helpful Answer: 57 | 58 | ", 59 | "context","question"))) 60 | 61 | ]; 62 | let chain = ConversationalRetrieverChainBuilder::new() 63 | .llm(llm) 64 | .rephrase_question(true) 65 | .retriever(RetrieverMock {}) 66 | .memory(SimpleMemory::new().into()) 67 | //If you want to use the default prompt remove the .prompt() 68 | //Keep in mind if you want to change the prompt; this chain need the {{context}} variable 69 | .prompt(prompt) 70 | .build() 71 | .expect("Error building ConversationalChain"); 72 | 73 | let input_variables = prompt_args! { 74 | "question" => "Hi", 75 | }; 76 | 77 | let result = chain.invoke(input_variables).await; 78 | if let Ok(result) = result { 79 | println!("Result: {:?}", result); 80 | } 81 | 82 | let input_variables = prompt_args! { 83 | "question" => "Which is luis Favorite Food", 84 | }; 85 | 86 | //If you want to stream 87 | let mut stream = chain.stream(input_variables).await.unwrap(); 88 | while let Some(result) = stream.next().await { 89 | match result { 90 | Ok(data) => data.to_stdout().unwrap(), 91 | Err(e) => { 92 | println!("Error: {:?}", e); 93 | } 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /examples/llm_chain.rs: -------------------------------------------------------------------------------- 1 | use langchain_rust::{ 2 | chain::{Chain, LLMChainBuilder}, 3 | fmt_message, fmt_placeholder, fmt_template, 4 | language_models::llm::LLM, 5 | llm::openai::OpenAI, 6 | message_formatter, 7 | prompt::HumanMessagePromptTemplate, 8 | prompt_args, 9 | schemas::messages::Message, 10 | template_fstring, 11 | }; 12 | 13 | #[tokio::main] 14 | async fn main() { 15 | //We can then initialize the model: 16 | // If you'd prefer not to set an environment variable you can pass the key in directly via the `openai_api_key` named parameter when initiating the OpenAI LLM class: 17 | //let open_ai = OpenAI::default().with_api_key("..."); 18 | let open_ai = OpenAI::default(); 19 | 20 | //Once you've installed and initialized the LLM of your choice, we can try using it! Let's ask it what LangSmith is - this is something that wasn't present in the training data so it shouldn't have a very good response. 21 | let resp = open_ai.invoke("What is rust").await.unwrap(); 22 | println!("{}", resp); 23 | 24 | // We can also guide it's response with a prompt template. Prompt templates are used to convert raw user input to a better input to the LLM. 25 | let prompt = message_formatter![ 26 | fmt_message!(Message::new_system_message( 27 | "You are world class technical documentation writer." 28 | )), 29 | fmt_template!(HumanMessagePromptTemplate::new(template_fstring!( 30 | "{input}", "input" 31 | ))) 32 | ]; 33 | 34 | //We can now combine these into a simple LLM chain: 35 | 36 | let chain = LLMChainBuilder::new() 37 | .prompt(prompt) 38 | .llm(open_ai.clone()) 39 | .build() 40 | .unwrap(); 41 | 42 | //We can now invoke it and ask the same question. It still won't know the answer, but it should respond in a more proper tone for a technical writer! 43 | 44 | match chain 45 | .invoke(prompt_args! { 46 | "input" => "Quien es el escritor de 20000 millas de viaje submarino", 47 | }) 48 | .await 49 | { 50 | Ok(result) => { 51 | println!("Result: {:?}", result); 52 | } 53 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 54 | } 55 | 56 | //If you want to prompt to have a list of messages you could use the `fmt_placeholder` macro 57 | 58 | let prompt = message_formatter![ 59 | fmt_message!(Message::new_system_message( 60 | "You are world class technical documentation writer." 61 | )), 62 | fmt_placeholder!("history"), 63 | fmt_template!(HumanMessagePromptTemplate::new(template_fstring!( 64 | "{input}", "input" 65 | ))), 66 | ]; 67 | 68 | let chain = LLMChainBuilder::new() 69 | .prompt(prompt) 70 | .llm(open_ai) 71 | .build() 72 | .unwrap(); 73 | match chain 74 | .invoke(prompt_args! { 75 | "input" => "Who is the writer of 20,000 Leagues Under the Sea, and what is my name?", 76 | "history" => vec![ 77 | Message::new_human_message("My name is: luis"), 78 | Message::new_ai_message("Hi luis"), 79 | ], 80 | 81 | }) 82 | .await 83 | { 84 | Ok(result) => { 85 | println!("Result: {:?}", result); 86 | } 87 | Err(e) => panic!("Error invoking LLMChain: {:?}", e), 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/embedding/ollama/ollama_embedder.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use crate::embedding::{embedder_trait::Embedder, EmbedderError}; 4 | use async_trait::async_trait; 5 | use ollama_rs::{ 6 | generation::{ 7 | embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, 8 | options::GenerationOptions, 9 | }, 10 | Ollama as OllamaClient, 11 | }; 12 | 13 | #[derive(Debug)] 14 | pub struct OllamaEmbedder { 15 | pub(crate) client: Arc, 16 | pub(crate) model: String, 17 | pub(crate) options: Option, 18 | } 19 | 20 | /// [nomic-embed-text](https://ollama.com/library/nomic-embed-text) is a 137M parameters, 274MB model. 21 | const DEFAULT_MODEL: &str = "nomic-embed-text"; 22 | 23 | impl OllamaEmbedder { 24 | pub fn new>( 25 | client: Arc, 26 | model: S, 27 | options: Option, 28 | ) -> Self { 29 | Self { 30 | client, 31 | model: model.into(), 32 | options, 33 | } 34 | } 35 | 36 | pub fn with_model>(mut self, model: S) -> Self { 37 | self.model = model.into(); 38 | self 39 | } 40 | 41 | pub fn with_options(mut self, options: GenerationOptions) -> Self { 42 | self.options = Some(options); 43 | self 44 | } 45 | } 46 | 47 | impl Default for OllamaEmbedder { 48 | fn default() -> Self { 49 | let client = Arc::new(OllamaClient::default()); 50 | Self::new(client, String::from(DEFAULT_MODEL), None) 51 | } 52 | } 53 | 54 | #[async_trait] 55 | impl Embedder for OllamaEmbedder { 56 | async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { 57 | log::debug!("Embedding documents: {:?}", documents); 58 | 59 | let response = self 60 | .client 61 | .generate_embeddings(GenerateEmbeddingsRequest::new( 62 | self.model.clone(), 63 | EmbeddingsInput::Multiple(documents.to_vec()), 64 | )) 65 | .await?; 66 | 67 | let embeddings = response 68 | .embeddings 69 | .into_iter() 70 | .map(|embedding| embedding.into_iter().map(f64::from).collect()) 71 | .collect(); 72 | 73 | Ok(embeddings) 74 | } 75 | 76 | async fn embed_query(&self, text: &str) -> Result, EmbedderError> { 77 | log::debug!("Embedding query: {:?}", text); 78 | 79 | let response = self 80 | .client 81 | .generate_embeddings(GenerateEmbeddingsRequest::new( 82 | self.model.clone(), 83 | EmbeddingsInput::Single(text.into()), 84 | )) 85 | .await?; 86 | 87 | let embeddings = response 88 | .embeddings 89 | .into_iter() 90 | .next() 91 | .unwrap() 92 | .into_iter() 93 | .map(f64::from) 94 | .collect(); 95 | 96 | Ok(embeddings) 97 | } 98 | } 99 | 100 | #[cfg(test)] 101 | mod tests { 102 | use super::*; 103 | 104 | #[tokio::test] 105 | #[ignore] 106 | async fn test_ollama_embed() { 107 | let ollama = OllamaEmbedder::default() 108 | .with_model("nomic-embed-text") 109 | .with_options(GenerationOptions::default().temperature(0.5)); 110 | 111 | let response = ollama.embed_query("Why is the sky blue?").await.unwrap(); 112 | 113 | assert_eq!(response.len(), 768); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/chain/sql_datbase/builder.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | chain::{ 3 | llm_chain::LLMChainBuilder, options::ChainCallOptions, ChainError, DEFAULT_OUTPUT_KEY, 4 | }, 5 | language_models::llm::LLM, 6 | output_parsers::OutputParser, 7 | prompt::HumanMessagePromptTemplate, 8 | template_jinja2, 9 | tools::SQLDatabase, 10 | }; 11 | 12 | use super::{ 13 | chain::SQLDatabaseChain, 14 | prompt::{DEFAULT_SQLSUFFIX, DEFAULT_SQLTEMPLATE}, 15 | STOP_WORD, 16 | }; 17 | 18 | pub struct SQLDatabaseChainBuilder { 19 | llm: Option>, 20 | options: Option, 21 | top_k: Option, 22 | database: Option, 23 | output_key: Option, 24 | output_parser: Option>, 25 | } 26 | 27 | impl SQLDatabaseChainBuilder { 28 | pub fn new() -> Self { 29 | Self { 30 | llm: None, 31 | options: None, 32 | top_k: None, 33 | database: None, 34 | output_key: None, 35 | output_parser: None, 36 | } 37 | } 38 | 39 | pub fn llm>>(mut self, llm: L) -> Self { 40 | self.llm = Some(llm.into()); 41 | self 42 | } 43 | 44 | pub fn output_key>(mut self, output_key: S) -> Self { 45 | self.output_key = Some(output_key.into()); 46 | self 47 | } 48 | 49 | pub fn output_parser>>(mut self, output_parser: P) -> Self { 50 | self.output_parser = Some(output_parser.into()); 51 | self 52 | } 53 | 54 | pub fn options(mut self, options: ChainCallOptions) -> Self { 55 | self.options = Some(options); 56 | self 57 | } 58 | 59 | pub fn top_k(mut self, top_k: usize) -> Self { 60 | self.top_k = Some(top_k); 61 | self 62 | } 63 | 64 | pub fn database(mut self, database: SQLDatabase) -> Self { 65 | self.database = Some(database); 66 | self 67 | } 68 | 69 | pub fn build(self) -> Result { 70 | let llm = self 71 | .llm 72 | .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?; 73 | let top_k = self 74 | .top_k 75 | .ok_or_else(|| ChainError::MissingObject("Top K must be set".into()))?; 76 | let database = self 77 | .database 78 | .ok_or_else(|| ChainError::MissingObject("Database must be set".into()))?; 79 | 80 | let prompt = HumanMessagePromptTemplate::new(template_jinja2!( 81 | format!("{}{}", DEFAULT_SQLTEMPLATE, DEFAULT_SQLSUFFIX), 82 | "dialect", 83 | "table_info", 84 | "top_k", 85 | "input" 86 | )); 87 | 88 | let llm_chain = { 89 | let mut builder = LLMChainBuilder::new() 90 | .prompt(prompt) 91 | .output_key(self.output_key.unwrap_or_else(|| DEFAULT_OUTPUT_KEY.into())) 92 | .llm(llm); 93 | 94 | let mut options = self.options.unwrap_or_default(); 95 | options = options.with_stop_words(vec![STOP_WORD.to_string()]); 96 | builder = builder.options(options); 97 | 98 | if let Some(output_parser) = self.output_parser { 99 | builder = builder.output_parser(output_parser); 100 | } 101 | 102 | builder.build()? 103 | }; 104 | 105 | Ok(SQLDatabaseChain { 106 | llmchain: llm_chain, 107 | top_k, 108 | database, 109 | }) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/schemas/tools_openai_like.rs: -------------------------------------------------------------------------------- 1 | use crate::schemas::convert::{OpenAIFromLangchain, TryOpenAiFromLangchain}; 2 | use crate::tools::Tool; 3 | use async_openai::types::{ 4 | ChatCompletionNamedToolChoice, ChatCompletionTool, ChatCompletionToolArgs, 5 | ChatCompletionToolChoiceOption, ChatCompletionToolType, FunctionName, FunctionObjectArgs, 6 | }; 7 | use serde::{Deserialize, Serialize}; 8 | use serde_json::Value; 9 | use std::ops::Deref; 10 | 11 | #[derive(Clone, Debug)] 12 | pub enum FunctionCallBehavior { 13 | None, 14 | Auto, 15 | Named(String), 16 | } 17 | 18 | impl OpenAIFromLangchain for ChatCompletionToolChoiceOption { 19 | fn from_langchain(langchain: FunctionCallBehavior) -> Self { 20 | match langchain { 21 | FunctionCallBehavior::Auto => ChatCompletionToolChoiceOption::Auto, 22 | FunctionCallBehavior::None => ChatCompletionToolChoiceOption::None, 23 | FunctionCallBehavior::Named(name) => { 24 | ChatCompletionToolChoiceOption::Named(ChatCompletionNamedToolChoice { 25 | r#type: ChatCompletionToolType::Function, 26 | function: FunctionName { 27 | name: name.to_owned(), 28 | }, 29 | }) 30 | } 31 | } 32 | } 33 | } 34 | 35 | #[derive(Clone, Debug)] 36 | pub struct FunctionDefinition { 37 | pub name: String, 38 | pub description: String, 39 | pub parameters: Value, 40 | } 41 | 42 | impl FunctionDefinition { 43 | pub fn new(name: &str, description: &str, parameters: Value) -> Self { 44 | FunctionDefinition { 45 | name: name.trim().replace(" ", "_"), 46 | description: description.to_string(), 47 | parameters, 48 | } 49 | } 50 | 51 | /// Generic function that can be used with both Arc, Box, and direct references 52 | pub fn from_langchain_tool(tool: &T) -> FunctionDefinition 53 | where 54 | T: Deref + ?Sized, 55 | { 56 | FunctionDefinition { 57 | name: tool.name().trim().replace(" ", "_"), 58 | description: tool.description(), 59 | parameters: tool.parameters(), 60 | } 61 | } 62 | } 63 | 64 | impl TryOpenAiFromLangchain for ChatCompletionTool { 65 | type Error = async_openai::error::OpenAIError; 66 | fn try_from_langchain(langchain: FunctionDefinition) -> Result { 67 | let tool = FunctionObjectArgs::default() 68 | .name(langchain.name) 69 | .description(langchain.description) 70 | .parameters(langchain.parameters) 71 | .build()?; 72 | 73 | ChatCompletionToolArgs::default() 74 | .r#type(ChatCompletionToolType::Function) 75 | .function(tool) 76 | .build() 77 | } 78 | } 79 | 80 | #[derive(Serialize, Deserialize, Debug)] 81 | pub struct FunctionCallResponse { 82 | pub id: String, 83 | #[serde(rename = "type")] 84 | pub type_field: String, 85 | pub function: FunctionDetail, 86 | } 87 | 88 | #[derive(Serialize, Deserialize, Debug)] 89 | pub struct FunctionDetail { 90 | pub name: String, 91 | ///this should be an string, and this should be passed to the tool, to 92 | ///then be deserilised inside the tool, becuase just the tools knows the names of the arguments. 93 | pub arguments: String, 94 | } 95 | 96 | impl FunctionCallResponse { 97 | pub fn from_str(s: &str) -> Result { 98 | serde_json::from_str(s) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/llm/qwen/models.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::schemas::{Message, MessageType}; 4 | 5 | #[derive(Serialize, Deserialize)] 6 | pub(crate) struct QwenMessage { 7 | pub role: String, 8 | pub content: String, 9 | } 10 | 11 | impl QwenMessage { 12 | pub fn new>(role: S, content: S) -> Self { 13 | Self { 14 | role: role.into(), 15 | content: content.into(), 16 | } 17 | } 18 | 19 | pub fn from_message(message: &Message) -> Self { 20 | match message.message_type { 21 | MessageType::SystemMessage => Self::new("system", &message.content), 22 | MessageType::AIMessage => Self::new("assistant", &message.content), 23 | MessageType::HumanMessage => Self::new("user", &message.content), 24 | // Qwen may not have direct support for tool messages in the same way as Claude 25 | // For now, handle them as user messages 26 | MessageType::ToolMessage => Self::new("user", &message.content), 27 | } 28 | } 29 | } 30 | 31 | #[derive(Serialize, Deserialize)] 32 | pub(crate) struct Payload { 33 | pub model: String, 34 | pub messages: Vec, 35 | #[serde(skip_serializing_if = "Option::is_none")] 36 | pub max_tokens: Option, 37 | #[serde(skip_serializing_if = "Option::is_none")] 38 | pub stream: Option, 39 | #[serde(skip_serializing_if = "Option::is_none")] 40 | pub stop: Option>, 41 | #[serde(skip_serializing_if = "Option::is_none")] 42 | pub temperature: Option, 43 | #[serde(skip_serializing_if = "Option::is_none")] 44 | pub top_p: Option, 45 | #[serde(skip_serializing_if = "Option::is_none")] 46 | pub seed: Option, 47 | #[serde(skip_serializing_if = "Option::is_none")] 48 | pub result_format: Option, 49 | } 50 | 51 | #[derive(Debug, Serialize, Deserialize, Clone)] 52 | pub(crate) struct ApiResponse { 53 | pub id: String, 54 | pub created: u64, 55 | pub model: String, 56 | pub choices: Vec, 57 | pub usage: Usage, 58 | } 59 | 60 | #[derive(Debug, Serialize, Deserialize, Clone)] 61 | pub(crate) struct Choice { 62 | pub message: ResponseMessage, 63 | pub finish_reason: Option, 64 | pub index: u32, 65 | } 66 | 67 | #[derive(Debug, Serialize, Deserialize, Clone)] 68 | pub(crate) struct ResponseMessage { 69 | pub role: String, 70 | pub content: String, 71 | } 72 | 73 | #[derive(Debug, Serialize, Deserialize, Clone)] 74 | pub(crate) struct Usage { 75 | pub prompt_tokens: u32, 76 | pub completion_tokens: u32, 77 | pub total_tokens: u32, 78 | } 79 | 80 | // Stream response structures 81 | #[derive(Debug, Serialize, Deserialize, Clone)] 82 | pub(crate) struct StreamResponse { 83 | pub id: String, 84 | pub model: String, 85 | pub created: u64, 86 | pub choices: Vec, 87 | } 88 | 89 | #[derive(Debug, Serialize, Deserialize, Clone)] 90 | pub(crate) struct StreamChoice { 91 | pub delta: Delta, 92 | pub finish_reason: Option, 93 | pub index: u32, 94 | } 95 | 96 | #[derive(Debug, Serialize, Deserialize, Clone)] 97 | pub(crate) struct Delta { 98 | #[serde(skip_serializing_if = "Option::is_none")] 99 | pub role: Option, 100 | #[serde(skip_serializing_if = "Option::is_none")] 101 | pub content: Option, 102 | } 103 | 104 | // Error response structure 105 | #[derive(Debug, Serialize, Deserialize)] 106 | pub(crate) struct ErrorResponse { 107 | pub request_id: String, 108 | pub code: String, 109 | pub message: String, 110 | } 111 | -------------------------------------------------------------------------------- /src/tools/scraper/scraper.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use regex::Regex; 3 | use scraper::{ElementRef, Html, Selector}; 4 | use serde_json::Value; 5 | use std::{error::Error, sync::Arc}; 6 | 7 | use crate::tools::Tool; 8 | 9 | pub struct WebScrapper {} 10 | 11 | impl WebScrapper { 12 | pub fn new() -> Self { 13 | Self {} 14 | } 15 | } 16 | 17 | #[async_trait] 18 | impl Tool for WebScrapper { 19 | fn name(&self) -> String { 20 | String::from("Web Scraper") 21 | } 22 | fn description(&self) -> String { 23 | String::from( 24 | "Web Scraper will scan a url and return the content of the web page. 25 | Input should be a working url.", 26 | ) 27 | } 28 | async fn run(&self, input: Value) -> Result> { 29 | let input = input.as_str().ok_or("Invalid input")?; 30 | match scrape_url(input).await { 31 | Ok(content) => Ok(content), 32 | Err(e) => Ok(format!("Error scraping {}: {}\n", input, e)), 33 | } 34 | } 35 | } 36 | 37 | impl Into> for WebScrapper { 38 | fn into(self) -> Arc { 39 | Arc::new(self) 40 | } 41 | } 42 | 43 | async fn scrape_url(url: &str) -> Result> { 44 | let res = reqwest::get(url).await?.text().await?; 45 | 46 | let document = Html::parse_document(&res); 47 | let body_selector = Selector::parse("body").unwrap(); 48 | 49 | let mut text = Vec::new(); 50 | for element in document.select(&body_selector) { 51 | collect_text_not_in_script(&element, &mut text); 52 | } 53 | 54 | let joined_text = text.join(" "); 55 | let cleaned_text = joined_text.replace(['\n', '\t'], " "); 56 | let re = Regex::new(r"\s+").unwrap(); 57 | let final_text = re.replace_all(&cleaned_text, " "); 58 | Ok(final_text.to_string()) 59 | } 60 | 61 | fn collect_text_not_in_script(element: &ElementRef, text: &mut Vec) { 62 | for node in element.children() { 63 | if node.value().is_element() { 64 | let tag_name = node.value().as_element().unwrap().name(); 65 | if tag_name == "script" { 66 | continue; 67 | } 68 | collect_text_not_in_script(&ElementRef::wrap(node).unwrap(), text); 69 | } else if node.value().is_text() { 70 | text.push(node.value().as_text().unwrap().text.to_string()); 71 | } 72 | } 73 | } 74 | 75 | #[cfg(test)] 76 | mod tests { 77 | use super::*; 78 | use tokio; 79 | 80 | #[tokio::test] 81 | async fn test_scrape_url() { 82 | // Request a new server from the pool 83 | let mut server = mockito::Server::new_async().await; 84 | 85 | // Create a mock on the server 86 | let mock = server 87 | .mock("GET", "/") 88 | .with_status(200) 89 | .with_header("content-type", "text/plain") 90 | .with_body("Hello World") 91 | .create(); 92 | 93 | // Instantiate your WebScrapper 94 | let scraper = WebScrapper::new(); 95 | 96 | // Use the server URL for scraping 97 | let url = server.url(); 98 | 99 | // Call the WebScrapper with the mocked URL 100 | let result = scraper.call(&url).await; 101 | 102 | // Assert that the result is Ok and contains "Hello World" 103 | assert!(result.is_ok()); 104 | let content = result.unwrap(); 105 | assert_eq!(content.trim(), "Hello World"); 106 | 107 | // Verify that the mock was called as expected 108 | mock.assert(); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/agent/chat/output_parser.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | 3 | use regex::Regex; 4 | use serde::Deserialize; 5 | use serde_json::Value; 6 | 7 | use crate::{ 8 | agent::AgentError, 9 | schemas::agent::{AgentAction, AgentEvent, AgentFinish}, 10 | }; 11 | 12 | use super::prompt::FORMAT_INSTRUCTIONS; 13 | 14 | #[derive(Debug, Deserialize)] 15 | struct AgentOutput { 16 | action: String, 17 | action_input: String, 18 | } 19 | 20 | pub struct ChatOutputParser {} 21 | impl ChatOutputParser { 22 | pub fn new() -> Self { 23 | Self {} 24 | } 25 | } 26 | 27 | impl ChatOutputParser { 28 | pub fn parse(&self, text: &str) -> Result { 29 | log::debug!("Parsing to Agent Action: {}", text); 30 | match parse_json_markdown(text) { 31 | Some(value) => { 32 | // Deserialize the Value into AgentOutput 33 | let agent_output: AgentOutput = serde_json::from_value(value)?; 34 | 35 | if agent_output.action == "Final Answer" { 36 | Ok(AgentEvent::Finish(AgentFinish { 37 | output: agent_output.action_input, 38 | })) 39 | } else { 40 | Ok(AgentEvent::Action(vec![AgentAction { 41 | tool: agent_output.action, 42 | tool_input: agent_output.action_input, 43 | log: text.to_string(), 44 | }])) 45 | } 46 | } 47 | None => { 48 | log::debug!("No JSON found or malformed JSON in text: {}", text); 49 | Ok(AgentEvent::Finish(AgentFinish { 50 | output: text.to_string(), 51 | })) 52 | } 53 | } 54 | } 55 | 56 | pub fn get_format_instructions(&self) -> &str { 57 | FORMAT_INSTRUCTIONS 58 | } 59 | } 60 | 61 | fn parse_partial_json(s: &str, strict: bool) -> Option { 62 | // First, attempt to parse the string as-is. 63 | match serde_json::from_str::(s) { 64 | Ok(val) => return Some(val), 65 | Err(_) if !strict => (), 66 | Err(_) => return None, 67 | } 68 | 69 | let mut new_s = String::new(); 70 | let mut stack: VecDeque = VecDeque::new(); 71 | let mut is_inside_string = false; 72 | let mut escaped = false; 73 | 74 | for char in s.chars() { 75 | match char { 76 | '"' if !escaped => is_inside_string = !is_inside_string, 77 | '{' if !is_inside_string => stack.push_back('}'), 78 | '[' if !is_inside_string => stack.push_back(']'), 79 | '}' | ']' if !is_inside_string => { 80 | if let Some(c) = stack.pop_back() { 81 | if c != char { 82 | return None; // Mismatched closing character 83 | } 84 | } else { 85 | return None; // Unbalanced closing character 86 | } 87 | } 88 | '\\' if is_inside_string => escaped = !escaped, 89 | _ => escaped = false, 90 | } 91 | new_s.push(char); 92 | } 93 | 94 | // Close any open structures. 95 | while let Some(c) = stack.pop_back() { 96 | new_s.push(c); 97 | } 98 | 99 | // Attempt to parse again. 100 | serde_json::from_str(&new_s).ok() 101 | } 102 | 103 | fn parse_json_markdown(json_markdown: &str) -> Option { 104 | let re = Regex::new(r"```(?:json)?\s*([\s\S]+?)\s*```").unwrap(); 105 | if let Some(caps) = re.captures(json_markdown) { 106 | if let Some(json_str) = caps.get(1) { 107 | return parse_partial_json(json_str.as_str(), false); 108 | } 109 | } 110 | None 111 | } 112 | -------------------------------------------------------------------------------- /src/document_loaders/git_commit_loader/git_commit_loader.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::pin::Pin; 3 | 4 | use crate::document_loaders::{process_doc_stream, LoaderError}; 5 | use crate::{document_loaders::Loader, schemas::Document, text_splitter::TextSplitter}; 6 | use async_trait::async_trait; 7 | use futures::Stream; 8 | use gix::ThreadSafeRepository; 9 | use serde_json::Value; 10 | 11 | #[derive(Clone)] 12 | pub struct GitCommitLoader { 13 | repo: ThreadSafeRepository, 14 | } 15 | 16 | impl GitCommitLoader { 17 | pub fn new(repo: ThreadSafeRepository) -> Self { 18 | Self { repo } 19 | } 20 | 21 | pub fn from_path>(directory: P) -> Result { 22 | let repo = ThreadSafeRepository::discover(directory)?; 23 | Ok(Self::new(repo)) 24 | } 25 | } 26 | 27 | #[async_trait] 28 | impl Loader for GitCommitLoader { 29 | async fn load( 30 | mut self, 31 | ) -> Result< 32 | Pin> + Send + 'static>>, 33 | LoaderError, 34 | > { 35 | let repo = self.repo.to_thread_local(); 36 | 37 | // Since commits_iter can't be shared across thread safely, use channels as a workaround. 38 | let (tx, rx) = flume::bounded(1); 39 | 40 | tokio::spawn(async move { 41 | let commits_iter = repo 42 | .rev_walk(Some(repo.head_id().unwrap().detach())) 43 | .all() 44 | .unwrap() 45 | .filter_map(Result::ok) 46 | .map(|oid| { 47 | let commit = oid.object().unwrap(); 48 | let commit_id = commit.id; 49 | let author = commit.author().unwrap(); 50 | let email = author.email.to_string(); 51 | let name = author.name.to_string(); 52 | let message = format!("{}", commit.message().unwrap().title); 53 | 54 | let mut document = Document::new(format!( 55 | "commit {commit_id}\nAuthor: {name} <{email}>\n\n {message}" 56 | )); 57 | let mut metadata = HashMap::new(); 58 | metadata.insert("commit".to_string(), Value::from(commit_id.to_string())); 59 | 60 | document.metadata = metadata; 61 | Ok(document) 62 | }); 63 | 64 | for document in commits_iter { 65 | if tx.send(document).is_err() { 66 | // stream might have been dropped 67 | break; 68 | } 69 | } 70 | }); 71 | 72 | Ok(Box::pin(rx.into_stream())) 73 | } 74 | 75 | async fn load_and_split( 76 | mut self, 77 | splitter: TS, 78 | ) -> Result< 79 | Pin> + Send + 'static>>, 80 | LoaderError, 81 | > { 82 | let doc_stream = self.load().await?; 83 | let stream = process_doc_stream(doc_stream, splitter).await; 84 | Ok(Box::pin(stream)) 85 | } 86 | } 87 | 88 | #[cfg(test)] 89 | mod tests { 90 | use futures_util::StreamExt; 91 | 92 | use super::*; 93 | 94 | #[tokio::test] 95 | #[ignore] 96 | async fn git_commit_loader() { 97 | let git_commit_loader = GitCommitLoader::from_path("/code/langchain-rust").unwrap(); 98 | 99 | let documents = git_commit_loader 100 | .load() 101 | .await 102 | .unwrap() 103 | .map(|x| x.unwrap()) 104 | .collect::>() 105 | .await; 106 | 107 | dbg!(&documents); 108 | // assert_eq!(documents[0].page_content, ""); 109 | todo!() 110 | } 111 | } 112 | --------------------------------------------------------------------------------