├── .gitignore ├── README.md ├── Cargo.toml ├── test.py └── src └── lib.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rust-sqlite-ext-example 2 | 3 | See: https://ricardoanderegg.com/posts/extending-sqlite-with-rust/ 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "sqlite-regex-ext" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | 7 | # [features] 8 | # default = [] 9 | # build_extension = [ 10 | # "rusqlite/bundled", 11 | # "rusqlite/functions", 12 | # "rusqlite/loadable_extension", 13 | # ] 14 | 15 | [lib] 16 | crate-type = ["cdylib"] 17 | 18 | [dependencies] 19 | 20 | # once_cell = "1.9.0" 21 | regex = "1.5.4" 22 | log = "0.4.14" 23 | env_logger = "0.9.0" 24 | anyhow = "1.0.54" 25 | 26 | 27 | [dependencies.rusqlite] 28 | package = "rusqlite" 29 | git = "https://github.com/litements/rusqlite/" 30 | branch = "loadable-extensions-release-2" 31 | # path = "../rusqlite/loadable-extensions-release-2" 32 | default-features = false 33 | features = [ 34 | "loadable_extension", 35 | "vtab", 36 | "functions", 37 | "bundled", 38 | "modern_sqlite", 39 | "buildtime_bindgen", 40 | ] 41 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sqlite3 4 | 5 | conn = sqlite3.connect("test.db", isolation_level=None) 6 | 7 | print(f"Loading SQLite extension in connection: {conn}") 8 | conn.enable_load_extension(True) 9 | conn.execute( 10 | "SELECT load_extension('target/release/libsqlite_regex_ext.dylib', 'sqlite3_regex_init');" 11 | ) 12 | 13 | print("Running tests...") 14 | 15 | print("Testing pattern 'x(ab)' WITHOUT capture group") 16 | row = conn.execute("SELECT regex_extract('x(ab)', 'xxabaa')").fetchone() 17 | assert row[0] == "xab", row[0] 18 | 19 | print("Testing pattern 'x(ab)' WITH capture group = 1") 20 | row = conn.execute("SELECT regex_extract('x(ab)', 'xxabaa', 1)").fetchone() 21 | assert row[0] == "ab", row[0] 22 | 23 | print("Testing pattern 'x(ab)' WITH capture group = 0") 24 | row = conn.execute("SELECT regex_extract('x(ab)', 'xxabaa', 0)").fetchone() 25 | assert row[0] == "xab", row[0] 26 | 27 | print("Testing pattern 'g(oog)+le' WITHOUT capture group") 28 | row = conn.execute("SELECT regex_extract('g(oog)+le', 'googoogoogle')").fetchone() 29 | assert row[0] == "googoogoogle", row[0] 30 | 31 | print("Testing pattern 'g(oog)+le' WITH capture group = 1") 32 | row = conn.execute("SELECT regex_extract('g(oog)+le', 'googoogoogle', 1)").fetchone() 33 | assert row[0] == "oog", row[0] 34 | 35 | print("Testing pattern '[Cc]at' WITHOUT capture group") 36 | row = conn.execute("SELECT regex_extract('[Cc]at', 'cat')").fetchone() 37 | assert row[0] == "cat", row[0] 38 | 39 | print("Testing pattern '[Cc]at' WITHOUT capture group, expecting empty return") 40 | row = conn.execute("SELECT regex_extract('[Cc]at', 'hello')").fetchone() 41 | assert row[0] is None, row[0] 42 | 43 | conn.close() 44 | 45 | 46 | conn2 = sqlite3.connect("test.db", isolation_level=None) 47 | print(f"Testing connection 2: {conn2}") 48 | row = conn2.execute("SELECT regex_extract('x(ab)', 'xxabaa')").fetchone() 49 | assert row[0] == "xab", row[0] 50 | 51 | 52 | print("All tests passed") 53 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // #![allow( 2 | // dead_code, 3 | // unused_imports, 4 | // unused_variables, 5 | // clippy::missing_safety_doc 6 | // )] 7 | #![allow(clippy::missing_safety_doc)] 8 | 9 | use crate::ffi::loadable_extension_init; 10 | use crate::ffi::sqlite3_auto_extension; 11 | use anyhow::Context as ACtxt; 12 | use log::LevelFilter; 13 | use regex::bytes::Regex; 14 | use rusqlite::ffi; 15 | use rusqlite::functions::{Context, FunctionFlags}; 16 | use rusqlite::types::{ToSqlOutput, Value, ValueRef}; 17 | use rusqlite::Connection; 18 | use std::os::raw::c_int; 19 | 20 | fn ah(e: anyhow::Error) -> rusqlite::Error { 21 | rusqlite::Error::UserFunctionError(format!("{:?}", e).into()) 22 | } 23 | 24 | fn init_logging(default_level: LevelFilter) { 25 | let lib_log_env = "SQLITE_REGEX_LOG"; 26 | if std::env::var(lib_log_env).is_err() { 27 | std::env::set_var(lib_log_env, format!("{}", default_level)) 28 | } 29 | 30 | let logger_env = env_logger::Env::new().filter(lib_log_env); 31 | 32 | env_logger::try_init_from_env(logger_env).ok(); 33 | } 34 | 35 | // Will use with ffi:sqlite3_auto_extension(arg1) 36 | // https://www.sqlite.org/c3ref/auto_extension.html 37 | // Example: https://sqlite.org/src/file/ext/misc/vfsstat.c 38 | // https://www.sqlite.org/loadext.html 39 | // #[no_mangle] 40 | // pub unsafe extern "C" fn regex_register( 41 | // db: *mut ffi::sqlite3, 42 | // _pz_err_msg: &mut &mut std::os::raw::c_char, 43 | // p_api: *mut ffi::sqlite3_api_routines, 44 | // ) -> c_int {} 45 | 46 | #[no_mangle] 47 | pub unsafe extern "C" fn sqlite3_regex_init_internal( 48 | db: *mut ffi::sqlite3, 49 | _pz_err_msg: &mut &mut std::os::raw::c_char, 50 | p_api: *mut ffi::sqlite3_api_routines, 51 | ) -> c_int { 52 | // https://www.sqlite.org/loadext.html 53 | // https://github.com/jgallagher/rusqlite/issues/524#issuecomment-507787350 54 | // SQLITE_EXTENSION_INIT2 equivalent 55 | loadable_extension_init(p_api); 56 | /* Insert here calls to 57 | ** sqlite3_create_function_v2(), 58 | ** sqlite3_create_collation_v2(), 59 | ** sqlite3_create_module_v2(), and/or 60 | ** sqlite3_vfs_register() 61 | ** to register the new features that your extension adds. 62 | */ 63 | match init(db) { 64 | Ok(()) => { 65 | log::info!("[regex-extension] init ok"); 66 | // ffi::SQLITE_OK 67 | ffi::SQLITE_OK_LOAD_PERMANENTLY 68 | } 69 | 70 | Err(e) => { 71 | log::error!("[regex-extension] init error: {:?}", e); 72 | ffi::SQLITE_ERROR 73 | } 74 | } 75 | } 76 | 77 | #[no_mangle] 78 | pub unsafe extern "C" fn sqlite3_regex_init( 79 | db: *mut ffi::sqlite3, 80 | _pz_err_msg: &mut &mut std::os::raw::c_char, 81 | p_api: *mut ffi::sqlite3_api_routines, 82 | ) -> c_int { 83 | loadable_extension_init(p_api); 84 | let ptr = sqlite3_regex_init_internal 85 | as unsafe extern "C" fn( 86 | *mut ffi::sqlite3, 87 | &mut &mut std::os::raw::c_char, 88 | *mut ffi::sqlite3_api_routines, 89 | ) -> c_int; 90 | 91 | sqlite3_auto_extension(Some(std::mem::transmute(ptr))); 92 | match init(db) { 93 | Ok(()) => { 94 | log::info!("[regex-extension] init ok"); 95 | ffi::SQLITE_OK_LOAD_PERMANENTLY 96 | } 97 | 98 | Err(e) => { 99 | log::error!("[regex-extension] init error: {:?}", e); 100 | ffi::SQLITE_ERROR 101 | } 102 | } 103 | } 104 | 105 | fn init(db_handle: *mut ffi::sqlite3) -> anyhow::Result<()> { 106 | let db = unsafe { rusqlite::Connection::from_handle(db_handle)? }; 107 | load(&db)?; 108 | Ok(()) 109 | } 110 | 111 | fn load(c: &Connection) -> anyhow::Result<()> { 112 | load_with_loglevel(c, LevelFilter::Info) 113 | } 114 | 115 | fn load_with_loglevel(c: &Connection, default_log_level: LevelFilter) -> anyhow::Result<()> { 116 | init_logging(default_log_level); 117 | add_functions(c) 118 | } 119 | 120 | fn add_functions(c: &Connection) -> anyhow::Result<()> { 121 | let deterministic = FunctionFlags::SQLITE_DETERMINISTIC | FunctionFlags::SQLITE_UTF8; 122 | // | FunctionFlags::SQLITE_INNOCUOUS; 123 | 124 | c.create_scalar_function("regex_extract", 2, deterministic, |ctx: &Context| { 125 | regex_extract(ctx).map_err(ah) 126 | })?; 127 | 128 | c.create_scalar_function("regex_extract", 3, deterministic, |ctx: &Context| { 129 | regex_extract(ctx).map_err(ah) 130 | })?; 131 | 132 | Ok(()) 133 | } 134 | 135 | fn regex_extract<'a>(ctx: &Context) -> anyhow::Result> { 136 | let arg_pat = 0; 137 | let arg_input_data = 1; 138 | let arg_cap_group = 2; 139 | 140 | let empty_return = Ok(ToSqlOutput::Owned(Value::Null)); 141 | 142 | let pattern = match ctx.get_raw(arg_pat) { 143 | ValueRef::Text(t) => t, 144 | e => anyhow::bail!("regex pattern must be text, got {}", e.data_type()), 145 | }; 146 | 147 | let re = Regex::new(std::str::from_utf8(pattern)?)?; 148 | 149 | let input_value = match ctx.get_raw(arg_input_data) { 150 | ValueRef::Text(t) => t, 151 | ValueRef::Null => return empty_return, 152 | e => anyhow::bail!("regex expects text as input, got {}", e.data_type()), 153 | }; 154 | 155 | let cap_group: usize = if ctx.len() <= arg_cap_group { 156 | // no capture group, use default 157 | 0 158 | } else { 159 | ctx.get(arg_cap_group).context("capture group")? 160 | }; 161 | 162 | // let mut caploc = re.capture_locations(); 163 | // re.captures_read(&mut caploc, input_value); 164 | if let Some(cap) = re.captures(input_value) { 165 | match cap.get(cap_group) { 166 | None => empty_return, 167 | // String::from_utf8_lossy 168 | Some(t) => { 169 | let value = String::from_utf8_lossy(t.as_bytes()); 170 | return Ok(ToSqlOutput::Owned(Value::Text(value.to_string()))); 171 | } 172 | } 173 | } else { 174 | empty_return 175 | } 176 | } 177 | --------------------------------------------------------------------------------