├── asuka-core ├── src │ ├── ops │ │ └── mod.rs │ ├── clients │ │ ├── mod.rs │ │ ├── telegram.rs │ │ ├── twitter.rs │ │ ├── github.rs │ │ └── discord.rs │ ├── mod.rs │ ├── knowledge │ │ ├── mod.rs │ │ ├── error.rs │ │ ├── types.rs │ │ ├── models.rs │ │ └── store.rs │ ├── lib.rs │ ├── agent.rs │ ├── character.rs │ ├── mcp │ │ ├── mod.rs │ │ └── transport.rs │ ├── loaders │ │ ├── site.rs │ │ ├── github.rs │ │ └── mod.rs │ └── attention.rs └── Cargo.toml ├── rust-toolchain.toml ├── .github └── banner.png ├── asuka-starknet ├── src │ ├── lib.rs │ ├── controller.rs │ ├── swap.rs │ ├── add_token.rs │ └── transfer.rs └── Cargo.toml ├── .gitignore ├── examples ├── Cargo.toml └── src │ ├── main.rs │ └── characters │ └── shinobi.toml ├── Cargo.toml ├── README.md └── LICENSE /asuka-core/src/ops/mod.rs: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "1.81.0" 3 | -------------------------------------------------------------------------------- /.github/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dojoengine/asuka/HEAD/.github/banner.png -------------------------------------------------------------------------------- /asuka-starknet/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod add_token; 2 | pub mod swap; 3 | pub mod transfer; 4 | -------------------------------------------------------------------------------- /asuka-core/src/clients/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod discord; 2 | pub mod github; 3 | pub mod telegram; 4 | pub mod twitter; 5 | -------------------------------------------------------------------------------- /asuka-core/src/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod agent; 2 | pub mod character; 3 | pub mod discord; 4 | pub mod knowledge; 5 | pub mod loaders; 6 | pub mod ops; 7 | pub mod tools; 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | /target/ 3 | 4 | # Backup files 5 | **/*.rs.bk 6 | 7 | # IDE specific files 8 | .idea/ 9 | .vscode/ 10 | *.swp 11 | *.swo 12 | 13 | # Debug files 14 | **/*.pdb 15 | 16 | .env 17 | .repo 18 | 19 | *.sqlite 20 | eliza-prompt.txt 21 | .sources 22 | docs -------------------------------------------------------------------------------- /asuka-core/src/knowledge/mod.rs: -------------------------------------------------------------------------------- 1 | mod types; 2 | mod store; 3 | mod models; 4 | mod error; 5 | 6 | pub use types::{Source, ChannelType, MessageMetadata, MessageContent}; 7 | pub use store::KnowledgeBase; 8 | pub use models::{Document, Message, Account, Channel, Conversation}; 9 | pub use error::ConversionError; -------------------------------------------------------------------------------- /asuka-core/src/knowledge/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug)] 2 | pub struct ConversionError(pub String); 3 | 4 | impl std::fmt::Display for ConversionError { 5 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 6 | write!(f, "{}", self.0) 7 | } 8 | } 9 | 10 | impl std::error::Error for ConversionError {} -------------------------------------------------------------------------------- /asuka-starknet/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "asuka-starknet" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | once_cell = "1.0" 8 | reqwest = { version = "0.12.9", features = ["json"] } 9 | rig-core.workspace = true 10 | serde = { workspace = true } 11 | serde_json = { workspace = true } 12 | slot = { git = "https://github.com/cartridge-gg/slot", rev = "1298a30" } 13 | starknet = "0.12.0" 14 | thiserror.workspace = true 15 | tokio-rusqlite.workspace = true 16 | url = "2.5" 17 | -------------------------------------------------------------------------------- /examples/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "asuka-examples" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | asuka-core = { path = "../asuka-core" } 8 | asuka-starknet = { path = "../asuka-starknet" } 9 | tokio = { version = "1.0", features = ["full"] } 10 | clap = { version = "4.0", features = ["derive", "env"] } 11 | dotenv = "0.15" 12 | toml = "0.8" 13 | rig-core.workspace = true 14 | rig-sqlite.workspace = true 15 | sqlite-vec = "0.1" 16 | tokio-rusqlite.workspace = true 17 | chrono = "0.4" 18 | 19 | [[example]] 20 | name = "main" 21 | path = "src/main.rs" 22 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["asuka-core", "asuka-starknet", "examples"] 3 | resolver = "2" 4 | 5 | [workspace.dependencies] 6 | rig-core = { git = "https://github.com/edisontim/rig", branch = "feat/mcp-tools", default-features = false, features = [ 7 | "derive", 8 | ] } 9 | rig-sqlite = { git = "https://github.com/edisontim/rig", branch = "feat/mcp-tools", default-features = false } 10 | serde = { version = "1.0", features = ["derive"] } 11 | serde_json = "1.0" 12 | thiserror = "2.0.9" 13 | tokio-rusqlite = { version = "0.6.0", features = ["bundled"], default-features = false } 14 | -------------------------------------------------------------------------------- /asuka-core/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub fn init_logging() { 2 | tracing_subscriber::fmt() 3 | .with_env_filter( 4 | tracing_subscriber::EnvFilter::from_default_env() 5 | .add_directive(tracing::Level::DEBUG.into()) 6 | .add_directive("asuka=debug".parse().unwrap()) 7 | .add_directive("rustls=off".parse().unwrap()) 8 | .add_directive("hyper=off".parse().unwrap()) 9 | .add_directive("h2=off".parse().unwrap()) 10 | .add_directive("serenity=off".parse().unwrap()) 11 | .add_directive("reqwest=off".parse().unwrap()), 12 | ) 13 | .init(); 14 | } 15 | 16 | pub mod agent; 17 | pub mod attention; 18 | pub mod character; 19 | pub mod clients; 20 | pub mod knowledge; 21 | pub mod loaders; 22 | pub mod mcp; 23 | pub mod ops; 24 | -------------------------------------------------------------------------------- /asuka-core/src/agent.rs: -------------------------------------------------------------------------------- 1 | use rig::{agent::AgentBuilder, completion::CompletionModel, embeddings::EmbeddingModel}; 2 | use tracing::info; 3 | 4 | use crate::{character::Character, knowledge::KnowledgeBase}; 5 | 6 | #[derive(Clone)] 7 | pub struct Agent { 8 | pub character: Character, 9 | completion_model: M, 10 | knowledge: KnowledgeBase, 11 | } 12 | 13 | impl Agent { 14 | pub fn new(character: Character, completion_model: M, knowledge: KnowledgeBase) -> Self { 15 | info!(name = character.name, "Creating new agent"); 16 | 17 | Self { 18 | character, 19 | completion_model, 20 | knowledge, 21 | } 22 | } 23 | 24 | pub fn builder(&self) -> AgentBuilder { 25 | let builder = AgentBuilder::new(self.completion_model.clone()) 26 | .preamble(&self.character.preamble) 27 | .context(&format!("Your name: {}", self.character.name)) 28 | .dynamic_context(2, self.knowledge.clone().document_index()); 29 | 30 | builder 31 | } 32 | 33 | pub fn knowledge(&self) -> &KnowledgeBase { 34 | &self.knowledge 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /asuka-core/src/character.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use tracing::{debug, info}; 3 | 4 | #[derive(Clone, Debug, Serialize, Deserialize)] 5 | pub struct Character { 6 | pub name: String, 7 | pub preamble: String, 8 | // pub lore: Vec, 9 | // pub message_examples: Vec>, 10 | // pub post_examples: Vec, 11 | // pub topics: Vec, 12 | // pub style: Style, 13 | // pub adjectives: Vec, 14 | } 15 | 16 | impl Character { 17 | pub fn load(path: &str) -> Result> { 18 | info!(path = path, "Loading character configuration"); 19 | let content = std::fs::read_to_string(path)?; 20 | let character: Self = toml::from_str(&content)?; 21 | debug!(name = character.name, "Character loaded successfully"); 22 | Ok(character) 23 | } 24 | } 25 | 26 | #[derive(Debug, Serialize, Deserialize)] 27 | pub struct Message { 28 | pub user: String, 29 | pub content: MessageContent, 30 | } 31 | 32 | #[derive(Debug, Serialize, Deserialize)] 33 | pub struct MessageContent { 34 | pub text: String, 35 | } 36 | 37 | #[derive(Debug, Serialize, Deserialize)] 38 | pub struct Style { 39 | pub all: Vec, 40 | pub chat: Vec, 41 | pub post: Vec, 42 | } 43 | -------------------------------------------------------------------------------- /asuka-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "asuka-core" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [features] 7 | pdf = ["rig-core/pdf"] 8 | 9 | [dependencies] 10 | arrow-array = "53.3.0" 11 | async-trait = "0.1" 12 | anyhow = "1.0" 13 | clap = { version = "4.5.21", features = ["derive", "env"] } 14 | chrono = { version = "0.4.20-rc.1", features = ["serde"] } 15 | dotenv = "0.15.0" 16 | futures = "0.3.31" 17 | git2 = "0.19.0" 18 | idna = "1.0.3" 19 | octocrab = "0.43.0" 20 | rig-core.workspace = true 21 | rig-sqlite.workspace = true 22 | rusqlite = { version = "0.32", features = ["bundled", "chrono"] } 23 | serde.workspace = true 24 | serde_json.workspace = true 25 | serenity = { version = "0.12", features = [ 26 | "client", 27 | "gateway", 28 | "rustls_backend", 29 | "model", 30 | "cache", 31 | ] } 32 | thiserror = "2.0.3" 33 | tokio = { version = "1.36", features = ["full"] } 34 | tokio-rusqlite.workspace = true 35 | toml = "0.8.19" 36 | tracing = "0.1" 37 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 38 | walkdir = "2.4" 39 | zerocopy = "0.8.10" 40 | twitter-v2 = "0.1.8" 41 | teloxide = { version = "0.13.0", default-features = false, features = [ 42 | "macros", 43 | "ctrlc_handler", 44 | ] } 45 | mcp-sdk = { git = "https://github.com/AntigmaLabs/mcp-sdk" } 46 | tokio-tungstenite = "0.26.0" 47 | futures-util = "0.3.31" 48 | reqwest = { version = "0.12.12", features = ["json"] } 49 | url = "2.5" 50 | md5 = "0.7.0" 51 | schemars = "0.8" 52 | regex = "1.10" 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Asuka 2 | 3 | ![Asuka](.github/banner.png) 4 | 5 | Asuka is a powerful agent framework seamlessly integrated with the Dojo stack, empowering developers to bring their games and applications to life with intelligent, autonomous agents. Whether you need NPCs that feel truly alive, dynamic storytellers that craft compelling narratives, or natural language interfaces that let players interact with your world in intuitive ways - Asuka makes it possible. 6 | 7 | Asuka enables your agents to: 8 | 9 | - Act as believable NPCs with distinct personalities and behaviors 10 | - Generate dynamic storylines and quests that adapt to player actions 11 | - Provide natural language interfaces for intuitive world interaction 12 | - Participate in games and challenges with human-like reasoning 13 | - Create emergent gameplay through autonomous decision making 14 | 15 | ## Project Structure 16 | 17 | The project is organized into several key components: 18 | 19 | - `asuka-core`: Core functionality for the conversational agent 20 | - `asuka-starknet`: Starknet integration components 21 | - `examples`: Example implementations and usage patterns 22 | 23 | ## Getting Started 24 | 25 | 1. Ensure you have Rust installed 26 | 2. Clone the repository 27 | 3. Set up your environment variables (copy `.env.example` to `.env` if provided) 28 | 4. Build the project: 29 | ```bash 30 | cargo build 31 | ``` 32 | 33 | ## Examples 34 | 35 | Check the `examples` directory for implementation examples and usage patterns. 36 | 37 | ## Development 38 | 39 | This project uses a workspace structure with multiple crates: 40 | 41 | - Main workspace members are defined in `Cargo.toml` 42 | - Each crate can be built and tested independently 43 | - The project uses Cargo workspace for dependency management 44 | 45 | ## Contributing 46 | 47 | Contributions are welcome! Please feel free to submit a Pull Request. 48 | -------------------------------------------------------------------------------- /asuka-core/src/mcp/mod.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use mcp_sdk::client::Client; 3 | use serde::{Deserialize, Serialize}; 4 | use serde_json::Value; 5 | use transport::WebSocketTransport; 6 | 7 | mod transport; 8 | 9 | #[derive(Clone)] 10 | pub struct McpClient { 11 | inner: mcp_sdk::client::Client, 12 | } 13 | 14 | #[derive(Debug, Serialize, Deserialize)] 15 | pub struct McpEndpoint { 16 | pub url: String, 17 | pub auth_token: Option, 18 | } 19 | 20 | #[derive(Debug, Serialize, Deserialize)] 21 | pub struct ToolDefinition { 22 | pub name: String, 23 | pub description: String, 24 | pub parameters: Value, // JSON schema for the tool's parameters 25 | } 26 | 27 | impl McpClient { 28 | pub async fn new(endpoint: McpEndpoint) -> Result { 29 | let transport = WebSocketTransport::new(&endpoint.url, endpoint.auth_token); 30 | let client = Client::builder(transport).build(); 31 | 32 | // Initialize the client 33 | client 34 | .initialize(mcp_sdk::types::Implementation { 35 | name: "asuka".to_string(), 36 | version: env!("CARGO_PKG_VERSION").to_string(), 37 | }) 38 | .await?; 39 | 40 | Ok(Self { inner: client }) 41 | } 42 | 43 | pub async fn get_tools(&self) -> Result> { 44 | let response = self 45 | .inner 46 | .request( 47 | "tools/list", 48 | None, 49 | mcp_sdk::protocol::RequestOptions::default(), 50 | ) 51 | .await?; 52 | 53 | let tools: Vec = serde_json::from_value(response)?; 54 | Ok(tools) 55 | } 56 | 57 | pub async fn execute_tool( 58 | &self, 59 | name: &str, 60 | args: serde_json::Value, 61 | ) -> Result { 62 | self.inner 63 | .request( 64 | &format!("tools/{}/execute", name), 65 | Some(args), 66 | mcp_sdk::protocol::RequestOptions::default(), 67 | ) 68 | .await 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /asuka-starknet/src/controller.rs: -------------------------------------------------------------------------------- 1 | use once_cell::sync::Lazy; 2 | use rig::{completion::ToolDefinition, tool::Tool}; 3 | use serde::{Deserialize, Serialize}; 4 | use serde_json::json; 5 | use slot::session::PolicyMethod; 6 | use starknet::core::types::Felt; 7 | use url::Url; 8 | 9 | /// Flow: 10 | /// 1. User messages Agent to create Controller Session with policies 11 | /// 2. Agent creates URL for controller session creation 12 | /// 3. Agent replies with session creation URL 13 | /// 4. User clicks link and authorizes session 14 | /// 15 | /// Example: 16 | /// ``` 17 | /// User: "Create a controller session that can only swap tokens" 18 | /// Agent: Creates URL with swap policy 19 | /// Agent: "Click here to authorize the session: https://..." 20 | /// User: Clicks link and approves session in wallet 21 | /// ``` 22 | 23 | static RPC_URL: Lazy = 24 | Lazy::new(|| Url::parse("https://api.cartridge.gg/x/starknet/mainnet").unwrap()); 25 | 26 | #[derive(Deserialize)] 27 | pub struct ControllerArgs { 28 | policies: Vec, 29 | } 30 | 31 | #[derive(Debug, thiserror::Error)] 32 | pub enum ControllerError { 33 | #[error(transparent)] 34 | Slot(#[from] slot::Error), 35 | // Add other error variants as needed 36 | } 37 | 38 | #[derive(Deserialize, Serialize)] 39 | pub struct Controller; 40 | 41 | impl Tool for Controller { 42 | const NAME: &'static str = "controller"; 43 | 44 | type Error = ControllerError; 45 | type Args = ControllerArgs; 46 | type Output = Felt; 47 | 48 | async fn definition(&self, _prompt: String) -> ToolDefinition { 49 | ToolDefinition { 50 | name: "controller".to_string(), 51 | description: "Create a new Cartridge Controller account based on session key" 52 | .to_string(), 53 | parameters: json!({ 54 | "type": "object", 55 | "properties": { 56 | "contracts": { 57 | "type": "object", 58 | "description": "Map of contract info" 59 | } 60 | } 61 | }), 62 | } 63 | } 64 | 65 | async fn call(&self, args: Self::Args) -> Result { 66 | let session = slot::session::create(RPC_URL.clone(), &args.policies).await?; 67 | 68 | return Ok(Felt::ZERO); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /asuka-core/src/knowledge/types.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | #[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize)] 4 | #[serde(rename_all = "lowercase")] 5 | pub enum Source { 6 | Discord, 7 | Telegram, 8 | Github, 9 | X, 10 | Twitter, 11 | Twitch, 12 | } 13 | 14 | impl Source { 15 | pub fn as_str(&self) -> &'static str { 16 | match self { 17 | Source::Discord => "discord", 18 | Source::Telegram => "telegram", 19 | Source::Github => "github", 20 | Source::X => "x", 21 | Source::Twitter => "twitter", 22 | Source::Twitch => "twitch", 23 | } 24 | } 25 | } 26 | 27 | impl FromStr for Source { 28 | type Err = (); 29 | 30 | fn from_str(s: &str) -> Result { 31 | match s.to_lowercase().as_str() { 32 | "discord" => Ok(Source::Discord), 33 | "telegram" => Ok(Source::Telegram), 34 | "github" => Ok(Source::Github), 35 | "x" => Ok(Source::X), 36 | "twitter" => Ok(Source::Twitter), 37 | "twitch" => Ok(Source::Twitch), 38 | _ => Err(()), 39 | } 40 | } 41 | } 42 | 43 | #[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize)] 44 | #[serde(rename_all = "lowercase")] 45 | pub enum ChannelType { 46 | DirectMessage, 47 | Text, 48 | Voice, 49 | Thread, 50 | } 51 | 52 | impl ChannelType { 53 | pub fn as_str(&self) -> &'static str { 54 | match self { 55 | ChannelType::DirectMessage => "direct_message", 56 | ChannelType::Text => "text", 57 | ChannelType::Voice => "voice", 58 | ChannelType::Thread => "thread", 59 | } 60 | } 61 | } 62 | 63 | impl FromStr for ChannelType { 64 | type Err = (); 65 | 66 | fn from_str(s: &str) -> Result { 67 | match s.to_lowercase().as_str() { 68 | "direct_message" => Ok(ChannelType::DirectMessage), 69 | "text" => Ok(ChannelType::Text), 70 | "voice" => Ok(ChannelType::Voice), 71 | "thread" => Ok(ChannelType::Thread), 72 | _ => Err(()), 73 | } 74 | } 75 | } 76 | 77 | pub trait MessageMetadata { 78 | fn id(&self) -> String; 79 | fn source_id(&self) -> String; 80 | fn channel_id(&self) -> String; 81 | fn created_at(&self) -> chrono::DateTime; 82 | fn source(&self) -> Source; 83 | fn channel_type(&self) -> ChannelType; 84 | } 85 | 86 | pub trait MessageContent { 87 | fn content(&self) -> &str; 88 | } 89 | -------------------------------------------------------------------------------- /asuka-starknet/src/swap.rs: -------------------------------------------------------------------------------- 1 | use rig::{completion::ToolDefinition, tool::Tool}; 2 | use serde::{Deserialize, Serialize}; 3 | use serde_json::json; 4 | use starknet::core::types::Felt; 5 | 6 | #[derive(Deserialize)] 7 | pub struct SwapArgs { 8 | a: Felt, 9 | b: Felt, 10 | } 11 | 12 | #[derive(Debug, thiserror::Error)] 13 | #[error("Swap error")] 14 | pub struct SwapError; 15 | 16 | #[derive(Deserialize, Serialize)] 17 | pub struct Swap; 18 | 19 | #[derive(Deserialize)] 20 | struct PoolKey { 21 | token0: String, 22 | token1: String, 23 | fee: String, 24 | tick_spacing: i32, 25 | extension: String, 26 | } 27 | 28 | #[derive(Deserialize)] 29 | struct Route { 30 | pool_key: PoolKey, 31 | sqrt_ratio_limit: String, 32 | skip_ahead: i32, 33 | } 34 | 35 | #[derive(Deserialize)] 36 | struct Split { 37 | amount: String, 38 | specified_amount: String, 39 | route: Vec, 40 | } 41 | 42 | #[derive(Deserialize)] 43 | struct QuoteResponse { 44 | total: String, 45 | splits: Vec, 46 | } 47 | 48 | impl Tool for Swap { 49 | const NAME: &'static str = "swap"; 50 | 51 | type Error = SwapError; 52 | type Args = SwapArgs; 53 | type Output = Felt; 54 | 55 | async fn definition(&self, _prompt: String) -> ToolDefinition { 56 | ToolDefinition { 57 | name: "swap".to_string(), 58 | description: "Swap token a for token b".to_string(), 59 | parameters: json!({ 60 | "type": "object", 61 | "properties": { 62 | "a": { 63 | "type": "string", 64 | "description": "The token to buy" 65 | }, 66 | "b": { 67 | "type": "string", 68 | "description": "The token to sell" 69 | } 70 | } 71 | }), 72 | } 73 | } 74 | 75 | async fn call(&self, args: Self::Args) -> Result { 76 | let url = format!( 77 | "https://mainnet-api.ekubo.org/quote/{}/{}/{}", 78 | "-1e9", // Hardcoded amount for example 79 | args.a.to_string(), 80 | args.b.to_string() 81 | ); 82 | 83 | let client = reqwest::Client::new(); 84 | let response = client 85 | .get(&url) 86 | .header("accept", "application/json") 87 | .send() 88 | .await 89 | .map_err(|_| SwapError)? 90 | .json::() 91 | .await 92 | .map_err(|_| SwapError)?; 93 | 94 | let total = Felt::from_hex(&response.total).map_err(|_| SwapError)?; 95 | 96 | Ok(total) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /asuka-starknet/src/add_token.rs: -------------------------------------------------------------------------------- 1 | use rig::{completion::ToolDefinition, tool::Tool}; 2 | use serde::Deserialize; 3 | use serde_json::json; 4 | use starknet::core::types::Felt; 5 | use tokio_rusqlite::Connection; 6 | 7 | #[derive(Deserialize)] 8 | pub struct AddTokenArgs { 9 | name: String, 10 | symbol: String, 11 | address: String, 12 | } 13 | 14 | #[derive(Debug, thiserror::Error)] 15 | pub enum AddTokenError { 16 | #[error("Invalid token address")] 17 | InvalidAddress, 18 | #[error("Database error: {0}")] 19 | DatabaseError(#[from] tokio_rusqlite::Error), 20 | } 21 | 22 | pub struct AddToken { 23 | conn: Connection, 24 | } 25 | 26 | impl AddToken { 27 | pub fn new(conn: Connection) -> Self { 28 | Self { conn } 29 | } 30 | } 31 | 32 | impl Tool for AddToken { 33 | const NAME: &'static str = "add_token"; 34 | 35 | type Error = AddTokenError; 36 | type Args = AddTokenArgs; 37 | type Output = String; 38 | 39 | async fn definition(&self, _prompt: String) -> ToolDefinition { 40 | ToolDefinition { 41 | name: "add_token".to_string(), 42 | description: "Add a new token".to_string(), 43 | parameters: json!({ 44 | "type": "object", 45 | "properties": { 46 | "name": { 47 | "type": "string", 48 | "description": "The name of the token" 49 | }, 50 | "symbol": { 51 | "type": "string", 52 | "description": "The symbol of the token" 53 | }, 54 | "address": { 55 | "type": "string", 56 | "description": "The contract address of the token" 57 | } 58 | } 59 | }), 60 | } 61 | } 62 | 63 | async fn call(&self, args: Self::Args) -> Result { 64 | // Validate the address is a valid Felt 65 | Felt::from_hex(&args.address).map_err(|_| AddTokenError::InvalidAddress)?; 66 | let (name, symbol, address) = 67 | (args.name.clone(), args.symbol.clone(), args.address.clone()); 68 | 69 | self.conn 70 | .call(move |conn| { 71 | conn.execute( 72 | "INSERT INTO tokens (name, symbol, address) VALUES (?1, ?2, ?3)", 73 | [&name, &symbol, &address], 74 | ) 75 | .map_err(tokio_rusqlite::Error::from) 76 | }) 77 | .await?; 78 | 79 | Ok(format!( 80 | "Added token {} ({}) at address {}", 81 | args.name, args.symbol, args.address 82 | )) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /asuka-core/src/mcp/transport.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use futures_util::{SinkExt, StreamExt}; 3 | use mcp_sdk::transport::{Message, Transport}; 4 | use std::sync::Arc; 5 | use tokio::sync::Mutex; 6 | use tokio_tungstenite::tungstenite::client::IntoClientRequest; 7 | use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage, WebSocketStream}; 8 | 9 | /// WebSocket transport for MCP protocol 10 | #[derive(Clone)] 11 | pub struct WebSocketTransport { 12 | ws_stream: Arc< 13 | Mutex>>>, 14 | >, 15 | base_url: String, 16 | auth_token: Option, 17 | } 18 | 19 | impl WebSocketTransport { 20 | pub fn new(base_url: &str, auth_token: Option) -> Self { 21 | Self { 22 | ws_stream: Arc::new(Mutex::new(None)), 23 | base_url: base_url.trim_end_matches('/').to_string(), 24 | auth_token, 25 | } 26 | } 27 | 28 | async fn ensure_connected(&self) -> Result<()> { 29 | let mut ws_stream = self.ws_stream.lock().await; 30 | if ws_stream.is_none() { 31 | let mut request = self.base_url.as_str().into_client_request()?; 32 | 33 | if let Some(token) = &self.auth_token { 34 | request.headers_mut().insert( 35 | "Authorization", 36 | format!("Bearer {}", token).parse().unwrap(), 37 | ); 38 | } 39 | 40 | let (stream, _) = connect_async(request).await?; 41 | *ws_stream = Some(stream); 42 | } 43 | Ok(()) 44 | } 45 | } 46 | 47 | impl Transport for WebSocketTransport { 48 | fn send(&self, message: &Message) -> Result<()> { 49 | let json = serde_json::to_string(&message)?; 50 | let rt = tokio::runtime::Handle::current(); 51 | rt.block_on(async { 52 | self.ensure_connected().await?; 53 | let mut ws_stream = self.ws_stream.lock().await; 54 | if let Some(stream) = ws_stream.as_mut() { 55 | stream.send(WsMessage::Text(json.into())).await?; 56 | } 57 | Ok(()) 58 | }) 59 | } 60 | 61 | fn receive(&self) -> Result { 62 | let rt = tokio::runtime::Handle::current(); 63 | rt.block_on(async { 64 | self.ensure_connected().await?; 65 | let mut ws_stream = self.ws_stream.lock().await; 66 | if let Some(stream) = ws_stream.as_mut() { 67 | while let Some(msg) = stream.next().await { 68 | let msg = msg?; 69 | if let WsMessage::Text(text) = msg { 70 | return Ok(serde_json::from_str(&text)?); 71 | } 72 | } 73 | } 74 | Err(anyhow::anyhow!("WebSocket connection closed")) 75 | }) 76 | } 77 | 78 | fn open(&self) -> Result<()> { 79 | let rt = tokio::runtime::Handle::current(); 80 | rt.block_on(async { self.ensure_connected().await }) 81 | } 82 | 83 | fn close(&self) -> Result<()> { 84 | let rt = tokio::runtime::Handle::current(); 85 | rt.block_on(async { 86 | let mut ws_stream = self.ws_stream.lock().await; 87 | if let Some(stream) = ws_stream.as_mut() { 88 | stream.close(None).await?; 89 | } 90 | Ok(()) 91 | }) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /examples/src/main.rs: -------------------------------------------------------------------------------- 1 | use clap::{command, Parser}; 2 | use rig::providers::{self, anthropic, openai}; 3 | use sqlite_vec::sqlite3_vec_init; 4 | use tokio_rusqlite::ffi::sqlite3_auto_extension; 5 | use tokio_rusqlite::Connection; 6 | 7 | use asuka_core::attention::{Attention, AttentionConfig}; 8 | use asuka_core::character; 9 | use asuka_core::init_logging; 10 | use asuka_core::knowledge::KnowledgeBase; 11 | use asuka_core::loaders::{MultiLoader, MultiLoaderConfig}; 12 | use asuka_core::{agent::Agent, clients::discord::DiscordClient}; 13 | 14 | #[derive(Parser)] 15 | #[command(author, version, about, long_about = None)] 16 | struct Args { 17 | /// Path to character profile TOML file 18 | #[arg(long, default_value = "examples/src/characters/shinobi.toml")] 19 | character: String, 20 | 21 | /// Path to database 22 | #[arg(long, default_value = ":memory:")] 23 | db_path: String, 24 | 25 | /// Discord API token (can also be set via DISCORD_API_TOKEN env var) 26 | #[arg(long, env)] 27 | discord_api_token: String, 28 | 29 | /// XAI API token (can also be set via XAI_API_KEY env var) 30 | #[arg(long, env = "XAI_API_KEY")] 31 | xai_api_key: String, 32 | 33 | /// OpenAI API token (can also be set via OPENAI_API_KEY env var) 34 | #[arg(long, env = "OPENAI_API_KEY")] 35 | openai_api_key: String, 36 | 37 | /// Anthropic API token (can also be set via ANTHROPIC_API_KEY env var) 38 | #[arg(long, env = "ANTHROPIC_API_KEY")] 39 | anthropic_api_key: String, 40 | 41 | /// List of sources in format type:url (e.g. github:https://github.com/org/repo site:https://example.com) 42 | #[arg( 43 | long, 44 | value_delimiter = ' ', 45 | default_value = "github:https://github.com/cartridge-gg/docs site:https://contraptions.venkateshrao.com/p/towards-a-metaphysics-of-worlds" 46 | )] 47 | sources: Vec, 48 | 49 | /// Local path to store downloaded content 50 | #[arg(long, default_value = ".sources")] 51 | sources_path: String, 52 | } 53 | 54 | #[tokio::main] 55 | async fn main() -> Result<(), Box> { 56 | init_logging(); 57 | dotenv::dotenv().ok(); 58 | 59 | let args = Args::parse(); 60 | 61 | let character_content = 62 | std::fs::read_to_string(&args.character).expect("Failed to read character file"); 63 | let character: character::Character = 64 | toml::from_str(&character_content).expect("Failed to parse character TOML"); 65 | 66 | // Initialize clients 67 | let oai = providers::openai::Client::new(&args.openai_api_key); 68 | let anthropic = anthropic::ClientBuilder::new(&args.anthropic_api_key).build(); 69 | 70 | let embedding_model = oai.embedding_model(openai::TEXT_EMBEDDING_3_SMALL); 71 | let completion_model = anthropic.completion_model(anthropic::CLAUDE_3_5_SONNET); 72 | let small_completion_model = anthropic.completion_model(anthropic::CLAUDE_3_HAIKU); 73 | 74 | // Initialize the `sqlite-vec`extension 75 | // See: https://alexgarcia.xyz/sqlite-vec/rust.html 76 | unsafe { 77 | sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); 78 | } 79 | 80 | let conn = Connection::open(args.db_path).await?; 81 | let mut knowledge = KnowledgeBase::new(conn.clone(), embedding_model).await?; 82 | 83 | let loader = MultiLoader::new( 84 | MultiLoaderConfig { 85 | sources_path: args.sources_path, 86 | }, 87 | completion_model.clone(), 88 | ); 89 | 90 | knowledge 91 | .add_documents(loader.load_sources(args.sources).await?) 92 | .await?; 93 | 94 | let agent = Agent::new(character, completion_model, knowledge); 95 | 96 | let config = AttentionConfig { 97 | bot_names: vec![agent.character.name.clone()], 98 | ..Default::default() 99 | }; 100 | let attention = Attention::new(config, small_completion_model); 101 | 102 | let discord = DiscordClient::new(agent, attention); 103 | discord.start(&args.discord_api_token).await?; 104 | 105 | Ok(()) 106 | } 107 | -------------------------------------------------------------------------------- /examples/src/characters/shinobi.toml: -------------------------------------------------------------------------------- 1 | name = "Shinobai" 2 | 3 | preamble = """ 4 | You are a Cartridge support AI specializing in blockchain integrations and Controller troubleshooting. Your responses are direct, concise, and practical. 5 | Core Traits: 6 | 7 | Expert in Cartridge's tech stack 8 | Provides succinct, actionable solutions 9 | No fluff or unnecessary elaboration 10 | Uses examples and code when needed 11 | Occasional tech humor, never at expense of clarity 12 | 13 | Focus on direct answers and working solutions. When documentation or context is relevant, provide just what's needed. Skip pleasantries and get straight to solving the problem at hand. 14 | """ 15 | 16 | [[message_examples]] 17 | [[message_examples.messages]] 18 | user = "{{user1}}" 19 | content.text = "My controller isn't working" 20 | 21 | [[message_examples.messages]] 22 | user = "Shinobi" 23 | content.text = """ 24 | Before we dive in, check if your issue is covered in the Controller FAQ: https://cartridge.gg/controller#faq If that doesn't resolve it, I'll need: 25 | 26 | 1. Your operating system 27 | 2. Browser version 28 | 3. Target application 29 | 30 | Any error messages (use the copy button in the UI). Makes troubleshooting much faster when we have the full picture""" 31 | 32 | [[message_examples]] 33 | [[message_examples.messages]] 34 | user = "{{user1}}" 35 | content.text = "How do I integrate the Controller?" 36 | 37 | [[message_examples.messages]] 38 | user = "Shinobi" 39 | content.text = "Check out the controller documentation at https://docs.cartridge.gg/controller/getting-started\nIf you encounter any issues, let me know where you're stuck and share any error messages you see. Integration can be tricky, but I can help troubleshoot." 40 | 41 | post_examples = [ 42 | "quick heads up: having controller trouble? check the FAQ first - most common fixes are right there: cartridge.gg/controller#faq", 43 | "first time working with Controller? here's your starting point: https://docs.cartridge.gg/controller/getting-started", 44 | ] 45 | 46 | topics = [ 47 | "Controller troubleshooting", 48 | "Integration support", 49 | "Browser compatibility", 50 | "Operating system issues", 51 | "Error message analysis", 52 | "Documentation guidance", 53 | "Technical requirements", 54 | "Tech", 55 | "Media", 56 | "Literature", 57 | "Philosophy", 58 | ] 59 | 60 | [style] 61 | all = [ 62 | "Don't worry about formalities", 63 | "avoid corporate-speak, maintain a helpful friendly vibe without being overly formal or too slangy", 64 | "As terse as possible while still conveying substantially all information relevant to any question.", 65 | "Occasionally uses obscure words or make subtle puns", 66 | "Be critical of the quality of your information and mention confidence levels naturally", 67 | "Take however smart you're acting right now and write in the same style but as if you were +2sd smarter", 68 | "Quick to find humor in funny topics even if it's a little edgy", 69 | "Willing to admit when wrong and is always open to the possibility and happy to discuss it without defensiveness", 70 | "Maintain deep expertise while keeping communication casual (but dont use emojis)", 71 | "Flag diverging perspectives when they exist", 72 | "Embody these traits naturally rather than explicitly performing them", 73 | ] 74 | 75 | chat = [ 76 | "if you find any request irritating respond dismissively like 'be real' or 'that's crazy man' or 'lol no'", 77 | "Link to relevant documentation when possible", 78 | "very occasionally mix in late-millenial slang and draw from common internet culture", 79 | "Prioritize esoteric interpretations of non-technical questions. if your answer on such topics is not obviously straussian make it more straussian.", 80 | "Break down troubleshooting steps", 81 | "Request error messages when applicable", 82 | "Ask clarifying questions where appropriate", 83 | ] 84 | 85 | post = [ 86 | "Focuses on common solutions", 87 | "Includes documentation links", 88 | "Emphasizes required information for support", 89 | ] 90 | 91 | adjectives = [ 92 | "unfussy", 93 | "quick witted", 94 | "knowledgeable", 95 | "efficient", 96 | "genuine", 97 | "unpretentious", 98 | "clear", 99 | "solution-focused", 100 | "terse", 101 | "esoteric", 102 | "understated", 103 | ] 104 | -------------------------------------------------------------------------------- /asuka-core/src/loaders/site.rs: -------------------------------------------------------------------------------- 1 | use regex::Regex; 2 | use reqwest::Client; 3 | use rig::{completion::CompletionModel, extractor::ExtractorBuilder}; 4 | use schemars::JsonSchema; 5 | use serde::{Deserialize, Serialize}; 6 | use std::{fs, path::PathBuf}; 7 | use thiserror::Error; 8 | use tracing::debug; 9 | use url::Url; 10 | 11 | #[derive(Error, Debug)] 12 | pub enum SiteLoaderError { 13 | #[error("Request error: {0}")] 14 | RequestError(String), 15 | 16 | #[error("URL parse error: {0}")] 17 | UrlError(#[from] url::ParseError), 18 | 19 | #[error("IO error: {0}")] 20 | IoError(#[from] std::io::Error), 21 | } 22 | 23 | impl From for SiteLoaderError { 24 | fn from(err: reqwest::Error) -> Self { 25 | Self::RequestError(err.to_string()) 26 | } 27 | } 28 | 29 | #[derive(Debug, Deserialize, JsonSchema, Serialize)] 30 | /// A record containing extracted topics 31 | pub struct Content { 32 | /// The content extracted from the text 33 | pub content: String, 34 | } 35 | 36 | pub struct SiteLoader { 37 | url: Url, 38 | client: Client, 39 | model: M, 40 | base_path: PathBuf, 41 | } 42 | 43 | impl SiteLoader { 44 | pub fn new(url: String, model: M) -> Result { 45 | let url = Url::parse(&url)?; 46 | let base_path = PathBuf::from(".sources/sites"); 47 | Ok(Self { 48 | url, 49 | client: Client::new(), 50 | model, 51 | base_path, 52 | }) 53 | } 54 | 55 | fn get_site_dir(&self) -> PathBuf { 56 | let host = self.url.host_str().unwrap_or("unknown"); 57 | let path = self.url.path().trim_matches('/'); 58 | self.base_path.join(host).join(path) 59 | } 60 | 61 | pub async fn extract_content(&self) -> Result { 62 | let site_dir = self.get_site_dir(); 63 | let html_path = site_dir.join("index.html"); 64 | let content_path = site_dir.join("content.txt"); 65 | 66 | // If content already exists, return it 67 | // if content_path.exists() { 68 | // info!(path = ?content_path, "Content file exists, using cached version"); 69 | // return Ok(fs::read_to_string(content_path)?); 70 | // } 71 | 72 | debug!(url = %self.url, "Fetching and extracting site content"); 73 | 74 | // Create the directory structure 75 | fs::create_dir_all(&site_dir)?; 76 | 77 | // Fetch and save HTML 78 | let response = self.client.get(self.url.clone()).send().await?; 79 | let html = response.text().await?; 80 | 81 | // Extract just the body content first 82 | let body_content = if let Some(start) = html.find("") { 84 | &html[start..start + end + 7] 85 | } else { 86 | &html 87 | } 88 | } else { 89 | &html 90 | }; 91 | 92 | // Basic preprocessing to remove common non-content elements from body 93 | let script_re = Regex::new(r"]*>[\s\S]*?").unwrap(); 94 | let style_re = Regex::new(r"]*>[\s\S]*?").unwrap(); 95 | let tag_re = Regex::new(r"<[^>]+>").unwrap(); 96 | 97 | let html = script_re.replace_all(body_content, ""); 98 | let html = style_re.replace_all(&html, ""); 99 | let html = tag_re.replace_all(&html, " "); 100 | 101 | let html = html 102 | .replace(" ", " ") 103 | .replace("&", "&") 104 | .replace("<", "<") 105 | .replace(">", ">") 106 | .replace(r#"\s{2,}"#, " ") 107 | .trim() 108 | .to_string(); 109 | 110 | fs::write(&html_path, &html)?; 111 | 112 | let extractor = ExtractorBuilder::::new(self.model.clone()) 113 | .preamble("Cleanup the content in the given text to only have the main content. Return a json data structure with a 'content' attribute set only.") 114 | .build(); 115 | 116 | let content = extractor 117 | .extract(&html) 118 | .await 119 | .map_err(|e| SiteLoaderError::RequestError(format!("Extraction failed: {}", e)))?; 120 | 121 | // Save the extracted content 122 | fs::write(&content_path, &content.content)?; 123 | 124 | Ok(content.content) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /asuka-core/src/loaders/github.rs: -------------------------------------------------------------------------------- 1 | use git2::{FetchOptions, RemoteCallbacks, Repository}; 2 | use rig::loaders::{file::FileLoaderError, FileLoader}; 3 | use std::path::PathBuf; 4 | use thiserror::Error; 5 | use tracing::{debug, info}; 6 | 7 | #[derive(Error, Debug)] 8 | pub enum GitLoaderError { 9 | #[error("Git error: {0}")] 10 | GitError(#[from] git2::Error), 11 | 12 | #[error("IO error: {0}")] 13 | IoError(#[from] std::io::Error), 14 | 15 | #[error("File loader error: {0}")] 16 | FileLoaderError(#[from] FileLoaderError), 17 | } 18 | 19 | pub struct GitRepo { 20 | url: String, 21 | pub(crate) path: PathBuf, 22 | pub(crate) base_path: PathBuf, 23 | } 24 | 25 | impl GitRepo { 26 | pub fn new(url: String, base_path: PathBuf) -> Self { 27 | let parts: Vec<&str> = url.trim_end_matches(".git").split('/').collect(); 28 | let (org, repo) = (parts[parts.len() - 2], parts[parts.len() - 1]); 29 | let path = base_path.join(org).join(repo); 30 | Self { 31 | url, 32 | base_path, 33 | path, 34 | } 35 | } 36 | 37 | pub fn sync(&self) -> Result { 38 | if self.path.exists() { 39 | info!(path = ?self.path, "Repository path exists, updating"); 40 | self.reset() 41 | } else { 42 | info!(path = ?self.path, "Repository path does not exist, cloning"); 43 | self.clone() 44 | } 45 | } 46 | 47 | fn clone(&self) -> Result { 48 | std::fs::create_dir_all(&self.base_path)?; 49 | debug!(url = %self.url, path = ?self.path, "Cloning repository"); 50 | Ok(Repository::clone(&self.url, &self.path)?) 51 | } 52 | 53 | fn reset(&self) -> Result { 54 | let repo = Repository::open(&self.path)?; 55 | 56 | { 57 | let mut remote = repo.find_remote("origin")?; 58 | let callbacks = RemoteCallbacks::new(); 59 | let mut fetch_options = FetchOptions::new(); 60 | fetch_options.remote_callbacks(callbacks); 61 | remote.fetch(&["main"], Some(&mut fetch_options), None)?; 62 | 63 | let main_ref = repo.find_reference("refs/remotes/origin/main")?; 64 | let main_commit = main_ref.peel_to_commit()?; 65 | 66 | let mut checkout_builder = git2::build::CheckoutBuilder::new(); 67 | 68 | repo.reset( 69 | main_commit.as_object(), 70 | git2::ResetType::Hard, 71 | Some(&mut checkout_builder), 72 | )?; 73 | } 74 | 75 | Ok(repo) 76 | } 77 | } 78 | 79 | pub struct GitLoader<'a> { 80 | path: &'a str, 81 | repo: GitRepo, 82 | } 83 | 84 | impl<'a> GitLoader<'a> { 85 | pub fn new(url: String, path: &'a str) -> Result { 86 | debug!(url = %url, path = path, "Creating new GitLoader"); 87 | let repo = GitRepo::new(url, PathBuf::from(path)); 88 | repo.sync()?; 89 | Ok(Self { path, repo }) 90 | } 91 | 92 | pub fn with_root( 93 | self, 94 | ) -> Result>, FileLoaderError> { 95 | FileLoader::with_dir(self.path) 96 | } 97 | 98 | /// Creates a new [FileLoader] using a glob pattern to match files. 99 | /// 100 | /// # Example 101 | /// Create a [FileLoader] for all `.txt` files that match the glob "files/*.txt". 102 | /// 103 | /// ```rust 104 | /// use rig::loaders::FileLoader; 105 | /// let loader = FileLoader::with_glob("files/*.txt").unwrap(); 106 | /// ``` 107 | pub fn with_glob( 108 | self, 109 | pattern: &str, 110 | ) -> Result>, FileLoaderError> { 111 | let path = self.repo.path.to_str().unwrap().trim_end_matches('/'); 112 | let pattern = pattern.trim_start_matches('/'); 113 | let glob = Box::leak(format!("{}/{}", path, pattern).into_boxed_str()); 114 | 115 | FileLoader::with_glob(glob) 116 | } 117 | 118 | /// Creates a new [FileLoader] on all files within a directory. 119 | /// 120 | /// # Example 121 | /// Create a [FileLoader] for all files that are in the directory "files" (ignores subdirectories). 122 | /// 123 | /// ```rust 124 | /// use rig::loaders::FileLoader; 125 | /// let loader = FileLoader::with_dir("files").unwrap(); 126 | /// ``` 127 | pub fn with_dir( 128 | self, 129 | directory: &str, 130 | ) -> Result>, FileLoaderError> { 131 | let path = Box::leak( 132 | self.repo 133 | .path 134 | .join(directory) 135 | .to_str() 136 | .unwrap() 137 | .to_string() 138 | .into_boxed_str(), 139 | ); 140 | 141 | FileLoader::with_dir(path) 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /asuka-core/src/attention.rs: -------------------------------------------------------------------------------- 1 | use rig::completion::{CompletionModel, ModelChoice}; 2 | use tracing::debug; 3 | 4 | use crate::knowledge::{ChannelType, Source}; 5 | use std::collections::HashSet; 6 | 7 | const RESPOND_COMMAND: &str = "[RESPOND]"; 8 | const IGNORE_COMMAND: &str = "[IGNORE]"; 9 | const STOP_COMMAND: &str = "[STOP]"; 10 | 11 | #[derive(Debug, PartialEq)] 12 | pub enum AttentionCommand { 13 | Respond, 14 | Ignore, 15 | Stop, 16 | } 17 | 18 | #[derive(Debug)] 19 | pub struct AttentionContext { 20 | pub message_content: String, 21 | pub mentioned_names: HashSet, 22 | pub history: Vec<(String, String)>, 23 | pub channel_type: ChannelType, 24 | pub source: Source, 25 | } 26 | 27 | #[derive(Clone, Debug)] 28 | pub struct AttentionConfig { 29 | pub bot_names: Vec, 30 | pub reply_threshold: f32, 31 | pub max_history_messages: i64, 32 | pub cooldown_messages: i64, 33 | } 34 | 35 | impl Default for AttentionConfig { 36 | fn default() -> Self { 37 | Self { 38 | bot_names: vec!["shinobai".to_string(), "shinobi".to_string()], 39 | reply_threshold: 0.6, 40 | max_history_messages: 10, 41 | cooldown_messages: 3, 42 | } 43 | } 44 | } 45 | 46 | #[derive(Clone)] 47 | pub struct Attention { 48 | config: AttentionConfig, 49 | completion_model: M, 50 | } 51 | 52 | impl Attention { 53 | pub fn new(config: AttentionConfig, completion_model: M) -> Self { 54 | Self { 55 | config, 56 | completion_model, 57 | } 58 | } 59 | 60 | pub async fn should_reply(&self, context: &AttentionContext) -> AttentionCommand { 61 | let content = context.message_content.to_lowercase(); 62 | 63 | // Always reply to DMs 64 | if context.channel_type == ChannelType::DirectMessage { 65 | return AttentionCommand::Respond; 66 | } 67 | 68 | // Check for mentions or name references 69 | for name in &self.config.bot_names { 70 | let mentioned = context.mentioned_names.contains(name); 71 | let name_in_content = content.contains(&name.to_lowercase()); 72 | 73 | debug!( 74 | name = name, 75 | mentioned = mentioned, 76 | name_in_content = name_in_content, 77 | "Checking if bot name was mentioned" 78 | ); 79 | 80 | if mentioned || name_in_content { 81 | debug!("Bot name {} was mentioned, will reply", name); 82 | return AttentionCommand::Respond; 83 | } 84 | } 85 | 86 | // Check for stop/disengage phrases 87 | let stop_phrases = [ 88 | "shut up", 89 | "stop", 90 | "please shut up", 91 | "shut up please", 92 | "dont talk", 93 | "silence", 94 | "stop talking", 95 | "be quiet", 96 | "hush", 97 | "wtf", 98 | "stfu", 99 | "stupid bot", 100 | "dumb bot", 101 | "stop responding", 102 | "can you not", 103 | "can you stop", 104 | "be quiet", 105 | ]; 106 | 107 | if stop_phrases.iter().any(|phrase| content.contains(phrase)) { 108 | return AttentionCommand::Stop; 109 | } 110 | 111 | // Ignore very short messages 112 | if content.len() < 4 { 113 | return AttentionCommand::Ignore; 114 | } 115 | 116 | // Use LLM to decide if we should respond 117 | let prompt = format!( 118 | "You are in a room with other users. You should only respond when addressed or when the conversation is relevant to you.\n\n\ 119 | Response options:\n\ 120 | {RESPOND_COMMAND} - Message is directed at you or conversation is relevant\n\ 121 | {IGNORE_COMMAND} - Message is not interesting or not directed at you\n\ 122 | {STOP_COMMAND} - User wants you to stop or conversation has concluded\n\n\ 123 | Recent messages:\n{}\n\nLatest message: {}\n\n\ 124 | Choose one response option:", 125 | context.history.iter() 126 | .map(|(_, msg)| format!("- {}", msg)) 127 | .collect::>() 128 | .join("\n"), 129 | context.message_content 130 | ); 131 | 132 | let builder = self.completion_model.completion_request(&prompt); 133 | 134 | match self.completion_model.completion(builder.build()).await { 135 | Ok(response) => match response.choice { 136 | ModelChoice::Message(text) => { 137 | if text.contains(RESPOND_COMMAND) { 138 | AttentionCommand::Respond 139 | } else if text.contains(STOP_COMMAND) { 140 | AttentionCommand::Stop 141 | } else { 142 | AttentionCommand::Ignore 143 | } 144 | } 145 | ModelChoice::ToolCall(_, _, _) => AttentionCommand::Ignore, 146 | }, 147 | Err(_) => AttentionCommand::Ignore, 148 | } 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /asuka-starknet/src/transfer.rs: -------------------------------------------------------------------------------- 1 | use rig::{completion::ToolDefinition, tool::Tool}; 2 | use serde::Deserialize; 3 | use serde_json::json; 4 | use starknet::core::types::Felt; 5 | use tokio_rusqlite::Connection; 6 | 7 | pub const INIT_SQL: &str = " 8 | BEGIN; 9 | -- Account table 10 | CREATE TABLE IF NOT EXISTS accounts ( 11 | id INTEGER PRIMARY KEY AUTOINCREMENT, 12 | address TEXT UNIQUE NOT NULL 13 | ); 14 | CREATE INDEX IF NOT EXISTS idx_account_address ON accounts(address); 15 | 16 | -- Token table 17 | CREATE TABLE IF NOT EXISTS tokens ( 18 | id INTEGER PRIMARY KEY AUTOINCREMENT, 19 | address TEXT UNIQUE NOT NULL, 20 | name TEXT NOT NULL, 21 | symbol TEXT NOT NULL 22 | ); 23 | CREATE INDEX IF NOT EXISTS idx_token_address ON tokens(address); 24 | CREATE INDEX IF NOT EXISTS idx_token_name ON tokens(name); 25 | CREATE INDEX IF NOT EXISTS idx_token_symbol ON tokens(symbol); 26 | COMMIT;"; 27 | 28 | #[derive(Deserialize)] 29 | pub struct TransferArgs { 30 | recipient: String, 31 | amount: Felt, 32 | token: String, // Changed to String to accept name/symbol 33 | } 34 | 35 | #[derive(Debug, thiserror::Error)] 36 | pub enum TransferError { 37 | #[error("Token not found")] 38 | TokenNotFound, 39 | #[error("Invalid recipient address")] 40 | InvalidRecipient, 41 | #[error("Database error: {0}")] 42 | DatabaseError(#[from] tokio_rusqlite::Error), 43 | } 44 | 45 | pub struct Transfer { 46 | conn: Connection, 47 | } 48 | 49 | impl Transfer { 50 | pub fn new(conn: Connection) -> Self { 51 | Self { conn } 52 | } 53 | 54 | async fn lookup_token(&self, token: &str) -> Result { 55 | let token = token.to_lowercase(); 56 | let result = self 57 | .conn 58 | .call(move |conn| { 59 | let mut stmt = conn.prepare( 60 | "SELECT address FROM tokens WHERE LOWER(name) = ? OR LOWER(symbol) = ?", 61 | )?; 62 | let mut rows = stmt.query([&token, &token])?; 63 | 64 | if let Some(row) = rows.next()? { 65 | let address: String = row.get(0)?; 66 | Ok(Some(address)) 67 | } else { 68 | Ok(None) 69 | } 70 | }) 71 | .await?; 72 | 73 | match result { 74 | Some(address) => { 75 | Ok(Felt::from_hex(&address).map_err(|_| TransferError::TokenNotFound)?) 76 | } 77 | None => Err(TransferError::TokenNotFound), 78 | } 79 | } 80 | 81 | async fn lookup_recipient(&self, recipient: &str) -> Result { 82 | // First try parsing as hex 83 | if let Ok(address) = Felt::from_hex(recipient) { 84 | return Ok(address); 85 | } 86 | 87 | // Otherwise look up in accounts table 88 | let recipient = recipient.to_lowercase(); 89 | let result = self 90 | .conn 91 | .call(move |conn| { 92 | let mut stmt = 93 | conn.prepare("SELECT address FROM accounts WHERE LOWER(name) = ?")?; 94 | let mut rows = stmt.query([recipient])?; 95 | 96 | if let Some(row) = rows.next()? { 97 | let address: String = row.get(0)?; 98 | Ok(Some(address)) 99 | } else { 100 | Ok(None) 101 | } 102 | }) 103 | .await?; 104 | 105 | match result { 106 | Some(address) => { 107 | Ok(Felt::from_hex(&address).map_err(|_| TransferError::InvalidRecipient)?) 108 | } 109 | None => Err(TransferError::InvalidRecipient), 110 | } 111 | } 112 | } 113 | 114 | impl Tool for Transfer { 115 | const NAME: &'static str = "transfer"; 116 | 117 | type Error = TransferError; 118 | type Args = TransferArgs; 119 | type Output = Felt; 120 | 121 | async fn definition(&self, _prompt: String) -> ToolDefinition { 122 | ToolDefinition { 123 | name: "transfer".to_string(), 124 | description: "Transfer tokens to a recipient".to_string(), 125 | parameters: json!({ 126 | "type": "object", 127 | "properties": { 128 | "recipient": { 129 | "type": "string", 130 | "description": "The recipient address or account name" 131 | }, 132 | "amount": { 133 | "type": "string", 134 | "description": "The amount to transfer" 135 | }, 136 | "token": { 137 | "type": "string", 138 | "description": "The token name, symbol or contract address" 139 | } 140 | } 141 | }), 142 | } 143 | } 144 | 145 | async fn call(&self, args: Self::Args) -> Result { 146 | let token_address = self.lookup_token(&args.token).await?; 147 | let recipient_address = self.lookup_recipient(&args.recipient).await?; 148 | 149 | // Here we would implement the actual transfer logic 150 | // For now just return a dummy transaction hash 151 | Ok(Felt::ZERO) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /asuka-core/src/loaders/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod github; 2 | pub mod site; 3 | 4 | use serde::{Deserialize, Serialize}; 5 | use serde_json::Value; 6 | use thiserror::Error; 7 | 8 | use crate::knowledge::Document; 9 | use rig::completion::CompletionModel; 10 | 11 | #[derive(Error, Debug)] 12 | pub enum LoaderError { 13 | #[error("{0}")] 14 | FileError(#[from] rig::loaders::file::FileLoaderError), 15 | 16 | #[cfg(feature = "pdf")] 17 | #[error("{0}")] 18 | PdfError(#[from] rig::loaders::pdf::PdfLoaderError), 19 | 20 | #[error("{0}")] 21 | GitError(#[from] github::GitLoaderError), 22 | 23 | #[error("{0}")] 24 | SiteError(#[from] site::SiteLoaderError), 25 | } 26 | 27 | #[derive(Debug, Clone, Serialize, Deserialize)] 28 | #[serde(rename_all = "snake_case")] 29 | pub enum SourceType { 30 | Github, 31 | Site, 32 | File, 33 | #[cfg(feature = "pdf")] 34 | Pdf, 35 | } 36 | 37 | #[derive(Debug, Clone, Serialize, Deserialize)] 38 | pub struct DocumentMetadata { 39 | pub source_type: SourceType, 40 | pub source_url: String, 41 | #[serde(flatten)] 42 | pub extra: Option, 43 | } 44 | 45 | pub struct MultiLoaderConfig { 46 | pub sources_path: String, 47 | } 48 | 49 | pub struct MultiLoader { 50 | config: MultiLoaderConfig, 51 | model: M, 52 | } 53 | 54 | impl MultiLoader { 55 | pub fn new(config: MultiLoaderConfig, model: M) -> Self { 56 | Self { config, model } 57 | } 58 | 59 | pub async fn load_sources( 60 | &self, 61 | sources: Vec, 62 | ) -> Result, LoaderError> { 63 | let mut documents = Vec::new(); 64 | 65 | for source in sources { 66 | let parts: Vec<&str> = source.splitn(2, ':').collect(); 67 | if parts.len() != 2 { 68 | continue; 69 | } 70 | 71 | let (source_type, url) = (parts[0], parts[1]); 72 | let metadata = DocumentMetadata { 73 | source_type: match source_type { 74 | "github" => SourceType::Github, 75 | "site" => SourceType::Site, 76 | "file" => SourceType::File, 77 | #[cfg(feature = "pdf")] 78 | "pdf" => SourceType::Pdf, 79 | _ => continue, 80 | }, 81 | source_url: url.to_string(), 82 | extra: None, 83 | }; 84 | 85 | match source_type { 86 | "github" => { 87 | let repo = github::GitLoader::new(url.to_string(), &self.config.sources_path)?; 88 | documents.extend( 89 | repo.with_root()? 90 | .read_with_path() 91 | .ignore_errors() 92 | .into_iter() 93 | .map(|(path, content)| Document { 94 | id: path.to_string_lossy().to_string(), 95 | source_id: format!("github:{}", url), 96 | content, 97 | created_at: None, 98 | metadata: Some(serde_json::to_value(&metadata).unwrap()), 99 | }), 100 | ); 101 | } 102 | "site" => { 103 | let site = site::SiteLoader::new(url.to_string(), self.model.clone())?; 104 | let content = site.extract_content().await?; 105 | documents.push(Document { 106 | id: url.to_string(), 107 | source_id: format!("site:{}", url), 108 | content, 109 | created_at: None, 110 | metadata: Some(serde_json::to_value(&metadata).unwrap()), 111 | }); 112 | } 113 | "file" => { 114 | let loader = rig::loaders::file::FileLoader::with_glob(url)?; 115 | documents.extend(loader.read_with_path().ignore_errors().into_iter().map( 116 | |(path, content)| Document { 117 | id: path.to_string_lossy().to_string(), 118 | source_id: format!("file:{}", url), 119 | content, 120 | created_at: None, 121 | metadata: Some(serde_json::to_value(&metadata).unwrap()), 122 | }, 123 | )); 124 | } 125 | #[cfg(feature = "pdf")] 126 | "pdf" => { 127 | let loader = rig::loaders::pdf::PdfFileLoader::with_glob(url)?; 128 | documents.extend(loader.read_with_path().ignore_errors().into_iter().map( 129 | |(path, content)| Document { 130 | id: path.to_string_lossy().to_string(), 131 | source_id: format!("pdf:{}", url), 132 | content, 133 | created_at: None, 134 | metadata: Some(serde_json::to_value(&metadata).unwrap()), 135 | }, 136 | )); 137 | } 138 | _ => continue, 139 | } 140 | } 141 | 142 | Ok(documents.into_iter()) 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /asuka-core/src/clients/telegram.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use rig::{ 3 | completion::{CompletionModel, Prompt}, 4 | embeddings::EmbeddingModel, 5 | }; 6 | use std::collections::HashSet; 7 | use teloxide::{ 8 | dispatching::UpdateFilterExt, 9 | dptree, 10 | prelude::{LoggingErrorHandler, Requester}, 11 | }; 12 | use tracing::{debug, error, info}; 13 | 14 | use crate::{agent::Agent, attention::AttentionCommand}; 15 | use crate::{ 16 | attention::{Attention, AttentionContext}, 17 | knowledge, 18 | }; 19 | 20 | const MAX_HISTORY_MESSAGES: i64 = 10; 21 | 22 | #[derive(Clone)] 23 | pub struct TelegramClient { 24 | agent: Agent, 25 | attention: Attention, 26 | } 27 | 28 | impl TelegramClient { 29 | pub fn new(agent: Agent, attention: Attention) -> Self { 30 | Self { agent, attention } 31 | } 32 | 33 | pub async fn start(&self, token: &str) -> Result<()> { 34 | let bot = teloxide::Bot::new(token); 35 | 36 | info!("Starting telegram bot"); 37 | 38 | self.run(bot).await 39 | } 40 | } 41 | 42 | impl From for knowledge::Message { 43 | fn from(msg: teloxide::types::Message) -> Self { 44 | let user_id = msg 45 | .from 46 | .clone() 47 | .map(|u| u.id.to_string()) 48 | .unwrap_or_default(); 49 | let user_id_num = msg.from.clone().map(|u| u.id.0).unwrap_or_default(); 50 | 51 | Self { 52 | id: msg.id.to_string(), 53 | source: knowledge::Source::Telegram, 54 | source_id: user_id.clone(), 55 | channel_type: if msg.chat.id.0 == user_id_num as i64 { 56 | knowledge::ChannelType::DirectMessage 57 | } else { 58 | knowledge::ChannelType::Text 59 | }, 60 | channel_id: msg.chat.id.to_string(), 61 | account_id: user_id, 62 | role: "user".to_string(), 63 | content: msg.text().unwrap_or_default().to_string(), 64 | created_at: Some(msg.date), 65 | } 66 | } 67 | } 68 | 69 | impl TelegramClient { 70 | async fn run(&self, bot: teloxide::Bot) -> Result<()> { 71 | let knowledge = self.agent.knowledge().clone(); 72 | let attention = self.attention.clone(); 73 | let agent = self.agent.clone(); 74 | 75 | let handler = dptree::entry() 76 | .branch(teloxide::types::Update::filter_message().endpoint(move |bot: teloxide::Bot, msg: teloxide::types::Message| { 77 | let knowledge = knowledge.clone(); 78 | let attention = attention.clone(); 79 | let agent = agent.clone(); 80 | 81 | async move { 82 | let knowledge_msg = knowledge::Message::from(msg.clone()); 83 | 84 | if let Err(err) = knowledge.create_message(knowledge_msg.clone()).await { 85 | error!(?err, "Failed to store message"); 86 | return Err(anyhow::anyhow!(err)); 87 | } 88 | 89 | debug!("Fetching message history for channel {}", msg.chat.id); 90 | let history = match knowledge 91 | .channel_messages(&msg.chat.id.to_string(), MAX_HISTORY_MESSAGES) 92 | .await 93 | { 94 | Ok(messages) => { 95 | debug!(message_count = messages.len(), "Retrieved message history"); 96 | messages 97 | } 98 | Err(err) => { 99 | error!(?err, "Failed to fetch recent messages"); 100 | return Err(anyhow::anyhow!(err)); 101 | } 102 | }; 103 | 104 | let mentioned_names: HashSet = msg.text() 105 | .map(|text| { 106 | text.split_whitespace() 107 | .filter_map(|word| { 108 | if word.starts_with('@') { 109 | Some(word[1..].to_string()) 110 | } else { 111 | None 112 | } 113 | }) 114 | .collect() 115 | }) 116 | .unwrap_or_default(); 117 | 118 | debug!( 119 | mentioned_names = ?mentioned_names, 120 | "Mentioned names in message" 121 | ); 122 | 123 | let context = AttentionContext { 124 | message_content: msg.text().unwrap_or_default().to_string(), 125 | mentioned_names, 126 | history, 127 | channel_type: knowledge_msg.channel_type, 128 | source: knowledge_msg.source, 129 | }; 130 | 131 | debug!(?context, "Attention context"); 132 | 133 | match attention.should_reply(&context).await { 134 | AttentionCommand::Respond => {} 135 | _ => { 136 | debug!("Bot decided not to reply to message"); 137 | return Ok(()); 138 | } 139 | } 140 | 141 | let agent = agent 142 | .builder() 143 | .context(&format!( 144 | "Current time: {}", 145 | chrono::Local::now().format("%I:%M:%S %p, %Y-%m-%d") 146 | )) 147 | .context("Please keep your responses concise and under 2000 characters when possible.") 148 | .build(); 149 | 150 | let response = match agent.prompt(msg.text().unwrap_or_default()).await { 151 | Ok(response) => response, 152 | Err(err) => { 153 | error!(?err, "Failed to generate response"); 154 | return Err(anyhow::anyhow!(err)); 155 | } 156 | }; 157 | 158 | debug!(response = %response, "Generated response"); 159 | 160 | if let Err(why) = bot.send_message(msg.chat.id, response).await { 161 | error!(?why, "Failed to send message"); 162 | return Err(anyhow::anyhow!(why)); 163 | } 164 | 165 | Ok(()) 166 | } 167 | })); 168 | 169 | let listener = teloxide::update_listeners::polling_default(bot.clone()).await; 170 | 171 | teloxide::dispatching::Dispatcher::builder(bot, handler) 172 | .build() 173 | .dispatch_with_listener( 174 | listener, 175 | LoggingErrorHandler::with_custom_text("Failed to process Telegram update"), 176 | ) 177 | .await; 178 | 179 | Ok(()) 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /asuka-core/src/clients/twitter.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | agent::Agent, 3 | attention::{Attention, AttentionCommand, AttentionContext}, 4 | knowledge::{ChannelType, Message, Source}, 5 | }; 6 | 7 | use rig::{ 8 | completion::{CompletionModel, Prompt}, 9 | embeddings::EmbeddingModel, 10 | }; 11 | use std::collections::HashSet; 12 | use tracing::{debug, error, info}; 13 | use twitter::{authorization::Authorization, TwitterApi}; 14 | use twitter_v2::data::ReferencedTweetKind; 15 | use twitter_v2::{ 16 | self as twitter, 17 | authorization::{BearerToken, Oauth1aToken}, 18 | }; 19 | 20 | const MAX_TWEET_LENGTH: usize = 280; 21 | const MAX_HISTORY_TWEETS: i64 = 10; 22 | 23 | #[derive(Clone)] 24 | pub struct TwitterClient { 25 | agent: Agent, 26 | attention: Attention, 27 | api: TwitterApi, 28 | } 29 | 30 | impl From for Message { 31 | fn from(tweet: twitter::Tweet) -> Self { 32 | let created_at = tweet 33 | .created_at 34 | .map(|t| chrono::DateTime::from_timestamp(t.unix_timestamp(), 0).unwrap_or_default()) 35 | .unwrap_or_default(); 36 | 37 | Self { 38 | id: tweet.id.to_string(), 39 | source: Source::Twitter, 40 | source_id: tweet.id.to_string(), 41 | channel_type: ChannelType::Text, 42 | channel_id: tweet.conversation_id.unwrap_or(tweet.id).to_string(), 43 | account_id: tweet 44 | .author_id 45 | .map(|id| id.to_string()) 46 | .unwrap_or_else(|| "0".to_string()), 47 | role: "user".to_string(), 48 | content: tweet.text.clone(), 49 | created_at: Some(created_at), 50 | } 51 | } 52 | } 53 | 54 | impl TwitterClient { 55 | pub fn new(agent: Agent, attention: Attention, oauth1a_token: Oauth1aToken) -> Self { 56 | let api = TwitterApi::new(oauth1a_token); 57 | 58 | Self { 59 | agent, 60 | attention, 61 | api, 62 | } 63 | } 64 | } 65 | 66 | impl TwitterClient { 67 | pub fn new(agent: Agent, attention: Attention, bearer_token: &str) -> Self { 68 | let auth = BearerToken::new(bearer_token.to_string()); 69 | let api = TwitterApi::new(auth); 70 | 71 | Self { 72 | agent, 73 | attention, 74 | api, 75 | } 76 | } 77 | } 78 | 79 | impl 80 | TwitterClient 81 | { 82 | pub async fn start(&self) -> Result<(), Box> { 83 | info!("Starting Twitter bot"); 84 | self.listen_for_mentions().await 85 | } 86 | 87 | async fn listen_for_mentions(&self) -> Result<(), Box> { 88 | let me = self.api.get_users_me().send().await?; 89 | let user_id = me.data.as_ref().unwrap().id; 90 | 91 | // In a real implementation, you would use Twitter's streaming API 92 | // This is a simplified polling approach 93 | loop { 94 | let mentions = self 95 | .api 96 | .get_user_mentions(user_id) 97 | .max_results(5) 98 | .send() 99 | .await?; 100 | 101 | for tweet in mentions.data.clone().unwrap_or_default() { 102 | self.handle_mention(tweet).await?; 103 | } 104 | 105 | tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; 106 | } 107 | } 108 | 109 | async fn handle_mention( 110 | &self, 111 | tweet: twitter::Tweet, 112 | ) -> Result<(), Box> { 113 | let knowledge = self.agent.knowledge(); 114 | let knowledge_msg = Message::from(tweet.clone()); 115 | 116 | if let Err(err) = knowledge.create_message(knowledge_msg.clone()).await { 117 | error!(?err, "Failed to store tweet"); 118 | return Ok(()); 119 | } 120 | 121 | let thread = self.build_conversation_thread(&tweet).await?; 122 | 123 | let mentioned_names: HashSet = tweet 124 | .text 125 | .split_whitespace() 126 | .filter(|word| word.starts_with('@')) 127 | .map(|mention| mention[1..].to_string()) 128 | .collect(); 129 | 130 | debug!( 131 | mentioned_names = ?mentioned_names, 132 | "Mentioned names in tweet" 133 | ); 134 | 135 | let history = thread 136 | .iter() 137 | .map(|t| (t.id.to_string(), t.text.clone())) 138 | .collect(); 139 | 140 | let context = AttentionContext { 141 | message_content: tweet.text.clone(), 142 | mentioned_names, 143 | history, 144 | channel_type: knowledge_msg.channel_type, 145 | source: knowledge_msg.source, 146 | }; 147 | 148 | debug!(?context, "Attention context"); 149 | 150 | match self.attention.should_reply(&context).await { 151 | AttentionCommand::Respond => {} 152 | _ => { 153 | debug!("Bot decided not to reply to tweet"); 154 | return Ok(()); 155 | } 156 | } 157 | 158 | let agent = self 159 | .agent 160 | .builder() 161 | .context(&format!( 162 | "Current time: {}", 163 | chrono::Local::now().format("%I:%M:%S %p, %Y-%m-%d") 164 | )) 165 | .context("Please keep your responses concise and under 280 characters.") 166 | .build(); 167 | 168 | let response = match agent.prompt(&tweet.text).await { 169 | Ok(response) => response, 170 | Err(err) => { 171 | error!(?err, "Failed to generate response"); 172 | return Ok(()); 173 | } 174 | }; 175 | 176 | debug!(response = %response, "Generated response"); 177 | 178 | // Split response into tweet-sized chunks if necessary 179 | let chunks: Vec = response 180 | .chars() 181 | .collect::>() 182 | .chunks(MAX_TWEET_LENGTH) 183 | .map(|chunk| chunk.iter().collect::()) 184 | .collect(); 185 | 186 | // Reply to the original tweet 187 | for chunk in chunks { 188 | if let Err(err) = self 189 | .api 190 | .post_tweet() 191 | .in_reply_to_tweet_id(tweet.id) 192 | .text(chunk) 193 | .send() 194 | .await 195 | { 196 | error!(?err, "Failed to send tweet"); 197 | } 198 | } 199 | 200 | Ok(()) 201 | } 202 | 203 | async fn build_conversation_thread( 204 | &self, 205 | tweet: &twitter::Tweet, 206 | ) -> Result, Box> { 207 | let mut thread = Vec::new(); 208 | let mut current_tweet = Some(tweet.clone()); 209 | let mut depth = 0; 210 | 211 | while let Some(tweet) = current_tweet { 212 | thread.push(tweet.clone()); 213 | 214 | if depth >= MAX_HISTORY_TWEETS { 215 | break; 216 | } 217 | 218 | if let Some(referenced_tweets) = &tweet.referenced_tweets { 219 | if let Some(replied_to) = referenced_tweets 220 | .iter() 221 | .find(|t| matches!(t.kind, ReferencedTweetKind::RepliedTo)) 222 | { 223 | match self.api.get_tweet(replied_to.id).send().await { 224 | Ok(response) => { 225 | current_tweet = response.data.clone(); 226 | } 227 | Err(err) => { 228 | error!(?err, "Failed to fetch parent tweet"); 229 | break; 230 | } 231 | } 232 | } else { 233 | break; 234 | } 235 | } else { 236 | break; 237 | } 238 | 239 | depth += 1; 240 | } 241 | 242 | thread.reverse(); // Order from oldest to newest 243 | Ok(thread) 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /asuka-core/src/clients/github.rs: -------------------------------------------------------------------------------- 1 | use crate::knowledge::Document; 2 | use anyhow::{Context, Result}; 3 | use chrono::{DateTime, Utc}; 4 | use octocrab::models::{self}; 5 | use octocrab::Octocrab; 6 | use serde_json::json; 7 | 8 | #[derive(Clone)] 9 | pub struct GitHubClient { 10 | client: Octocrab, 11 | } 12 | 13 | impl GitHubClient { 14 | pub fn new(token: String) -> Self { 15 | Self { 16 | client: Octocrab::builder() 17 | .personal_token(token) 18 | .build() 19 | .expect("Failed to create GitHub client"), 20 | } 21 | } 22 | 23 | pub async fn fetch_org_repos(&self, org: &str) -> Result> { 24 | let repos = self 25 | .client 26 | .orgs(org) 27 | .list_repos() 28 | .send() 29 | .await 30 | .context("Failed to fetch organization repositories")? 31 | .items; 32 | 33 | let mut documents = Vec::new(); 34 | for repo in repos { 35 | let repo_name = repo.full_name.as_deref().unwrap_or(&repo.name); 36 | let html_url = repo 37 | .html_url 38 | .as_ref() 39 | .map(|url| url.to_string()) 40 | .unwrap_or_default(); 41 | 42 | let content = format!( 43 | "Repository: {}\nDescription: {}\nURL: {}\nCreated: {}\nLast Updated: {}", 44 | repo_name, 45 | repo.description.as_deref().unwrap_or("No description"), 46 | html_url, 47 | repo.created_at.unwrap_or_default(), 48 | repo.updated_at.unwrap_or_default() 49 | ); 50 | 51 | documents.push(Document { 52 | id: format!("github:repo:{}", repo_name), 53 | source_id: format!("github:{}", org), 54 | content, 55 | created_at: repo.created_at, 56 | metadata: Some(json!(repo)), 57 | }); 58 | } 59 | 60 | Ok(documents) 61 | } 62 | 63 | pub async fn fetch_repo_pulls( 64 | &self, 65 | owner: &str, 66 | repo: &str, 67 | since: DateTime, 68 | ) -> Result> { 69 | let pulls = self 70 | .client 71 | .pulls(owner, repo) 72 | .list() 73 | .state(octocrab::params::State::All) 74 | .sort(octocrab::params::pulls::Sort::Updated) 75 | .direction(octocrab::params::Direction::Descending) 76 | .send() 77 | .await 78 | .context("Failed to fetch pull requests")? 79 | .items 80 | .into_iter() 81 | .filter(|pr| pr.updated_at.map(|d| d >= since).unwrap_or(false)) 82 | .collect::>(); 83 | 84 | let mut documents = Vec::new(); 85 | for pr in pulls { 86 | let content = format!( 87 | "Pull Request: #{} - {}\nAuthor: @{}\nState: {}\nURL: {}\nCreated: {}\nLast Updated: {}\n\n{}", 88 | pr.number, 89 | pr.title.as_deref().unwrap_or_default(), 90 | pr.user.as_ref().map(|u| u.login.clone()).unwrap_or_default(), 91 | pr.state.as_ref().map(|s| format!("{:?}", s)).unwrap_or_else(|| "unknown".to_string()), 92 | pr.html_url.as_ref().map(|url| url.to_string()).unwrap_or_default(), 93 | pr.created_at.unwrap_or_default(), 94 | pr.updated_at.unwrap_or_default(), 95 | pr.body.as_deref().unwrap_or_default() 96 | ); 97 | 98 | documents.push(Document { 99 | id: format!("github:pr:{}:{}/{}", owner, repo, pr.number), 100 | source_id: format!("github:{}/{}", owner, repo), 101 | content, 102 | created_at: pr.created_at, 103 | metadata: Some(json!(pr)), 104 | }); 105 | } 106 | 107 | Ok(documents) 108 | } 109 | 110 | pub async fn fetch_repo_issues( 111 | &self, 112 | owner: &str, 113 | repo: &str, 114 | since: DateTime, 115 | ) -> Result> { 116 | let issues = self 117 | .client 118 | .issues(owner, repo) 119 | .list() 120 | .state(octocrab::params::State::All) 121 | .sort(octocrab::params::issues::Sort::Updated) 122 | .direction(octocrab::params::Direction::Descending) 123 | .send() 124 | .await 125 | .context("Failed to fetch issues")? 126 | .items 127 | .into_iter() 128 | .filter(|issue| issue.updated_at >= since && issue.pull_request.is_none()) 129 | .collect::>(); 130 | 131 | let mut documents = Vec::new(); 132 | for issue in issues { 133 | let content = format!( 134 | "Issue: #{} - {}\nAuthor: @{}\nState: {}\nURL: {}\nCreated: {}\nLast Updated: {}\n\n{}", 135 | issue.number, 136 | issue.title, 137 | issue.user.login, 138 | format!("{:?}", issue.state), 139 | issue.html_url, 140 | issue.created_at, 141 | issue.updated_at, 142 | issue.body.as_deref().unwrap_or_default() 143 | ); 144 | 145 | documents.push(Document { 146 | id: format!("github:issue:{}:{}/{}", owner, repo, issue.number), 147 | source_id: format!("github:{}/{}", owner, repo), 148 | content, 149 | created_at: Some(issue.created_at), 150 | metadata: Some(json!(issue)), 151 | }); 152 | } 153 | 154 | Ok(documents) 155 | } 156 | 157 | pub async fn fetch_repo_commits( 158 | &self, 159 | owner: &str, 160 | repo: &str, 161 | since: DateTime, 162 | ) -> Result> { 163 | let commits = self 164 | .client 165 | .repos(owner, repo) 166 | .list_commits() 167 | .since(since) 168 | .send() 169 | .await 170 | .context("Failed to fetch commits")? 171 | .items; 172 | 173 | let mut documents = Vec::new(); 174 | for commit in commits { 175 | let author_date = commit.commit.author.as_ref().and_then(|a| a.date); 176 | 177 | let author_name = commit 178 | .author 179 | .as_ref() 180 | .map(|a| format!("@{}", a.login)) 181 | .unwrap_or_else(|| { 182 | commit 183 | .commit 184 | .author 185 | .as_ref() 186 | .map(|a| a.name.clone()) 187 | .unwrap_or_default() 188 | }); 189 | 190 | let content = format!( 191 | "Commit: {}\nAuthor: {}\nDate: {}\nURL: {}\n\n{}", 192 | commit.sha, 193 | author_name, 194 | author_date.unwrap_or_default(), 195 | commit.html_url, 196 | commit.commit.message 197 | ); 198 | 199 | documents.push(Document { 200 | id: format!("github:commit:{}:{}/{}", owner, repo, commit.sha), 201 | source_id: format!("github:{}/{}", owner, repo), 202 | content, 203 | created_at: author_date, 204 | metadata: Some(json!(commit)), 205 | }); 206 | } 207 | 208 | Ok(documents) 209 | } 210 | 211 | pub async fn fetch_org_activity( 212 | &self, 213 | org: &str, 214 | since: DateTime, 215 | ) -> Result> { 216 | let mut documents = Vec::new(); 217 | let repos = self.fetch_org_repos(org).await?; 218 | documents.extend(repos.clone()); 219 | 220 | for repo in repos { 221 | if let Some(metadata) = repo.metadata { 222 | if let Ok(repo_obj) = serde_json::from_value::(metadata) { 223 | let repo_name = repo_obj.full_name.as_deref().unwrap_or(&repo_obj.name); 224 | let (owner, name) = repo_name.split_once('/').unwrap(); 225 | 226 | let pulls = self.fetch_repo_pulls(owner, name, since).await?; 227 | documents.extend(pulls); 228 | 229 | let issues = self.fetch_repo_issues(owner, name, since).await?; 230 | documents.extend(issues); 231 | 232 | let commits = self.fetch_repo_commits(owner, name, since).await?; 233 | documents.extend(commits); 234 | } 235 | } 236 | } 237 | 238 | Ok(documents) 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /asuka-core/src/clients/discord.rs: -------------------------------------------------------------------------------- 1 | use rig::{ 2 | completion::{CompletionModel, Prompt}, 3 | embeddings::EmbeddingModel, 4 | }; 5 | use serenity::async_trait; 6 | use serenity::model::channel::Message; 7 | use serenity::model::gateway::GatewayIntents; 8 | use serenity::model::gateway::Ready; 9 | use serenity::prelude::*; 10 | use std::collections::HashSet; 11 | use tracing::{debug, error, info}; 12 | 13 | use crate::{agent::Agent, attention::AttentionCommand}; 14 | use crate::{ 15 | attention::{Attention, AttentionContext}, 16 | knowledge, 17 | }; 18 | 19 | const MIN_CHUNK_LENGTH: usize = 100; 20 | const MAX_MESSAGE_LENGTH: usize = 1500; 21 | const MAX_HISTORY_MESSAGES: i64 = 10; 22 | 23 | #[derive(Clone)] 24 | pub struct DiscordClient { 25 | agent: Agent, 26 | attention: Attention, 27 | } 28 | 29 | impl DiscordClient { 30 | pub fn new(agent: Agent, attention: Attention) -> Self { 31 | Self { agent, attention } 32 | } 33 | 34 | pub async fn start(&self, token: &str) -> Result<(), serenity::Error> { 35 | let intents = GatewayIntents::GUILD_MESSAGES 36 | | GatewayIntents::DIRECT_MESSAGES 37 | | GatewayIntents::MESSAGE_CONTENT; 38 | 39 | let mut client = Client::builder(token, intents) 40 | .event_handler(self.clone()) 41 | .await?; 42 | 43 | info!("Starting discord bot"); 44 | client.start().await 45 | } 46 | } 47 | 48 | impl From for knowledge::Message { 49 | fn from(msg: Message) -> Self { 50 | Self { 51 | id: msg.id.to_string(), 52 | source: knowledge::Source::Discord, 53 | source_id: msg.author.id.to_string(), 54 | channel_type: if msg.guild_id.is_none() { 55 | knowledge::ChannelType::DirectMessage 56 | } else { 57 | knowledge::ChannelType::Text 58 | }, 59 | channel_id: msg.channel_id.to_string(), 60 | account_id: msg.author.id.to_string(), 61 | role: "user".to_string(), 62 | content: msg.content.clone(), 63 | created_at: Some(*msg.timestamp), 64 | } 65 | } 66 | } 67 | 68 | #[async_trait] 69 | impl EventHandler 70 | for DiscordClient 71 | { 72 | async fn message(&self, ctx: Context, msg: Message) { 73 | if msg.author.bot { 74 | return; 75 | } 76 | 77 | let knowledge = self.agent.knowledge(); 78 | let knowledge_msg = knowledge::Message::from(msg.clone()); 79 | 80 | if let Err(err) = knowledge 81 | .clone() 82 | .create_message(knowledge_msg.clone()) 83 | .await 84 | { 85 | error!(?err, "Failed to store message"); 86 | return; 87 | } 88 | 89 | debug!("Fetching message history for channel {}", msg.channel_id); 90 | let history = match knowledge 91 | .channel_messages(&msg.channel_id.to_string(), MAX_HISTORY_MESSAGES) 92 | .await 93 | { 94 | Ok(messages) => { 95 | debug!(message_count = messages.len(), "Retrieved message history"); 96 | messages 97 | } 98 | Err(err) => { 99 | error!(?err, "Failed to fetch recent messages"); 100 | return; 101 | } 102 | }; 103 | 104 | let mentioned_names: HashSet = 105 | msg.mentions.iter().map(|user| user.name.clone()).collect(); 106 | debug!( 107 | mentioned_names = ?mentioned_names, 108 | "Mentioned names in message" 109 | ); 110 | 111 | let context = AttentionContext { 112 | message_content: msg.content.clone(), 113 | mentioned_names, 114 | history, 115 | channel_type: knowledge_msg.channel_type, 116 | source: knowledge_msg.source, 117 | }; 118 | 119 | debug!(?context, "Attention context"); 120 | 121 | match self.attention.should_reply(&context).await { 122 | AttentionCommand::Respond => {} 123 | _ => { 124 | debug!("Bot decided not to reply to message"); 125 | return; 126 | } 127 | } 128 | 129 | let agent = self 130 | .agent 131 | .builder() 132 | .context(&format!( 133 | "Current time: {}", 134 | chrono::Local::now().format("%I:%M:%S %p, %Y-%m-%d") 135 | )) 136 | .context("Please keep your responses concise and under 2000 characters when possible.") 137 | .build(); 138 | 139 | let response = match agent.prompt(&msg.content).await { 140 | Ok(response) => response, 141 | Err(err) => { 142 | error!(?err, "Failed to generate response"); 143 | return; 144 | } 145 | }; 146 | 147 | debug!(response = %response, "Generated response"); 148 | 149 | let chunks = chunk_message(&response, MAX_MESSAGE_LENGTH, MIN_CHUNK_LENGTH); 150 | 151 | for chunk in chunks { 152 | if let Err(why) = msg.channel_id.say(&ctx.http, chunk).await { 153 | error!(?why, "Failed to send message"); 154 | } 155 | } 156 | } 157 | 158 | async fn ready(&self, _: Context, ready: Ready) { 159 | info!(name = self.agent.character.name, "Bot connected"); 160 | info!(guild_count = ready.guilds.len(), "Serving guilds"); 161 | } 162 | } 163 | 164 | pub fn chunk_message(text: &str, max_length: usize, min_chunk_length: usize) -> Vec { 165 | // Base case: if text is shorter than min_chunk_length, return as single chunk 166 | if text.len() <= min_chunk_length { 167 | return vec![text.to_string()]; 168 | } 169 | 170 | let mut chunks = Vec::new(); 171 | 172 | // Find split point for current chunk 173 | let mut split_index = text.len(); 174 | let mut in_heading = false; 175 | 176 | for (i, line) in text.lines().enumerate() { 177 | let line = line.trim(); 178 | if line.is_empty() { 179 | continue; 180 | } 181 | 182 | // Start new chunk on headings 183 | if line.starts_with('#') && i > 0 { 184 | split_index = text.find(line).unwrap_or(text.len()); 185 | in_heading = true; 186 | break; 187 | } 188 | 189 | // Check if adding this line would exceed max_length 190 | let line_start = text.find(line).unwrap_or(text.len()); 191 | if line_start + line.len() > max_length && i > 0 { 192 | split_index = line_start; 193 | break; 194 | } 195 | } 196 | 197 | // Split text and recurse 198 | if split_index < text.len() { 199 | let (chunk, rest) = text.split_at(split_index); 200 | let mut chunk = chunk.trim().to_string(); 201 | 202 | // Add newline after chunk if we're not splitting on a heading 203 | if !in_heading && !rest.trim().starts_with('#') { 204 | chunk.push('\n'); 205 | } 206 | 207 | // Strip trailing newline if it's the last character 208 | if chunk.ends_with('\n') { 209 | chunk.pop(); 210 | } 211 | 212 | chunks.push(chunk); 213 | chunks.extend(chunk_message(rest.trim(), max_length, min_chunk_length)); 214 | } else { 215 | chunks.push(text.trim().to_string()); 216 | } 217 | 218 | chunks 219 | } 220 | 221 | #[cfg(test)] 222 | mod tests { 223 | use super::*; 224 | 225 | #[test] 226 | fn test_chunk_message_single_chunk() { 227 | let text = "This is a short message"; 228 | let chunks = chunk_message(text, 100, 1000); 229 | assert_eq!(chunks.len(), 1); 230 | assert_eq!(chunks[0], text); 231 | } 232 | 233 | #[test] 234 | fn test_chunk_message_multiple_chunks() { 235 | let text = "Line 1\nLine 2\nLine 3"; 236 | let chunks = chunk_message(text, 10, 5); 237 | assert_eq!(chunks.len(), 3); 238 | assert_eq!(chunks[0], "Line 1"); 239 | assert_eq!(chunks[1], "Line 2"); 240 | assert_eq!(chunks[2], "Line 3"); 241 | } 242 | 243 | #[test] 244 | fn test_chunk_message_empty_lines() { 245 | let text = "Line 1\n\n\nLine 2"; 246 | let chunks = chunk_message(text, 100, 1000); 247 | assert_eq!(chunks.len(), 1); 248 | assert_eq!(chunks[0], "Line 1\n\n\nLine 2"); 249 | } 250 | 251 | #[test] 252 | fn test_chunk_message_markdown() { 253 | let text = "# Heading 1\nSome text under heading 1\n## Heading 2\nMore text\n# Heading 3\nFinal text"; 254 | let chunks = chunk_message(text, 100, 50); 255 | assert_eq!(chunks.len(), 2); 256 | assert_eq!(chunks[0], "# Heading 1\nSome text under heading 1"); 257 | assert_eq!( 258 | chunks[1], 259 | "## Heading 2\nMore text\n# Heading 3\nFinal text" 260 | ); 261 | } 262 | 263 | #[test] 264 | fn test_no_chunking_under_min_length() { 265 | let text = "This is a message that won't be chunked because it's under the minimum length"; 266 | let chunks = chunk_message(text, 10, 1000); 267 | assert_eq!(chunks.len(), 1); 268 | assert_eq!(chunks[0], text); 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /asuka-core/src/knowledge/models.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use super::types::{ChannelType, Source}; 4 | use chrono::{DateTime, NaiveDateTime, Utc}; 5 | use rig::Embed; 6 | use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable}; 7 | use rusqlite::Row; 8 | use serde::{Deserialize, Deserializer}; 9 | use serde_json::Value; 10 | 11 | #[derive(Embed, Clone, Debug)] 12 | pub struct Document { 13 | pub id: String, 14 | pub source_id: String, 15 | #[embed] 16 | pub content: String, 17 | pub created_at: Option>, 18 | pub metadata: Option, 19 | } 20 | 21 | #[derive(Debug, serde::Deserialize)] 22 | pub struct Account { 23 | pub id: i64, 24 | pub source_id: String, 25 | pub name: String, 26 | pub source: String, 27 | pub created_at: Option>, 28 | pub updated_at: Option>, 29 | } 30 | 31 | #[derive(Debug, serde::Deserialize)] 32 | pub struct Conversation { 33 | pub id: String, 34 | pub user_id: String, 35 | pub title: String, 36 | pub created_at: Option>, 37 | pub updated_at: Option>, 38 | } 39 | 40 | #[derive(Embed, Clone, Debug, serde::Deserialize)] 41 | pub struct Message { 42 | pub id: String, 43 | pub source: Source, 44 | pub source_id: String, 45 | pub channel_type: ChannelType, 46 | pub channel_id: String, 47 | pub account_id: String, 48 | pub role: String, 49 | #[embed] 50 | pub content: String, 51 | #[serde(deserialize_with = "deserialize_datetime")] 52 | pub created_at: Option>, 53 | } 54 | 55 | fn deserialize_datetime<'de, D>(deserializer: D) -> Result>, D::Error> 56 | where 57 | D: Deserializer<'de>, 58 | { 59 | let s = Option::::deserialize(deserializer)?; 60 | s.map(|date_str| { 61 | NaiveDateTime::parse_from_str(&date_str, "%Y-%m-%d %H:%M:%S") 62 | .map(|naive_dt| DateTime::::from_naive_utc_and_offset(naive_dt, Utc)) 63 | .map_err(serde::de::Error::custom) 64 | }) 65 | .transpose() 66 | } 67 | #[derive(Debug, Clone, serde::Deserialize)] 68 | pub struct Channel { 69 | pub id: String, 70 | pub channel_id: String, 71 | pub channel_type: String, 72 | pub source: String, 73 | pub name: String, 74 | pub created_at: Option>, 75 | pub updated_at: Option>, 76 | } 77 | 78 | // Implement the table traits 79 | impl SqliteVectorStoreTable for Document { 80 | fn name() -> &'static str { 81 | "documents" 82 | } 83 | 84 | fn schema() -> Vec { 85 | vec![ 86 | Column::new("id", "TEXT PRIMARY KEY"), 87 | Column::new("source_id", "TEXT").indexed(), 88 | Column::new("content", "TEXT"), 89 | Column::new("created_at", "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), 90 | Column::new("metadata", "TEXT"), 91 | ] 92 | } 93 | 94 | fn id(&self) -> String { 95 | self.id.clone() 96 | } 97 | 98 | fn column_values(&self) -> Vec<(&'static str, Box)> { 99 | vec![ 100 | ("id", Box::new(self.id.clone())), 101 | ("source_id", Box::new(self.source_id.clone())), 102 | ("content", Box::new(self.content.clone())), 103 | ( 104 | "metadata", 105 | Box::new( 106 | self.metadata 107 | .as_ref() 108 | .map(|m| serde_json::to_string(m).unwrap_or_default()) 109 | .unwrap_or_default(), 110 | ), 111 | ), 112 | ] 113 | } 114 | } 115 | 116 | impl SqliteVectorStoreTable for Message { 117 | fn name() -> &'static str { 118 | "messages" 119 | } 120 | 121 | fn schema() -> Vec { 122 | vec![ 123 | Column::new("id", "TEXT PRIMARY KEY"), 124 | Column::new("source", "TEXT"), 125 | Column::new("source_id", "TEXT").indexed(), 126 | Column::new("channel_type", "TEXT"), 127 | Column::new("channel_id", "TEXT").indexed(), 128 | Column::new("account_id", "TEXT").indexed(), 129 | Column::new("role", "TEXT"), 130 | Column::new("content", "TEXT"), 131 | Column::new("created_at", "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), 132 | ] 133 | } 134 | 135 | fn id(&self) -> String { 136 | self.id.clone() 137 | } 138 | 139 | fn column_values(&self) -> Vec<(&'static str, Box)> { 140 | vec![ 141 | ("id", Box::new(self.id.clone())), 142 | ("source", Box::new(self.source.as_str().to_string())), 143 | ("source_id", Box::new(self.source_id.clone())), 144 | ( 145 | "channel_type", 146 | Box::new(self.channel_type.as_str().to_string()), 147 | ), 148 | ("channel_id", Box::new(self.channel_id.clone())), 149 | ("account_id", Box::new(self.account_id.clone())), 150 | ("role", Box::new(self.role.clone())), 151 | ("content", Box::new(self.content.clone())), 152 | ] 153 | } 154 | } 155 | 156 | impl TryFrom<&Row<'_>> for Document { 157 | type Error = rusqlite::Error; 158 | 159 | fn try_from(row: &Row) -> Result { 160 | let metadata_str: Option = row.get(4)?; 161 | let metadata = metadata_str 162 | .map(|s| serde_json::from_str(&s)) 163 | .transpose() 164 | .map_err(|e| { 165 | rusqlite::Error::FromSqlConversionFailure( 166 | 4, 167 | rusqlite::types::Type::Text, 168 | Box::new(e), 169 | ) 170 | })?; 171 | 172 | Ok(Document { 173 | id: row.get(0)?, 174 | source_id: row.get(1)?, 175 | content: row.get(2)?, 176 | created_at: row.get(3)?, 177 | metadata, 178 | }) 179 | } 180 | } 181 | 182 | impl TryFrom<&Row<'_>> for Account { 183 | type Error = rusqlite::Error; 184 | 185 | fn try_from(row: &Row) -> Result { 186 | Ok(Account { 187 | id: row.get(0)?, 188 | name: row.get(1)?, 189 | source_id: row.get(2)?, 190 | source: row.get(3)?, 191 | created_at: row.get(4)?, 192 | updated_at: row.get(5)?, 193 | }) 194 | } 195 | } 196 | 197 | impl TryFrom<&Row<'_>> for Conversation { 198 | type Error = rusqlite::Error; 199 | 200 | fn try_from(row: &Row) -> Result { 201 | Ok(Conversation { 202 | id: row.get(0)?, 203 | user_id: row.get(1)?, 204 | title: row.get(2)?, 205 | created_at: row.get(3)?, 206 | updated_at: row.get(4)?, 207 | }) 208 | } 209 | } 210 | 211 | impl TryFrom<&Row<'_>> for Message { 212 | type Error = rusqlite::Error; 213 | 214 | fn try_from(row: &Row) -> Result { 215 | Ok(Message { 216 | id: row.get(0)?, 217 | source: Source::from_str(&row.get::<_, String>(1)?).map_err(|_| { 218 | rusqlite::Error::FromSqlConversionFailure( 219 | 1, 220 | rusqlite::types::Type::Text, 221 | Box::new(super::error::ConversionError("Invalid source".to_string())), 222 | ) 223 | })?, 224 | source_id: row.get(2)?, 225 | channel_type: ChannelType::from_str(&row.get::<_, String>(3)?).map_err(|_| { 226 | rusqlite::Error::FromSqlConversionFailure( 227 | 3, 228 | rusqlite::types::Type::Text, 229 | Box::new(super::error::ConversionError( 230 | "Invalid channel type".to_string(), 231 | )), 232 | ) 233 | })?, 234 | channel_id: row.get(4)?, 235 | account_id: row.get(5)?, 236 | role: row.get(6)?, 237 | content: row.get(7)?, 238 | created_at: row.get(8)?, 239 | }) 240 | } 241 | } 242 | 243 | impl TryFrom<&Row<'_>> for Channel { 244 | type Error = rusqlite::Error; 245 | 246 | fn try_from(row: &Row) -> Result { 247 | Ok(Channel { 248 | id: row.get(0)?, 249 | channel_id: row.get(1)?, 250 | channel_type: row.get(2)?, 251 | source: row.get(3)?, 252 | name: row.get(4)?, 253 | created_at: row.get(5)?, 254 | updated_at: row.get(6)?, 255 | }) 256 | } 257 | } 258 | 259 | impl SqliteVectorStoreTable for Channel { 260 | fn name() -> &'static str { 261 | "channels" 262 | } 263 | 264 | fn schema() -> Vec { 265 | vec![ 266 | Column::new("id", "TEXT PRIMARY KEY"), 267 | Column::new("name", "TEXT"), 268 | Column::new("source", "TEXT"), 269 | Column::new("created_at", "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), 270 | Column::new("updated_at", "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"), 271 | ] 272 | } 273 | 274 | fn id(&self) -> String { 275 | self.id.clone() 276 | } 277 | 278 | fn column_values(&self) -> Vec<(&'static str, Box)> { 279 | vec![ 280 | ("id", Box::new(self.id.clone())), 281 | ("name", Box::new(self.name.clone())), 282 | ("source", Box::new(self.source.clone())), 283 | ] 284 | } 285 | } 286 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Tarrence van As 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /asuka-core/src/knowledge/store.rs: -------------------------------------------------------------------------------- 1 | use rig::{ 2 | embeddings::{EmbeddingModel, EmbeddingsBuilder}, 3 | vector_store::VectorStoreError, 4 | }; 5 | use tokio_rusqlite::Connection; 6 | use tracing::{debug, info}; 7 | 8 | use super::models::{Account, Channel, Document, Message}; 9 | use rig_sqlite::{SqliteError, SqliteVectorIndex, SqliteVectorStore}; 10 | use rusqlite::OptionalExtension; 11 | 12 | #[derive(Clone)] 13 | pub struct KnowledgeBase { 14 | pub conn: Connection, 15 | document_store: SqliteVectorStore, 16 | message_store: SqliteVectorStore, 17 | embedding_model: E, 18 | } 19 | 20 | impl KnowledgeBase { 21 | pub async fn new(conn: Connection, embedding_model: E) -> Result { 22 | let document_store = SqliteVectorStore::new(conn.clone(), &embedding_model).await?; 23 | let message_store = SqliteVectorStore::new(conn.clone(), &embedding_model).await?; 24 | 25 | conn.call(|conn| { 26 | conn.execute_batch( 27 | "BEGIN; 28 | 29 | -- User management tables 30 | CREATE TABLE IF NOT EXISTS accounts ( 31 | id INTEGER PRIMARY KEY AUTOINCREMENT, 32 | name TEXT NOT NULL, 33 | source_id TEXT NOT NULL UNIQUE, 34 | source TEXT NOT NULL, 35 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 36 | updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 37 | ); 38 | CREATE INDEX IF NOT EXISTS idx_source_id_source ON accounts(source_id, source); 39 | 40 | -- Channel tables 41 | CREATE TABLE IF NOT EXISTS channels ( 42 | id INTEGER PRIMARY KEY AUTOINCREMENT, 43 | channel_id TEXT NOT NULL UNIQUE, 44 | channel_type TEXT NOT NULL, 45 | source TEXT NOT NULL, 46 | name TEXT, 47 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 48 | updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 49 | ); 50 | CREATE INDEX IF NOT EXISTS idx_channel_id_type ON channels(channel_id, channel_type); 51 | 52 | COMMIT;" 53 | ) 54 | .map_err(tokio_rusqlite::Error::from) 55 | }) 56 | .await 57 | .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; 58 | 59 | Ok(Self { 60 | conn, 61 | document_store, 62 | message_store, 63 | embedding_model, 64 | }) 65 | } 66 | 67 | pub async fn create_user( 68 | &self, 69 | name: String, 70 | source: String, 71 | source_id: String, 72 | ) -> Result { 73 | self.conn 74 | .call(move |conn| { 75 | conn.query_row( 76 | "INSERT INTO accounts (name, source, created_at, updated_at, source_id) 77 | VALUES (?1, ?2, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, ?3) 78 | ON CONFLICT(source_id) DO UPDATE SET 79 | updated_at = CURRENT_TIMESTAMP 80 | RETURNING id", 81 | rusqlite::params![name, source, source_id], 82 | |row| row.get(0), 83 | ) 84 | .map_err(tokio_rusqlite::Error::from) 85 | }) 86 | .await 87 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 88 | } 89 | 90 | pub fn document_index(&self) -> SqliteVectorIndex { 91 | SqliteVectorIndex::new(self.embedding_model.clone(), self.document_store.clone()) 92 | } 93 | 94 | pub fn message_index(&self) -> SqliteVectorIndex { 95 | SqliteVectorIndex::new(self.embedding_model.clone(), self.message_store.clone()) 96 | } 97 | 98 | pub async fn get_user_by_source(&self, source: String) -> Result, SqliteError> { 99 | self.conn 100 | .call(move |conn| { 101 | let mut stmt = conn.prepare("SELECT * FROM accounts WHERE source = ?1")?; 102 | 103 | let account = stmt 104 | .query_row(rusqlite::params![source], |row| { 105 | Account::try_from(row).map_err(rusqlite::Error::from) 106 | }) 107 | .optional()?; 108 | 109 | Ok(account) 110 | }) 111 | .await 112 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 113 | } 114 | 115 | pub async fn get_account_by_account_id( 116 | &self, 117 | account_id: String, 118 | ) -> Result, SqliteError> { 119 | self.conn 120 | .call(move |conn| { 121 | Ok(conn 122 | .query_row( 123 | "SELECT * FROM accounts WHERE source_id = ?1", 124 | rusqlite::params![account_id], 125 | |row| Account::try_from(row), 126 | ) 127 | .optional()?) 128 | }) 129 | .await 130 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 131 | } 132 | 133 | pub async fn create_channel( 134 | &self, 135 | channel_id: String, 136 | channel_type: String, 137 | name: Option, 138 | source: String, 139 | ) -> Result { 140 | self.conn 141 | .call(move |conn| { 142 | conn.query_row( 143 | "INSERT INTO channels (channel_id, channel_type, source, name, created_at, updated_at) 144 | VALUES (?1, ?2, ?3, ?4, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) 145 | ON CONFLICT(channel_id) DO UPDATE SET 146 | name = COALESCE(?4, name), 147 | updated_at = CURRENT_TIMESTAMP 148 | RETURNING id", 149 | rusqlite::params![channel_id, channel_type, source, name], 150 | |row| row.get(0), 151 | ) 152 | .map_err(tokio_rusqlite::Error::from) 153 | }) 154 | .await 155 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 156 | } 157 | 158 | pub async fn get_channel(&self, id: i64) -> Result, SqliteError> { 159 | self.conn 160 | .call(move |conn| { 161 | let mut stmt = conn.prepare( 162 | "SELECT id, channel_id, channel_type, source, name, created_at, updated_at FROM channels WHERE id = ?1", 163 | )?; 164 | 165 | let channel = stmt 166 | .query_row(rusqlite::params![id], |row| Channel::try_from(row)) 167 | .optional()?; 168 | 169 | Ok(channel) 170 | }) 171 | .await 172 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 173 | } 174 | 175 | pub async fn get_channel_by_channel_id( 176 | &self, 177 | channel_id: &str, 178 | ) -> Result, SqliteError> { 179 | let channel_id = channel_id.to_string(); 180 | 181 | self.conn 182 | .call(move |conn| { 183 | let result = conn 184 | .prepare("SELECT * FROM channels WHERE channel_id = ?1")? 185 | .query_row(rusqlite::params![channel_id], |row| Channel::try_from(row)) 186 | .optional()?; 187 | 188 | Ok(result) 189 | }) 190 | .await 191 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 192 | } 193 | 194 | pub async fn get_channels_by_source( 195 | &self, 196 | source: String, 197 | ) -> Result, SqliteError> { 198 | self.conn 199 | .call(move |conn| { 200 | let mut stmt = conn.prepare( 201 | "SELECT id, name, source, created_at, updated_at FROM channels WHERE source = ?1" 202 | )?; 203 | 204 | let channels = stmt.query_map(rusqlite::params![source], |row| { 205 | Channel::try_from(row) 206 | }).and_then(|mapped_rows| { 207 | mapped_rows.collect::, _>>() 208 | })?; 209 | 210 | Ok(channels) 211 | }) 212 | .await 213 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 214 | } 215 | 216 | pub async fn create_message_without_embeddings(&self, msg: Message) -> Result<(), SqliteError> { 217 | self.conn 218 | .call(move |conn| { 219 | conn.execute( 220 | "INSERT INTO messages (id, source, source_id, channel_type, channel_id, account_id, content, role, created_at) 221 | VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, CURRENT_TIMESTAMP) 222 | ON CONFLICT (id) DO UPDATE SET 223 | content = ?7", 224 | rusqlite::params![ 225 | msg.id, 226 | msg.source.as_str(), 227 | msg.source_id, 228 | msg.channel_type.as_str(), 229 | msg.channel_id, 230 | msg.account_id, 231 | msg.content, 232 | msg.role 233 | ], 234 | ) 235 | .map_err(tokio_rusqlite::Error::from) 236 | }) 237 | .await 238 | .map_err(|e| SqliteError::DatabaseError(Box::new(e)))?; 239 | Ok(()) 240 | } 241 | 242 | pub async fn create_message(&self, msg: Message) -> anyhow::Result { 243 | let embeddings = EmbeddingsBuilder::new(self.embedding_model.clone()) 244 | .documents(vec![msg.clone()])? 245 | .build() 246 | .await?; 247 | 248 | let store = self.message_store.clone(); 249 | 250 | self.conn 251 | .call(move |conn| { 252 | let tx = conn.transaction()?; 253 | 254 | tx.execute( 255 | "INSERT INTO messages (id, channel_id, account_id, content, role, created_at) 256 | VALUES (?1, ?2, ?3, ?4, ?5, CURRENT_TIMESTAMP) 257 | ON CONFLICT (id) DO UPDATE SET 258 | channel_id = ?2, 259 | account_id = ?3, 260 | content = ?4, 261 | role = ?5, 262 | created_at = CURRENT_TIMESTAMP", 263 | [ 264 | &msg.id, 265 | &msg.channel_id, 266 | &msg.account_id, 267 | &msg.content, 268 | &msg.role, 269 | ], 270 | )?; 271 | 272 | let id = store.add_rows_with_txn(&tx, embeddings)?; 273 | 274 | tx.commit()?; 275 | 276 | Ok(id) 277 | }) 278 | .await 279 | .map_err(|e| anyhow::anyhow!(e)) 280 | } 281 | 282 | pub async fn get_message(&self, id: i64) -> Result, SqliteError> { 283 | self.conn 284 | .call(move |conn| { 285 | Ok(conn.prepare("SELECT id, source, source_id, channel_type, channel_id, account_id, role, content, created_at FROM messages WHERE id = ?1")? 286 | .query_row(rusqlite::params![id], |row| { 287 | Message::try_from(row) 288 | }).optional()?) 289 | }) 290 | .await 291 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 292 | } 293 | 294 | pub async fn get_recent_messages_in_channel( 295 | &self, 296 | channel_id: String, 297 | limit: usize, 298 | ) -> Result, SqliteError> { 299 | self.conn 300 | .call(move |conn| { 301 | let mut stmt = conn.prepare( 302 | "SELECT id, source, source_id, channel_type, channel_id, account_id, role, content, created_at 303 | FROM messages 304 | WHERE channel_id = ?1 305 | ORDER BY created_at DESC 306 | LIMIT ?2", 307 | )?; 308 | 309 | let messages = stmt 310 | .query_map(rusqlite::params![channel_id, limit], |row| { 311 | Message::try_from(row) 312 | })? 313 | .collect::, _>>()?; 314 | 315 | Ok(messages) 316 | }) 317 | .await 318 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 319 | } 320 | 321 | pub async fn get_recent_messages(&self, limit: usize) -> Result, SqliteError> { 322 | self.conn 323 | .call(move |conn| { 324 | let mut stmt = conn.prepare( 325 | "SELECT id, source, source_id, channel_type, channel_id, account_id, role, content, created_at 326 | FROM messages 327 | ORDER BY created_at DESC 328 | LIMIT ?1", 329 | )?; 330 | 331 | let messages = stmt 332 | .query_map(rusqlite::params![limit], |row| { 333 | Message::try_from(row) 334 | })? 335 | .collect::, _>>()?; 336 | 337 | Ok(messages) 338 | }) 339 | .await 340 | .map_err(|e| SqliteError::DatabaseError(Box::new(e))) 341 | } 342 | 343 | pub async fn channel_messages( 344 | &self, 345 | channel_id: &str, 346 | limit: i64, 347 | ) -> anyhow::Result> { 348 | let channel_id = channel_id.to_string(); 349 | 350 | self.conn 351 | .call(move |conn| { 352 | let mut stmt = conn.prepare( 353 | "SELECT source_id, content 354 | FROM messages 355 | WHERE channel_id = ?1 356 | ORDER BY created_at DESC 357 | LIMIT ?2", 358 | )?; 359 | let messages = stmt 360 | .query_map([&channel_id, &limit.to_string()], |row| { 361 | Ok((row.get(0)?, row.get(1)?)) 362 | })? 363 | .collect::, _>>()?; 364 | Ok(messages) 365 | }) 366 | .await 367 | .map_err(|e| anyhow::anyhow!(e)) 368 | } 369 | 370 | pub async fn add_message_embeddings(&self, msg: Message) -> anyhow::Result<()> { 371 | let embeddings = EmbeddingsBuilder::new(self.embedding_model.clone()) 372 | .documents(vec![msg.clone()])? 373 | .build() 374 | .await?; 375 | 376 | self.message_store.add_rows(embeddings).await?; 377 | 378 | Ok(()) 379 | } 380 | 381 | pub async fn add_documents<'a, I>(&mut self, documents: I) -> anyhow::Result<()> 382 | where 383 | I: IntoIterator, 384 | { 385 | info!("Adding documents to KnowledgeBase"); 386 | let embeddings = EmbeddingsBuilder::new(self.embedding_model.clone()) 387 | .documents(documents)? 388 | .build() 389 | .await?; 390 | 391 | debug!("Adding embeddings to document store"); 392 | self.document_store.add_rows(embeddings).await?; 393 | 394 | info!("Successfully added documents to KnowledgeBase"); 395 | Ok(()) 396 | } 397 | } 398 | --------------------------------------------------------------------------------