├── .gitignore ├── valgrind.suppress ├── .github ├── dependabot.yml └── workflows │ ├── auto-assignee.yml │ ├── sast.yml │ ├── sca.yml │ ├── lint.yml │ └── tests.yml ├── .luacheckrc ├── cbindgen.toml ├── src ├── lib.rs ├── ffi │ ├── schema.rs │ ├── mod.rs │ ├── context.rs │ ├── router.rs │ └── expression.rs ├── schema.rs ├── atc_grammar.pest ├── context.rs ├── router.rs ├── semantics.rs ├── parser.rs ├── ast.rs └── interpreter.rs ├── Cargo.toml ├── t ├── 02-gc.t ├── 09-not.t ├── 03-contains.t ├── 07-in_notin.t ├── 05-equals.t ├── 02-bugs.t ├── 08-equals.t ├── 04-rawstr.t ├── 06-validate.t └── 01-sanity.t ├── benches ├── build.rs └── match_mix.rs ├── lib └── resty │ └── router │ ├── schema.lua │ ├── router.lua │ ├── context.lua │ └── cdefs.lua ├── Makefile ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # These are backup files generated by rustfmt 6 | **/*.rs.bk 7 | 8 | t/servroot 9 | Cargo.lock 10 | -------------------------------------------------------------------------------- /valgrind.suppress: -------------------------------------------------------------------------------- 1 | { 2 | 3 | Memcheck:Leak 4 | match-leak-kinds: definite 5 | fun:malloc 6 | fun:ngx_alloc 7 | fun:ngx_set_environment 8 | fun:ngx_single_process_cycle 9 | fun:main 10 | } 11 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | 9 | - package-ecosystem: "cargo" 10 | directory: "/" 11 | schedule: 12 | interval: "daily" 13 | -------------------------------------------------------------------------------- /.luacheckrc: -------------------------------------------------------------------------------- 1 | std = "ngx_lua" 2 | unused_args = false 3 | redefined = false 4 | max_line_length = false 5 | 6 | 7 | not_globals = { 8 | "string.len", 9 | "table.getn", 10 | } 11 | 12 | 13 | ignore = { 14 | "6.", -- ignore whitespace warnings 15 | } 16 | -------------------------------------------------------------------------------- /.github/workflows/auto-assignee.yml: -------------------------------------------------------------------------------- 1 | name: Add assignee to PRs 2 | on: 3 | pull_request: 4 | types: [ opened, reopened ] 5 | permissions: 6 | pull-requests: write 7 | jobs: 8 | assign-author: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: toshimaru/auto-author-assign@16f0022cf3d7970c106d8d1105f75a1165edb516 12 | 13 | -------------------------------------------------------------------------------- /cbindgen.toml: -------------------------------------------------------------------------------- 1 | language = "C" 2 | header = "/* Generated by cbindgen. Do NOT edit. */" 3 | 4 | [enum] 5 | prefix_with_name = true 6 | 7 | [defines] 8 | "feature = ffi" = "DEFINE_ATC_ROUTER_FFI" 9 | 10 | [macro_expansion] 11 | bitflags = true 12 | 13 | [export] 14 | include = [ 15 | "BinaryOperatorFlags", 16 | "ATC_ROUTER_EXPRESSION_VALIDATE_OK", 17 | "ATC_ROUTER_EXPRESSION_VALIDATE_FAILED", 18 | "ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL" 19 | ] -------------------------------------------------------------------------------- /.github/workflows/sast.yml: -------------------------------------------------------------------------------- 1 | name: SAST 2 | 3 | on: 4 | pull_request: {} 5 | push: 6 | branches: 7 | - master 8 | - main 9 | workflow_dispatch: {} 10 | 11 | 12 | jobs: 13 | semgrep: 14 | name: Semgrep SAST 15 | runs-on: ubuntu-latest 16 | permissions: 17 | # required for all workflows 18 | security-events: write 19 | # only required for workflows in private repositories 20 | actions: read 21 | contents: read 22 | 23 | if: (github.actor != 'dependabot[bot]') 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | - uses: Kong/public-shared-actions/security-actions/semgrep@0ccacffed804d85da3f938a1b78c12831935f992 # v2 28 | with: 29 | additional_config: '--config p/rust' 30 | 31 | -------------------------------------------------------------------------------- /.github/workflows/sca.yml: -------------------------------------------------------------------------------- 1 | name: SCA 2 | 3 | on: 4 | pull_request: {} 5 | workflow_dispatch: {} 6 | push: 7 | branches: 8 | - main 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }} 12 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 13 | 14 | jobs: 15 | 16 | rust-sca: 17 | name: Rust SCA 18 | runs-on: ubuntu-latest 19 | 20 | permissions: 21 | # required for all workflows 22 | security-events: write 23 | checks: write 24 | pull-requests: write 25 | actions: read 26 | contents: read 27 | 28 | if: (github.actor != 'dependabot[bot]') 29 | 30 | steps: 31 | - name: Checkout source code 32 | uses: actions/checkout@v4 33 | 34 | - name: Rust SCA 35 | uses: Kong/public-shared-actions/security-actions/sca@916a6f6221b7eab6f5ae53d061274d588c965ae6 # 5.1.1 36 | with: 37 | asset_prefix: 'atc-router' 38 | codeql_upload: 'true' 39 | dir: '.' 40 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(warnings, missing_debug_implementations)] 2 | /*! 3 | This crate provides a powerful rule based matching engine that can match a set of routes 4 | against dynamic input value efficiently. 5 | It is mainly used inside the [Kong Gateway](https://github.com/Kong/kong) 6 | for performing route matching against incoming requests and is used as a FFI binding 7 | for LuaJIT. 8 | 9 | Please see the [repository README.md](https://github.com/Kong/atc-router/blob/main/README.md) 10 | for more detailed explainations of the concepts and APIs. 11 | 12 | # Crate features 13 | 14 | * **ffi** - 15 | Builds the FFI based interface which is suitable for use by a foreign language such as 16 | C or LuaJIT. This feature is on by default. 17 | * **serde** - 18 | Enable serde integration which allows data structures to be serializable/deserializable. 19 | */ 20 | 21 | pub mod ast; 22 | pub mod context; 23 | pub mod interpreter; 24 | pub mod parser; 25 | pub mod router; 26 | pub mod schema; 27 | pub mod semantics; 28 | 29 | #[cfg(feature = "ffi")] 30 | pub mod ffi; 31 | 32 | #[macro_use] 33 | extern crate pest_derive; 34 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "atc-router" 3 | version = "1.7.1" 4 | edition = "2021" 5 | license = "Apache-2.0" 6 | authors = ["Datong Sun ", "Kong Contributors"] 7 | description = """ 8 | Versatile DSL based rule matching engine used by the Kong API Gateway 9 | """ 10 | repository = "https://github.com/Kong/atc-router" 11 | keywords = ["dsl", "atc", "router", "rule", "engine"] 12 | categories = ["compilers"] 13 | 14 | [dependencies] 15 | pest = "2.7" 16 | pest_derive = "2.7" 17 | cidr = "0.3" 18 | lazy_static = "1.5" 19 | uuid = "1.8" 20 | regex = "1" 21 | serde = { version = "1.0", features = ["derive"], optional = true } 22 | serde_regex = { version = "1.1", optional = true } 23 | fnv = "1" 24 | bitflags = { version = "2.6", optional = true } 25 | 26 | [dev-dependencies] 27 | criterion = "0" 28 | 29 | [lib] 30 | crate-type = ["lib", "cdylib", "staticlib"] 31 | 32 | [features] 33 | default = ["ffi"] 34 | ffi = ["dep:bitflags"] 35 | serde = ["cidr/serde", "dep:serde", "dep:serde_regex"] 36 | 37 | 38 | [[bench]] 39 | name = "build" 40 | harness = false 41 | 42 | [[bench]] 43 | name = "match_mix" 44 | harness = false 45 | -------------------------------------------------------------------------------- /t/02-gc.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(1); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: gc schema, router 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | 32 | local s = schema.new() 33 | local r = router.new(s) 34 | 35 | schema = nil 36 | router = nil 37 | 38 | rawset(package.loaded, "resty.router.schema", nil) 39 | rawset(package.loaded, "resty.router.router", nil) 40 | rawset(package.loaded, "resty.router.cdefs", nil) 41 | 42 | collectgarbage() 43 | 44 | s = nil 45 | r = nil 46 | 47 | collectgarbage() 48 | 49 | ngx.say("ok") 50 | } 51 | } 52 | --- request 53 | GET /t 54 | --- response_body 55 | ok 56 | --- no_error_log 57 | [error] 58 | [warn] 59 | [crit] 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /benches/build.rs: -------------------------------------------------------------------------------- 1 | use atc_router::ast::Type; 2 | use atc_router::router::Router; 3 | use atc_router::schema::Schema; 4 | use criterion::{criterion_group, criterion_main, Criterion}; 5 | use uuid::Uuid; 6 | 7 | // To run this benchmark, execute the following command: 8 | // ```shell 9 | // cargo bench --bench build 10 | // ``` 11 | 12 | const N: usize = 1000; 13 | 14 | fn make_uuid(a: usize) -> String { 15 | format!("8cb2a7d0-c775-4ed9-989f-{:012}", a) 16 | } 17 | 18 | fn criterion_benchmark(c: &mut Criterion) { 19 | // prepare test data 20 | let mut data = Vec::new(); 21 | for i in 0..N { 22 | let priority = N - i; 23 | 24 | let uuid = make_uuid(i); 25 | let uuid = Uuid::try_from(uuid.as_str()).unwrap(); 26 | 27 | let expr = format!("((a > 0 || a < {}) && a != 0) && a == 1", N + 1); 28 | 29 | data.push((priority, uuid, expr)) 30 | } 31 | 32 | let mut schema = Schema::default(); 33 | schema.add_field("a", Type::Int); 34 | 35 | c.bench_function("Build Router", |b| { 36 | b.iter_with_large_drop(|| { 37 | let mut router = Router::new(&schema); 38 | for v in &data { 39 | router.add_matcher(v.0, v.1, &v.2).unwrap(); 40 | } 41 | router 42 | }); 43 | }); 44 | } 45 | 46 | criterion_group!(benches, criterion_benchmark); 47 | criterion_main!(benches); 48 | -------------------------------------------------------------------------------- /src/ffi/schema.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::Type; 2 | use crate::schema::Schema; 3 | use std::ffi; 4 | use std::os::raw::c_char; 5 | 6 | #[no_mangle] 7 | pub extern "C" fn schema_new() -> *mut Schema { 8 | Box::into_raw(Box::default()) 9 | } 10 | 11 | /// Deallocate the schema object. 12 | /// 13 | /// # Errors 14 | /// 15 | /// This function never fails. 16 | /// 17 | /// # Safety 18 | /// 19 | /// Violating any of the following constraints will result in undefined behavior: 20 | /// 21 | /// - `schema` must be a valid pointer returned by [`schema_new`]. 22 | #[no_mangle] 23 | pub unsafe extern "C" fn schema_free(schema: *mut Schema) { 24 | drop(Box::from_raw(schema)); 25 | } 26 | 27 | /// Add a new field with the specified type to the schema. 28 | /// 29 | /// # Arguments 30 | /// 31 | /// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. 32 | /// - `field`: the C-style string representing the field name. 33 | /// - `typ`: the type of the field. 34 | /// 35 | /// # Panics 36 | /// 37 | /// This function will panic if the C-style string 38 | /// pointed by `field` is not a valid UTF-8 string. 39 | /// 40 | /// # Safety 41 | /// 42 | /// Violating any of the following constraints will result in undefined behavior: 43 | /// 44 | /// - `schema` must be a valid pointer returned by [`schema_new`]. 45 | /// - `field` must be a valid pointer to a C-style string, must be properly aligned, 46 | /// and must not have '\0' in the middle. 47 | #[no_mangle] 48 | pub unsafe extern "C" fn schema_add_field(schema: &mut Schema, field: *const i8, typ: Type) { 49 | let field = ffi::CStr::from_ptr(field as *const c_char) 50 | .to_str() 51 | .unwrap(); 52 | 53 | schema.add_field(field, typ) 54 | } 55 | -------------------------------------------------------------------------------- /src/schema.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::Type; 2 | use std::collections::HashMap; 3 | 4 | #[derive(Debug, Default)] 5 | pub struct Schema { 6 | fields: HashMap, 7 | } 8 | 9 | impl Schema { 10 | pub fn type_of(&self, field: &str) -> Option<&Type> { 11 | self.fields.get(field).or_else(|| { 12 | self.fields 13 | .get(&format!("{}.*", &field[..field.rfind('.')?])) 14 | }) 15 | } 16 | 17 | pub fn add_field(&mut self, field: &str, typ: Type) { 18 | self.fields.insert(field.to_string(), typ); 19 | } 20 | } 21 | 22 | #[cfg(test)] 23 | mod tests { 24 | use super::*; 25 | 26 | #[test] 27 | fn normal_fields() { 28 | let mut s = Schema::default(); 29 | 30 | s.add_field("str", Type::String); 31 | s.add_field("ip", Type::IpAddr); 32 | s.add_field("cidr", Type::IpCidr); 33 | s.add_field("r", Type::Regex); 34 | s.add_field("i", Type::Int); 35 | 36 | assert_eq!(s.type_of("str"), Some(&Type::String)); 37 | assert_eq!(s.type_of("ip"), Some(&Type::IpAddr)); 38 | assert_eq!(s.type_of("cidr"), Some(&Type::IpCidr)); 39 | assert_eq!(s.type_of("r"), Some(&Type::Regex)); 40 | assert_eq!(s.type_of("i"), Some(&Type::Int)); 41 | 42 | assert_eq!(s.type_of("unknown"), None); 43 | } 44 | 45 | #[test] 46 | fn wildcard_fields() { 47 | let mut s = Schema::default(); 48 | 49 | s.add_field("a.*", Type::String); 50 | 51 | assert_eq!(s.type_of("a.b"), Some(&Type::String)); 52 | assert_eq!(s.type_of("a.xxx"), Some(&Type::String)); 53 | 54 | assert_eq!(s.type_of("aa.xxx"), None); 55 | assert_eq!(s.type_of("a.x.y"), None); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/atc_grammar.pest: -------------------------------------------------------------------------------- 1 | WHITESPACE = _{ " " | "\t" | "\r" | "\n" } 2 | ident = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_" | ".")* } 3 | rhs = { str_literal | rawstr_literal | ip_literal | int_literal } 4 | transform_func = { ident ~ "(" ~ lhs ~ ")" } 5 | lhs = { transform_func | ident } 6 | 7 | 8 | int_literal = ${ "-"? ~ digits } 9 | digits = _{ oct_digits | ( "0x" ~ hex_digits ) | dec_digits } 10 | hex_digits = { ASCII_HEX_DIGIT+ } 11 | oct_digits = { "0" ~ ASCII_OCT_DIGIT+ } 12 | dec_digits = { ASCII_DIGIT+ } 13 | 14 | 15 | str_literal = ${ "\"" ~ str_inner ~ "\"" } 16 | str_inner = _{ (str_esc | str_char)* } 17 | str_char = { !("\"" | "\\") ~ ANY } 18 | str_esc = { "\\" ~ ("\"" | "\\" | "n" | "r" | "t") } 19 | 20 | rawstr_literal = ${ "r#\"" ~ rawstr_char* ~ "\"#" } 21 | rawstr_char = { !"\"#" ~ ANY } 22 | 23 | ipv4_literal = @{ ASCII_DIGIT{1,3} ~ ( "." ~ ASCII_DIGIT{1,3} ){3} } 24 | ipv6_literal = @{ 25 | ( ":" | ASCII_HEX_DIGIT{1,4} ) ~ ":" ~ ( ASCII_HEX_DIGIT{1,4} | ":" )* 26 | } 27 | ipv4_cidr_literal = @{ ipv4_literal ~ "/" ~ ASCII_DIGIT{1,2} } 28 | ipv6_cidr_literal = @{ ipv6_literal ~ "/" ~ ASCII_DIGIT{1,3} } 29 | ip_literal = _{ ipv4_cidr_literal | ipv6_cidr_literal | ipv4_literal | ipv6_literal } 30 | 31 | 32 | binary_operator = { "==" | "!=" | "~" | "^=" | "=^" | ">=" | 33 | ">" | "<=" | "<" | "in" | "not" ~ "in" | "contains" } 34 | logical_operator = _{ and_op | or_op } 35 | and_op = { "&&" } 36 | or_op = { "||" } 37 | 38 | not_op = { "!" } 39 | 40 | 41 | predicate = { lhs ~ binary_operator ~ rhs } 42 | parenthesised_expression = { not_op? ~ "(" ~ expression ~ ")" } 43 | term = { predicate | parenthesised_expression } 44 | expression = { term ~ ( logical_operator ~ term )* } 45 | matcher = { SOI ~ expression ~ EOI } 46 | -------------------------------------------------------------------------------- /lib/resty/router/schema.lua: -------------------------------------------------------------------------------- 1 | local _M = {} 2 | local cdefs = require("resty.router.cdefs") 3 | local ffi = require("ffi") 4 | 5 | 6 | local _MT = { __index = _M, } 7 | 8 | 9 | local setmetatable = setmetatable 10 | local ffi_gc = ffi.gc 11 | local clib = cdefs.clib 12 | local schema_free = cdefs.schema_free 13 | 14 | 15 | function _M.new() 16 | local schema = clib.schema_new() 17 | local s = setmetatable({ 18 | schema = ffi_gc(schema, schema_free), 19 | field_types = {}, 20 | field_ctypes = {}, 21 | clib = clib, 22 | }, _MT) 23 | 24 | return s 25 | end 26 | 27 | 28 | function _M:add_field(field, typ) 29 | if self.field_types[field] then 30 | return nil, "field " .. field .. " already exists" 31 | end 32 | 33 | local ctype 34 | 35 | if typ == "String" then 36 | ctype = clib.String 37 | 38 | elseif typ == "IpCidr" then 39 | ctype = clib.IpCidr 40 | 41 | elseif typ == "IpAddr" then 42 | ctype = clib.IpAddr 43 | 44 | elseif typ == "Int" then 45 | ctype = clib.Int 46 | 47 | else 48 | error("Unknown type: " .. typ, 2) 49 | end 50 | 51 | clib.schema_add_field(self.schema, field, ctype) 52 | 53 | self.field_types[field] = typ 54 | self.field_ctypes[field] = ctype 55 | 56 | return true 57 | end 58 | 59 | 60 | function _M:get_field_type(field) 61 | local typ = self.field_types[field] 62 | 63 | if not typ then 64 | local name = field:match("(.+)%..+") 65 | if name then 66 | typ = self.field_types[name .. ".*"] 67 | if not typ then 68 | return nil, "field " .. field .. " unknown" 69 | end 70 | end 71 | end 72 | 73 | return typ 74 | end 75 | 76 | 77 | return _M 78 | -------------------------------------------------------------------------------- /src/context.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::Value; 2 | use crate::schema::Schema; 3 | use fnv::FnvHashMap; 4 | use uuid::Uuid; 5 | 6 | #[derive(Debug)] 7 | pub struct Match { 8 | pub uuid: Uuid, 9 | pub matches: FnvHashMap, 10 | pub captures: FnvHashMap, 11 | } 12 | 13 | impl Match { 14 | pub fn new() -> Self { 15 | Self { 16 | uuid: Uuid::default(), 17 | matches: FnvHashMap::default(), 18 | captures: FnvHashMap::default(), 19 | } 20 | } 21 | 22 | pub(crate) fn reset(&mut self) { 23 | self.matches.clear(); 24 | self.captures.clear(); 25 | } 26 | } 27 | 28 | impl Default for Match { 29 | fn default() -> Self { 30 | Self::new() 31 | } 32 | } 33 | 34 | #[derive(Debug)] 35 | pub struct Context<'a> { 36 | schema: &'a Schema, 37 | values: FnvHashMap>, 38 | pub result: Option, 39 | } 40 | 41 | impl<'a> Context<'a> { 42 | pub fn new(schema: &'a Schema) -> Self { 43 | Self { 44 | schema, 45 | values: FnvHashMap::with_hasher(Default::default()), 46 | result: None, 47 | } 48 | } 49 | 50 | pub fn add_value(&mut self, field: &str, value: Value) { 51 | if &value.my_type() != self.schema.type_of(field).unwrap() { 52 | panic!("value provided does not match schema"); 53 | } 54 | 55 | self.values 56 | .entry(field.to_string()) 57 | .or_default() 58 | .push(value); 59 | } 60 | 61 | pub fn value_of(&self, field: &str) -> Option<&[Value]> { 62 | self.values.get(field).map(|v| v.as_slice()) 63 | } 64 | 65 | pub fn reset(&mut self) { 66 | self.values.clear(); 67 | self.result = None; 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/ffi/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod context; 2 | pub mod expression; 3 | pub mod router; 4 | pub mod schema; 5 | 6 | use crate::ast::Value; 7 | use cidr::IpCidr; 8 | use std::convert::TryFrom; 9 | use std::ffi; 10 | use std::net::IpAddr; 11 | use std::os::raw::c_char; 12 | use std::slice::from_raw_parts; 13 | 14 | pub const ERR_BUF_MAX_LEN: usize = 4096; 15 | 16 | #[derive(Debug)] 17 | #[repr(C)] 18 | pub enum CValue { 19 | Str(*const u8, usize), 20 | IpCidr(*const u8), 21 | IpAddr(*const u8), 22 | Int(i64), 23 | } 24 | 25 | impl TryFrom<&CValue> for Value { 26 | type Error = String; 27 | 28 | fn try_from(v: &CValue) -> Result { 29 | Ok(match v { 30 | CValue::Str(s, len) => Self::String(unsafe { 31 | std::str::from_utf8(from_raw_parts(*s, *len)) 32 | .map_err(|e| e.to_string())? 33 | .to_string() 34 | }), 35 | CValue::IpCidr(s) => Self::IpCidr( 36 | unsafe { 37 | ffi::CStr::from_ptr(*s as *const c_char) 38 | .to_str() 39 | .map_err(|e| e.to_string())? 40 | .to_string() 41 | } 42 | .parse::() 43 | .map_err(|e| e.to_string())?, 44 | ), 45 | CValue::IpAddr(s) => Self::IpAddr( 46 | unsafe { 47 | ffi::CStr::from_ptr(*s as *const c_char) 48 | .to_str() 49 | .map_err(|e| e.to_string())? 50 | .to_string() 51 | } 52 | .parse::() 53 | .map_err(|e| e.to_string())?, 54 | ), 55 | CValue::Int(i) => Self::Int(*i), 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | OS=$(shell uname -s) 2 | 3 | ifeq ($(OS), Darwin) 4 | SHLIB_EXT=dylib 5 | else 6 | SHLIB_EXT=so 7 | endif 8 | 9 | OPENRESTY_PREFIX=/usr/local/openresty 10 | 11 | #LUA_VERSION := 5.1 12 | PREFIX ?= /usr/local 13 | LUA_INCLUDE_DIR ?= $(PREFIX)/include 14 | LUA_LIB_DIR ?= $(PREFIX)/lib/lua/$(LUA_VERSION) 15 | INSTALL ?= install 16 | RELEASE_FOLDER = target/$(CARGO_BUILD_TARGET)/release 17 | DEBUG_RELEASE_FOLDER = target/$(CARGO_BUILD_TARGET)/debug 18 | 19 | .PHONY: all test install build clean 20 | 21 | all: ; 22 | 23 | build: $(RELEASE_FOLDER)/libatc_router.$(SHLIB_EXT) $(RELEASE_FOLDER)/libatc_router.a 24 | 25 | $(RELEASE_FOLDER)/libatc_router.%: src/*.rs 26 | cargo build --release 27 | 28 | $(DEBUG_RELEASE_FOLDER)/libatc_router.%: src/*.rs 29 | cargo build 30 | 31 | install-lualib: 32 | $(INSTALL) -d $(DESTDIR)$(LUA_LIB_DIR)/resty/router/ 33 | $(INSTALL) -m 664 lib/resty/router/*.lua $(DESTDIR)$(LUA_LIB_DIR)/resty/router/ 34 | 35 | install: build install-lualib 36 | $(INSTALL) -m 775 $(RELEASE_FOLDER)/libatc_router.$(SHLIB_EXT) $(DESTDIR)$(LUA_LIB_DIR)/libatc_router.$(SHLIB_EXT) 37 | 38 | install-debug: $(DEBUG_RELEASE_FOLDER)/libatc_router.% install-lualib 39 | $(INSTALL) -m 775 $(DEBUG_RELEASE_FOLDER)/libatc_router.$(SHLIB_EXT) $(DESTDIR)$(LUA_LIB_DIR)/libatc_router.$(SHLIB_EXT) 40 | 41 | test: $(DEBUG_RELEASE_FOLDER)/libatc_router.% 42 | PATH="$(OPENRESTY_PREFIX)/nginx/sbin:$$PATH" \ 43 | LUA_PATH="$(realpath lib)/?.lua;$(realpath lib)/?/init.lua;$$LUA_PATH" \ 44 | LUA_CPATH="$(realpath $(DEBUG_RELEASE_FOLDER))/?.so;$$LUA_CPATH" \ 45 | prove -r t/ 46 | 47 | valgrind: $(DEBUG_RELEASE_FOLDER)/libatc_router.% 48 | (PATH="$(OPENRESTY_PREFIX)/nginx/sbin:$$PATH" \ 49 | LUA_PATH="$(realpath lib)/?.lua;$(realpath lib)/?/init.lua;$$LUA_PATH" \ 50 | LUA_CPATH="$(realpath $(DEBUG_RELEASE_FOLDER))/?.so;$$LUA_CPATH" \ 51 | prove -r t/) 2>&1 | tee /dev/stderr | grep -q "match-leak-kinds: definite" && exit 1 || exit 0 52 | 53 | clean: 54 | rm -rf target 55 | -------------------------------------------------------------------------------- /t/09-not.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: not operator negates result from inside expression 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | local context = require("resty.router.context") 32 | 33 | local s = schema.new() 34 | 35 | s:add_field("http.path", "String") 36 | 37 | local r = router.new(s) 38 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 39 | [[!(http.path ^= "/abc")]])) 40 | 41 | local c = context.new(s) 42 | c:add_value("http.path", "/abc/d") 43 | 44 | local matched = r:execute(c) 45 | ngx.say(matched) 46 | 47 | c:reset() 48 | 49 | c:add_value("http.path", "/abb/d") 50 | 51 | local matched = r:execute(c) 52 | ngx.say(matched) 53 | 54 | assert(r:remove_matcher("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c")) 55 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 56 | [[!(http.path =^ "/")]])) 57 | 58 | c:reset() 59 | 60 | c:add_value("http.path", "/abb/d/") 61 | local matched = r:execute(c) 62 | ngx.say(matched) 63 | 64 | c:reset() 65 | 66 | c:add_value("http.path", "/abb/d") 67 | local matched = r:execute(c) 68 | ngx.say(matched) 69 | } 70 | } 71 | --- request 72 | GET /t 73 | --- response_body 74 | false 75 | true 76 | false 77 | true 78 | --- no_error_log 79 | [error] 80 | [warn] 81 | [crit] 82 | -------------------------------------------------------------------------------- /benches/match_mix.rs: -------------------------------------------------------------------------------- 1 | use atc_router::ast::{Type, Value}; 2 | use atc_router::context::Context; 3 | use atc_router::router::Router; 4 | use atc_router::schema::Schema; 5 | use criterion::{criterion_group, criterion_main, Criterion}; 6 | use uuid::Uuid; 7 | 8 | // To run this benchmark, execute the following command: 9 | // ```shell 10 | // cargo bench --bench match_mix 11 | // ``` 12 | 13 | const N: usize = 100_000; 14 | 15 | fn make_uuid(a: usize) -> String { 16 | format!("8cb2a7d0-c775-4ed9-989f-{:012}", a) 17 | } 18 | 19 | fn criterion_benchmark(c: &mut Criterion) { 20 | let mut schema = Schema::default(); 21 | schema.add_field("http.path", Type::String); 22 | schema.add_field("http.version", Type::String); 23 | schema.add_field("a", Type::Int); 24 | 25 | let mut router = Router::new(&schema); 26 | 27 | for i in 0..N { 28 | let expr = format!( 29 | r#"(http.path == "hello{}" && http.version == "1.1") || {} || {} || {}"#, 30 | i, "!((a == 2) && (a == 9))", "!(a == 1)", "(a == 3 && a == 4) && !(a == 5)" 31 | ); 32 | 33 | let uuid = make_uuid(i); 34 | let uuid = Uuid::try_from(uuid.as_str()).unwrap(); 35 | 36 | router.add_matcher(N - i, uuid, &expr).unwrap(); 37 | } 38 | 39 | let mut ctx = Context::new(&schema); 40 | 41 | // match benchmark 42 | ctx.add_value("http.path", Value::String("hello49999".to_string())); 43 | ctx.add_value("http.version", Value::String("1.1".to_string())); 44 | ctx.add_value("a", Value::Int(3_i64)); 45 | 46 | c.bench_function("Match", |b| { 47 | b.iter(|| { 48 | let is_match = router.execute(&mut ctx); 49 | assert!(is_match); 50 | }); 51 | }); 52 | 53 | ctx.reset(); 54 | 55 | // not match benchmark 56 | ctx.add_value("http.path", Value::String("hello49999".to_string())); 57 | ctx.add_value("http.version", Value::String("1.1".to_string())); 58 | ctx.add_value("a", Value::Int(5_i64)); // not match 59 | 60 | c.bench_function("Doesn't Match", |b| { 61 | b.iter(|| { 62 | let not_match = !router.execute(&mut ctx); 63 | assert!(not_match); 64 | }); 65 | }); 66 | } 67 | 68 | criterion_group!(benches, criterion_benchmark); 69 | criterion_main!(benches); 70 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: {} 5 | workflow_dispatch: {} 6 | push: 7 | branches: 8 | - main 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }} 12 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 13 | 14 | jobs: 15 | lua-check: 16 | name: Lua Check 17 | runs-on: ubuntu-latest 18 | permissions: 19 | contents: read 20 | issues: read 21 | checks: write 22 | pull-requests: write 23 | if: (github.actor != 'dependabot[bot]') 24 | 25 | steps: 26 | - name: Checkout source code 27 | uses: actions/checkout@v4 28 | 29 | # Optional step to run on only changed files 30 | - name: Get changed files 31 | id: changed-files 32 | uses: kong/changed-files@4edd678ac3f81e2dc578756871e4d00c19191daf 33 | with: 34 | files: | 35 | **.lua 36 | 37 | - name: Lua Check 38 | if: steps.changed-files.outputs.any_changed == 'true' 39 | uses: Kong/public-shared-actions/code-check-actions/lua-lint@0ccacffed804d85da3f938a1b78c12831935f992 # v2 40 | with: 41 | additional_args: '--no-default-config --config .luacheckrc' 42 | files: ${{ steps.changed-files.outputs.all_changed_files }} 43 | 44 | rust-fmt: 45 | name: Rust Fmt 46 | runs-on: ubuntu-latest 47 | if: (github.actor != 'dependabot[bot]') 48 | 49 | steps: 50 | - name: Checkout source code 51 | uses: actions/checkout@v4 52 | 53 | - name: Run Rust Fmt 54 | run: cargo fmt --all -- --check # only check, don't format 55 | 56 | rust-clippy: 57 | name: Rust Clippy 58 | runs-on: ubuntu-latest 59 | 60 | permissions: 61 | # required for all workflows 62 | security-events: write 63 | checks: write 64 | pull-requests: write 65 | # only required for workflows in private repositories 66 | actions: read 67 | contents: read 68 | 69 | if: (github.actor != 'dependabot[bot]') 70 | 71 | steps: 72 | - name: Checkout source code 73 | uses: actions/checkout@v4 74 | 75 | # Optional step to run on only changed files 76 | - name: Get changed files 77 | id: changed-files 78 | uses: kong/changed-files@4edd678ac3f81e2dc578756871e4d00c19191daf 79 | with: 80 | files: | 81 | **.rs 82 | 83 | - name: Rust Clippy 84 | if: steps.changed-files.outputs.any_changed == 'true' 85 | uses: Kong/public-shared-actions/code-check-actions/rust-lint@0ccacffed804d85da3f938a1b78c12831935f992 # v2 86 | with: 87 | token: ${{ secrets.GITHUB_TOKEN }} 88 | 89 | -------------------------------------------------------------------------------- /t/03-contains.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: contains operator 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | local context = require("resty.router.context") 32 | 33 | local s = schema.new() 34 | 35 | s:add_field("http.path", "String") 36 | s:add_field("tcp.port", "Int") 37 | 38 | local r = router.new(s) 39 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 40 | "http.path contains \"keyword\" && tcp.port == 80")) 41 | 42 | local c = context.new(s) 43 | c:add_value("http.path", "/foo/keyword/bar") 44 | c:add_value("tcp.port", 80) 45 | 46 | local matched = r:execute(c) 47 | ngx.say(matched) 48 | 49 | local uuid, matched_value = c:get_result("http.path") 50 | ngx.say(uuid) 51 | ngx.say(matched_value) 52 | } 53 | } 54 | --- request 55 | GET /t 56 | --- response_body 57 | true 58 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 59 | nil 60 | --- no_error_log 61 | [error] 62 | [warn] 63 | [crit] 64 | 65 | 66 | 67 | 68 | === TEST 2: contains operator should mismatch 69 | --- http_config eval: $::HttpConfig 70 | --- config 71 | location = /t { 72 | content_by_lua_block { 73 | local schema = require("resty.router.schema") 74 | local router = require("resty.router.router") 75 | local context = require("resty.router.context") 76 | 77 | local s = schema.new() 78 | 79 | s:add_field("http.path", "String") 80 | s:add_field("tcp.port", "Int") 81 | 82 | local r = router.new(s) 83 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 84 | "http.path contains \"keyword\" && tcp.port == 80")) 85 | 86 | local c = context.new(s) 87 | c:add_value("http.path", "/foo/bar") 88 | c:add_value("tcp.port", 80) 89 | 90 | local matched = r:execute(c) 91 | ngx.say(matched) 92 | 93 | local uuid, matched_value = c:get_result("http.path") 94 | ngx.say(uuid) 95 | ngx.say(matched_value) 96 | } 97 | } 98 | --- request 99 | GET /t 100 | --- response_body 101 | false 102 | nil 103 | nil 104 | --- no_error_log 105 | [error] 106 | [warn] 107 | [crit] 108 | 109 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: {} 5 | push: 6 | branches: 7 | - main 8 | 9 | # cancel previous runs if new commits are pushed to the PR, but run for each commit on master 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | tests: 16 | name: Tests 17 | runs-on: ubuntu-22.04 18 | 19 | strategy: 20 | matrix: 21 | openresty: 22 | - '1.27.1.2' 23 | - '1.27.1.1' 24 | - '1.25.3.2' 25 | - '1.25.3.1' 26 | - '1.21.4.4' 27 | - '1.21.4.3' 28 | - '1.21.4.2' 29 | - '1.21.4.1' 30 | env: 31 | JOBS: 1 32 | 33 | OPENRESTY: ${{ matrix.openresty }} 34 | CODE_PATH: ${{ github.workspace }} 35 | BASE_PATH: /home/runner/work/cache 36 | 37 | steps: 38 | - name: Checkout source code 39 | uses: actions/checkout@v4 40 | with: 41 | submodules: recursive 42 | token: ${{ secrets.GHA_KONG_BOT_READ_TOKEN }} 43 | 44 | - name: Make sure Cargo can clone private repositories 45 | run: | 46 | git config --global url."https://${{ secrets.GHA_KONG_BOT_READ_TOKEN }}@github.com".insteadOf https://github.com 47 | 48 | - name: Setup cache 49 | uses: actions/cache@v4 50 | id: cache-deps 51 | with: 52 | path: | 53 | ${{ env.BASE_PATH }} 54 | key: ${{ runner.os }}-${{ hashFiles('Makefile') }}-${{ hashFiles('**/tests.yml') }}-openresty-${{ matrix.openresty }} 55 | 56 | - name: Install packages 57 | run: | 58 | sudo apt update 59 | sudo apt-get install -qq -y wget cpanminus net-tools libpcre3-dev build-essential valgrind 60 | if [ ! -e perl ]; then sudo cpanm --notest Test::Nginx > build.log 2>&1 || (cat build.log && exit 1); cp -r /usr/local/share/perl/ .; else sudo cp -r perl /usr/local/share; fi 61 | 62 | - name: Download OpenResty 63 | if: steps.cache-deps.outputs.cache-hit != 'true' 64 | run: | 65 | wget https://openresty.org/download/openresty-${OPENRESTY}.tar.gz 66 | mkdir -p ${BASE_PATH} 67 | tar xfz openresty-${OPENRESTY}.tar.gz -C ${BASE_PATH} 68 | 69 | - name: Setup tools 70 | if: steps.cache-deps.outputs.cache-hit != 'true' 71 | run: | 72 | cd ${BASE_PATH}/openresty-${OPENRESTY} 73 | ./configure --prefix=${BASE_PATH}/openresty --with-debug 74 | sudo make -j$(nproc) && make install -j$(nproc) 75 | 76 | - name: Run Test 77 | run: | 78 | export PATH=${BASE_PATH}/openresty/bin:$PATH 79 | openresty -V 80 | make test OPENRESTY_PREFIX=${BASE_PATH}/openresty 81 | 82 | - name: Run Valgrind 83 | run: | 84 | export PATH=${BASE_PATH}/openresty/bin:$PATH 85 | export TEST_NGINX_VALGRIND='--num-callers=100 -q --tool=memcheck --leak-check=full --show-possibly-lost=no --gen-suppressions=all --suppressions=valgrind.suppress --track-origins=yes' TEST_NGINX_TIMEOUT=120 TEST_NGINX_SLEEP=1 86 | export TEST_NGINX_USE_VALGRIND=1 87 | openresty -V 88 | make valgrind OPENRESTY_PREFIX=${BASE_PATH}/openresty 89 | -------------------------------------------------------------------------------- /t/07-in_notin.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: in operator has correct type check 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | local context = require("resty.router.context") 32 | 33 | local s = schema.new() 34 | 35 | s:add_field("http.path", "String") 36 | s:add_field("tcp.port", "Int") 37 | 38 | local r = router.new(s) 39 | ngx.say(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 40 | "tcp.port in 80")) 41 | 42 | ngx.say(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 43 | "http.path in 80")) 44 | 45 | ngx.say(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 46 | "http.path in \"foo\"")) 47 | } 48 | } 49 | --- request 50 | GET /t 51 | --- response_body 52 | nilIn/NotIn operators only supports IP in CIDR 53 | nilIn/NotIn operators only supports IP in CIDR 54 | nilIn/NotIn operators only supports IP in CIDR 55 | --- no_error_log 56 | [error] 57 | [warn] 58 | [crit] 59 | 60 | 61 | 62 | === TEST 2: in operator works with IPAddr and IpCidr operands 63 | --- http_config eval: $::HttpConfig 64 | --- config 65 | location = /t { 66 | content_by_lua_block { 67 | local schema = require("resty.router.schema") 68 | local router = require("resty.router.router") 69 | local context = require("resty.router.context") 70 | 71 | local s = schema.new() 72 | 73 | s:add_field("l3.ip", "IpAddr") 74 | 75 | local r = router.new(s) 76 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 77 | "l3.ip in 192.168.12.0/24")) 78 | 79 | local c = context.new(s) 80 | c:add_value("l3.ip", "192.168.12.1") 81 | 82 | local matched = r:execute(c) 83 | ngx.say(matched) 84 | 85 | c = context.new(s) 86 | c:add_value("l3.ip", "192.168.1.1") 87 | 88 | local matched = r:execute(c) 89 | ngx.say(matched) 90 | 91 | assert(r:remove_matcher("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c")) 92 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 93 | "l3.ip not in 192.168.12.0/24")) 94 | local matched = r:execute(c) 95 | ngx.say(matched) 96 | } 97 | } 98 | --- request 99 | GET /t 100 | --- response_body 101 | true 102 | false 103 | true 104 | --- no_error_log 105 | [error] 106 | [warn] 107 | [crit] 108 | -------------------------------------------------------------------------------- /lib/resty/router/router.lua: -------------------------------------------------------------------------------- 1 | local _M = {} 2 | local _MT = { __index = _M, } 3 | 4 | 5 | local ffi = require("ffi") 6 | local base = require("resty.core.base") 7 | local cdefs = require("resty.router.cdefs") 8 | local tb_new = require("table.new") 9 | local get_string_buf = base.get_string_buf 10 | local get_size_ptr = base.get_size_ptr 11 | local ffi_string = ffi.string 12 | local ffi_new = ffi.new 13 | local ffi_gc = ffi.gc 14 | local assert = assert 15 | local tonumber = tonumber 16 | local setmetatable = setmetatable 17 | 18 | 19 | local ERR_BUF_MAX_LEN = cdefs.ERR_BUF_MAX_LEN 20 | local clib = cdefs.clib 21 | local router_free = cdefs.router_free 22 | 23 | 24 | function _M.new(schema, routes_n) 25 | local router = clib.router_new(schema.schema) 26 | -- Note on this weird looking finalizer: 27 | -- 28 | -- You may be tempted to change it to ffi_gc(router, clib.router_free) 29 | -- This isn't 100% safe, particularly with `busted` clearing the global 30 | -- environment between each runs. `clib` could be GC'ed before this entity, 31 | -- causing instruction fetch faults because the `router` finalizer will 32 | -- attempt to execute from unmapped memory region 33 | local r = setmetatable({ 34 | router = ffi_gc(router, router_free), 35 | schema = schema, 36 | priorities = tb_new(0, routes_n or 10), 37 | }, _MT) 38 | 39 | return r 40 | end 41 | 42 | 43 | function _M:add_matcher(priority, uuid, atc) 44 | local errbuf = get_string_buf(ERR_BUF_MAX_LEN) 45 | local errbuf_len = get_size_ptr() 46 | errbuf_len[0] = ERR_BUF_MAX_LEN 47 | 48 | if clib.router_add_matcher(self.router, priority, uuid, atc, errbuf, errbuf_len) == false then 49 | return nil, ffi_string(errbuf, errbuf_len[0]) 50 | end 51 | 52 | self.priorities[uuid] = priority 53 | 54 | return true 55 | end 56 | 57 | 58 | function _M:remove_matcher(uuid) 59 | local priority = self.priorities[uuid] 60 | if not priority then 61 | return false 62 | end 63 | 64 | self.priorities[uuid] = nil 65 | 66 | return clib.router_remove_matcher(self.router, priority, uuid) == true 67 | end 68 | 69 | 70 | function _M:execute(context) 71 | assert(context.schema == self.schema) 72 | return clib.router_execute(self.router, context.context) == true 73 | end 74 | 75 | 76 | function _M:get_fields() 77 | local out = {} 78 | local out_n = 0 79 | local router = self.router 80 | 81 | local total = tonumber(clib.router_get_fields(router, nil, nil)) 82 | if total == 0 then 83 | return out 84 | end 85 | 86 | local fields = ffi_new("const uint8_t *[?]", total) 87 | local fields_len = ffi_new("size_t [?]", total) 88 | fields_len[0] = total 89 | 90 | clib.router_get_fields(router, fields, fields_len) 91 | 92 | for i = 0, total - 1 do 93 | out_n = out_n + 1 94 | out[out_n] = ffi_string(fields[i], fields_len[i]) 95 | end 96 | 97 | return out 98 | end 99 | 100 | 101 | do 102 | local ROUTERS = setmetatable({}, { __mode = "k" }) 103 | local DEFAULT_UUID = "00000000-0000-0000-0000-000000000000" 104 | local DEFAULT_PRIORITY = 0 105 | 106 | -- validate an expression against a schema 107 | -- @param expr the expression to validate 108 | -- @param schema the schema to validate against 109 | -- @return true if the expression is valid, (nil, error) otherwise 110 | function _M.validate(schema, expr) 111 | local r = ROUTERS[schema] 112 | 113 | if not r then 114 | r = _M.new(schema, 1) 115 | ROUTERS[schema] = r 116 | end 117 | 118 | local ok, err = r:add_matcher(DEFAULT_PRIORITY, DEFAULT_UUID, expr) 119 | if not ok then 120 | return nil, err 121 | end 122 | 123 | local fields = r:get_fields() 124 | 125 | assert(r:remove_matcher(DEFAULT_UUID)) 126 | 127 | return fields 128 | end 129 | end 130 | 131 | 132 | return _M 133 | -------------------------------------------------------------------------------- /t/05-equals.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: multi value field 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | local context = require("resty.router.context") 32 | 33 | local s = schema.new() 34 | 35 | s:add_field("http.headers.foo", "String") 36 | 37 | local r = router.new(s) 38 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 39 | "http.headers.foo == \"bar\"")) 40 | 41 | local c = context.new(s) 42 | c:add_value("http.headers.foo", "bar") 43 | c:add_value("http.headers.foo", "bar") 44 | c:add_value("http.headers.foo", "bar") 45 | 46 | local matched = r:execute(c) 47 | ngx.say(matched) 48 | 49 | local uuid, prefix = c:get_result("http.headers.foo") 50 | ngx.say(uuid) 51 | ngx.say(prefix) 52 | } 53 | } 54 | --- request 55 | GET /t 56 | --- response_body 57 | true 58 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 59 | bar 60 | --- no_error_log 61 | [error] 62 | [warn] 63 | [crit] 64 | 65 | 66 | 67 | 68 | === TEST 2: multi value field expect mismatch 69 | --- http_config eval: $::HttpConfig 70 | --- config 71 | location = /t { 72 | content_by_lua_block { 73 | local schema = require("resty.router.schema") 74 | local router = require("resty.router.router") 75 | local context = require("resty.router.context") 76 | 77 | local s = schema.new() 78 | 79 | s:add_field("http.headers.foo", "String") 80 | 81 | local r = router.new(s) 82 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 83 | "http.headers.foo == \"bar\"")) 84 | 85 | local c = context.new(s) 86 | c:add_value("http.headers.foo", "bar") 87 | c:add_value("http.headers.foo", "bar") 88 | c:add_value("http.headers.foo", "barX") 89 | 90 | local matched = r:execute(c) 91 | ngx.say(matched) 92 | 93 | local uuid, prefix = c:get_result("http.headers.foo") 94 | ngx.say(uuid) 95 | ngx.say(prefix) 96 | } 97 | } 98 | --- request 99 | GET /t 100 | --- response_body 101 | false 102 | nil 103 | nil 104 | --- no_error_log 105 | [error] 106 | [warn] 107 | [crit] 108 | 109 | 110 | === TEST 3: empty value 111 | --- http_config eval: $::HttpConfig 112 | --- config 113 | location = /t { 114 | content_by_lua_block { 115 | local schema = require("resty.router.schema") 116 | local router = require("resty.router.router") 117 | local context = require("resty.router.context") 118 | 119 | local s = schema.new() 120 | 121 | s:add_field("http.headers.foo", "String") 122 | 123 | local r = router.new(s) 124 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 125 | "http.headers.foo == \"bar\"")) 126 | 127 | local c = context.new(s) 128 | 129 | local matched = r:execute(c) 130 | ngx.say(matched) 131 | 132 | local uuid, prefix = c:get_result("http.headers.foo") 133 | ngx.say(uuid) 134 | ngx.say(prefix) 135 | } 136 | } 137 | --- request 138 | GET /t 139 | --- response_body 140 | false 141 | nil 142 | nil 143 | --- no_error_log 144 | [error] 145 | [warn] 146 | [crit] 147 | -------------------------------------------------------------------------------- /lib/resty/router/context.lua: -------------------------------------------------------------------------------- 1 | local _M = {} 2 | local _MT = { __index = _M, } 3 | 4 | 5 | local ffi = require("ffi") 6 | local base = require("resty.core.base") 7 | local cdefs = require("resty.router.cdefs") 8 | 9 | 10 | local ffi_new = ffi.new 11 | local ffi_gc = ffi.gc 12 | local get_string_buf = base.get_string_buf 13 | local get_size_ptr = base.get_size_ptr 14 | local ffi_string = ffi.string 15 | local tonumber = tonumber 16 | local setmetatable = setmetatable 17 | local new_tab = require("table.new") 18 | local C = ffi.C 19 | 20 | 21 | local UUID_LEN = 36 -- hexadecimal representation of UUID 22 | local CACHED_VALUE = ffi_new("CValue[1]") 23 | local UUID_BUF = ffi_new("uint8_t[?]", UUID_LEN) 24 | local ERR_BUF_MAX_LEN = cdefs.ERR_BUF_MAX_LEN 25 | local clib = cdefs.clib 26 | local context_free = cdefs.context_free 27 | 28 | 29 | function _M.new(schema) 30 | local context = clib.context_new(schema.schema) 31 | local c = setmetatable({ 32 | context = ffi_gc(context, context_free), 33 | schema = schema, 34 | }, _MT) 35 | 36 | return c 37 | end 38 | 39 | 40 | function _M:add_value(field, value) 41 | if not value then 42 | return true 43 | end 44 | 45 | local typ, err = self.schema:get_field_type(field) 46 | if not typ then 47 | return nil, err 48 | end 49 | 50 | if typ == "String" then 51 | CACHED_VALUE[0].tag = C.CValue_Str 52 | CACHED_VALUE[0].str._0 = value 53 | CACHED_VALUE[0].str._1 = #value 54 | 55 | elseif typ == "IpAddr" then 56 | CACHED_VALUE[0].tag = C.CValue_IpAddr 57 | CACHED_VALUE[0].ip_addr = value 58 | 59 | elseif typ == "Int" then 60 | CACHED_VALUE[0].tag = C.CValue_Int 61 | CACHED_VALUE[0].int_ = value 62 | end 63 | 64 | local errbuf = get_string_buf(ERR_BUF_MAX_LEN) 65 | local errbuf_len = get_size_ptr() 66 | errbuf_len[0] = ERR_BUF_MAX_LEN 67 | 68 | if clib.context_add_value(self.context, field, CACHED_VALUE, errbuf, errbuf_len) == false then 69 | return nil, ffi_string(errbuf, errbuf_len[0]) 70 | end 71 | 72 | return true 73 | end 74 | 75 | 76 | function _M:get_result(matched_field) 77 | local captures_len = tonumber(clib.context_get_result( 78 | self.context, nil, nil, nil, nil, nil, nil, nil, nil)) 79 | if captures_len == -1 then 80 | return nil 81 | end 82 | 83 | local matched_value_buf, matched_value_len 84 | if matched_field then 85 | matched_value_buf = ffi_new("const uint8_t *[1]") 86 | matched_value_len = ffi_new("size_t [1]") 87 | end 88 | 89 | local capture_names, capture_names_len, capture_values, capture_values_len 90 | if captures_len > 0 then 91 | capture_names = ffi_new("const uint8_t *[?]", captures_len) 92 | capture_names_len = ffi_new("size_t [?]", captures_len) 93 | capture_values = ffi_new("const uint8_t *[?]", captures_len) 94 | capture_values_len = ffi_new("size_t [?]", captures_len) 95 | 96 | capture_names_len[0] = captures_len 97 | capture_values_len[0] = captures_len 98 | end 99 | 100 | clib.context_get_result(self.context, UUID_BUF, matched_field, 101 | matched_value_buf, matched_value_len, 102 | capture_names, capture_names_len, capture_values, 103 | capture_values_len) 104 | 105 | local uuid = ffi_string(UUID_BUF, UUID_LEN) 106 | local matched_value 107 | if matched_field then 108 | matched_value = matched_value_len[0] > 0 and 109 | ffi_string(matched_value_buf[0], matched_value_len[0]) or 110 | nil 111 | end 112 | 113 | local captures 114 | 115 | if captures_len > 0 then 116 | captures = new_tab(0, captures_len) 117 | 118 | for i = 0, captures_len - 1 do 119 | local name = ffi_string(capture_names[i], capture_names_len[i]) 120 | local value = ffi_string(capture_values[i], capture_values_len[i]) 121 | 122 | local num = tonumber(name, 10) 123 | if num then 124 | name = num 125 | end 126 | 127 | captures[name] = value 128 | end 129 | end 130 | 131 | return uuid, matched_value, captures 132 | end 133 | 134 | 135 | function _M:reset() 136 | clib.context_reset(self.context) 137 | end 138 | 139 | return _M 140 | -------------------------------------------------------------------------------- /lib/resty/router/cdefs.lua: -------------------------------------------------------------------------------- 1 | local ffi = require("ffi") 2 | 3 | 4 | -- generated from "cbindgen -l c", do not edit manually 5 | ffi.cdef([[ 6 | typedef enum Type { 7 | String, 8 | IpCidr, 9 | IpAddr, 10 | Int, 11 | Regex, 12 | } Type; 13 | 14 | typedef struct Context Context; 15 | 16 | typedef struct Router Router; 17 | 18 | typedef struct Schema Schema; 19 | 20 | typedef enum CValue_Tag { 21 | CValue_Str, 22 | CValue_IpCidr, 23 | CValue_IpAddr, 24 | CValue_Int, 25 | } CValue_Tag; 26 | 27 | typedef struct CValue_Str_Body { 28 | const uint8_t *_0; 29 | uintptr_t _1; 30 | } CValue_Str_Body; 31 | 32 | typedef struct CValue { 33 | CValue_Tag tag; 34 | union { 35 | CValue_Str_Body str; 36 | struct { 37 | const uint8_t *ip_cidr; 38 | }; 39 | struct { 40 | const uint8_t *ip_addr; 41 | }; 42 | struct { 43 | int64_t int_; 44 | }; 45 | }; 46 | } CValue; 47 | 48 | struct Schema *schema_new(void); 49 | 50 | void schema_free(struct Schema *schema); 51 | 52 | void schema_add_field(struct Schema *schema, const int8_t *field, enum Type typ); 53 | 54 | struct Router *router_new(const struct Schema *schema); 55 | 56 | void router_free(struct Router *router); 57 | 58 | bool router_add_matcher(struct Router *router, 59 | uintptr_t priority, 60 | const int8_t *uuid, 61 | const int8_t *atc, 62 | uint8_t *errbuf, 63 | uintptr_t *errbuf_len); 64 | 65 | bool router_remove_matcher(struct Router *router, uintptr_t priority, const int8_t *uuid); 66 | 67 | bool router_execute(const struct Router *router, struct Context *context); 68 | 69 | uintptr_t router_get_fields(const struct Router *router, 70 | const uint8_t **fields, 71 | uintptr_t *fields_len); 72 | 73 | struct Context *context_new(const struct Schema *schema); 74 | 75 | void context_free(struct Context *context); 76 | 77 | bool context_add_value(struct Context *context, 78 | const int8_t *field, 79 | const struct CValue *value, 80 | uint8_t *errbuf, 81 | uintptr_t *errbuf_len); 82 | 83 | void context_reset(struct Context *context); 84 | 85 | intptr_t context_get_result(const struct Context *context, 86 | uint8_t *uuid_hex, 87 | const int8_t *matched_field, 88 | const uint8_t **matched_value, 89 | uintptr_t *matched_value_len, 90 | const uint8_t **capture_names, 91 | uintptr_t *capture_names_len, 92 | const uint8_t **capture_values, 93 | uintptr_t *capture_values_len); 94 | ]]) 95 | 96 | 97 | local ERR_BUF_MAX_LEN = 4096 98 | 99 | 100 | -- From: https://github.com/openresty/lua-resty-signal/blob/master/lib/resty/signal.lua 101 | local load_shared_lib 102 | do 103 | local tostring = tostring 104 | local string_gmatch = string.gmatch 105 | local string_match = string.match 106 | local io_open = io.open 107 | local io_close = io.close 108 | local table_new = require("table.new") 109 | 110 | local cpath = package.cpath 111 | 112 | function load_shared_lib(so_name) 113 | local tried_paths = table_new(32, 0) 114 | local i = 1 115 | 116 | for k, _ in string_gmatch(cpath, "[^;]+") do 117 | local fpath = tostring(string_match(k, "(.*/)")) 118 | fpath = fpath .. so_name 119 | -- Don't get me wrong, the only way to know if a file exist is 120 | -- trying to open it. 121 | local f = io_open(fpath) 122 | if f ~= nil then 123 | io_close(f) 124 | return ffi.load(fpath) 125 | end 126 | 127 | tried_paths[i] = fpath 128 | i = i + 1 129 | end 130 | 131 | return nil, tried_paths 132 | end -- function 133 | end -- do 134 | 135 | local lib_name = ffi.os == "OSX" and "libatc_router.dylib" or "libatc_router.so" 136 | 137 | local clib, tried_paths = load_shared_lib(lib_name) 138 | if not clib then 139 | error(("could not load %s shared library from the following paths:\n"):format(lib_name) .. 140 | table.concat(tried_paths, "\n"), 2) 141 | end 142 | 143 | 144 | return { 145 | clib = clib, 146 | ERR_BUF_MAX_LEN = ERR_BUF_MAX_LEN, 147 | 148 | context_free = function(c) 149 | clib.context_free(c) 150 | end, 151 | 152 | schema_free = function(s) 153 | clib.schema_free(s) 154 | end, 155 | 156 | router_free = function(r) 157 | clib.router_free(r) 158 | end, 159 | } 160 | -------------------------------------------------------------------------------- /t/02-bugs.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: invalid UTF-8 sequence returns the decoding error 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local context = require("resty.router.context") 31 | 32 | local s = schema.new() 33 | 34 | s:add_field("http.path", "String") 35 | 36 | local BAD_UTF8 = { 37 | "\x80", 38 | "\xbf", 39 | "\xfc\x80\x80\x80\x80\xaf", 40 | } 41 | 42 | local c = context.new(s) 43 | for _, v in ipairs(BAD_UTF8) do 44 | local ok, err = c:add_value("http.path", v) 45 | ngx.say(err) 46 | end 47 | } 48 | } 49 | --- request 50 | GET /t 51 | --- response_body 52 | invalid utf-8 sequence of 1 bytes from index 0 53 | invalid utf-8 sequence of 1 bytes from index 0 54 | invalid utf-8 sequence of 1 bytes from index 0 55 | --- no_error_log 56 | [error] 57 | [warn] 58 | [crit] 59 | 60 | 61 | 62 | === TEST 2: NULL bytes does not cause UTF-8 issues (it is valid UTF-8) 63 | --- http_config eval: $::HttpConfig 64 | --- config 65 | location = /t { 66 | content_by_lua_block { 67 | local schema = require("resty.router.schema") 68 | local context = require("resty.router.context") 69 | 70 | local s = schema.new() 71 | 72 | s:add_field("http.path", "String") 73 | 74 | local c = context.new(s) 75 | assert(c:add_value("http.path", "\x00")) 76 | ngx.say("ok") 77 | } 78 | } 79 | --- request 80 | GET /t 81 | --- response_body 82 | ok 83 | --- no_error_log 84 | [error] 85 | [warn] 86 | [crit] 87 | 88 | 89 | 90 | === TEST 3: long strings don't cause a panic when parsing fails 91 | --- http_config eval: $::HttpConfig 92 | --- config 93 | location = /t { 94 | content_by_lua_block { 95 | local schema = require("resty.router.schema") 96 | local router = require("resty.router.router") 97 | 98 | local s = schema.new() 99 | s:add_field("http.path", "String") 100 | 101 | local r = router.new(s) 102 | local uuid = "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c" 103 | 104 | for _, len in ipairs({ 105 | 128, 106 | 256, 107 | 512, 108 | 1024, 109 | 2048, 110 | 4096, 111 | }) do 112 | local input = string.rep("a", len) 113 | local ok, err = r:add_matcher(0, uuid, input) 114 | assert(not ok, "expected add_matcher() to fail") 115 | assert(type(err) == "string", "expected an error string") 116 | end 117 | 118 | ngx.say("ok") 119 | } 120 | } 121 | --- request 122 | GET /t 123 | --- response_body 124 | ok 125 | --- no_error_log 126 | [error] 127 | [warn] 128 | [crit] 129 | 130 | 131 | 132 | === TEST 4: able to parse and handle string with NULL bytes inside 133 | --- http_config eval: $::HttpConfig 134 | --- config 135 | location = /t { 136 | content_by_lua_block { 137 | local schema = require("resty.router.schema") 138 | local router = require("resty.router.router") 139 | local context = require("resty.router.context") 140 | 141 | local s = schema.new() 142 | 143 | s:add_field("http.body", "String") 144 | 145 | local r = router.new(s) 146 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 147 | "http.body =^ \"world\"")) 148 | 149 | local c = context.new(s) 150 | c:add_value("http.body", "hello\x00world") 151 | 152 | local matched = r:execute(c) 153 | ngx.say(matched) 154 | 155 | local uuid = c:get_result("http.body") 156 | ngx.say(uuid) 157 | 158 | c:reset() 159 | c:add_value("http.body", "world\x00hello") 160 | 161 | local matched = r:execute(c) 162 | ngx.say(matched) 163 | } 164 | } 165 | --- request 166 | GET /t 167 | --- response_body 168 | true 169 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 170 | false 171 | --- no_error_log 172 | [error] 173 | [warn] 174 | [crit] 175 | -------------------------------------------------------------------------------- /t/08-equals.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: Equals/NotEquals works Int 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | local context = require("resty.router.context") 32 | 33 | local s = schema.new() 34 | 35 | s:add_field("net.port", "Int") 36 | 37 | local r = router.new(s) 38 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 39 | "net.port == 8000")) 40 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-8aa5583d150c", 41 | "net.port != 8000")) 42 | 43 | local c = context.new(s) 44 | c:add_value("net.port", 8000) 45 | 46 | local matched = r:execute(c) 47 | ngx.say(matched) 48 | ngx.say(c:get_result()) 49 | 50 | c = context.new(s) 51 | c:add_value("net.port", 8001) 52 | 53 | matched = r:execute(c) 54 | ngx.say(matched) 55 | ngx.say(c:get_result()) 56 | } 57 | } 58 | --- request 59 | GET /t 60 | --- response_body 61 | true 62 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150cnilnil 63 | true 64 | a921a9aa-ec0e-4cf3-a6cc-8aa5583d150cnilnil 65 | --- no_error_log 66 | [error] 67 | [warn] 68 | [crit] 69 | 70 | 71 | 72 | === TEST 2: Equals/NotEquals works String 73 | --- http_config eval: $::HttpConfig 74 | --- config 75 | location = /t { 76 | content_by_lua_block { 77 | local schema = require("resty.router.schema") 78 | local router = require("resty.router.router") 79 | local context = require("resty.router.context") 80 | 81 | local s = schema.new() 82 | 83 | s:add_field("http.path", "String") 84 | 85 | local r = router.new(s) 86 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 87 | "http.path == \"/foo\"")) 88 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-8aa5583d150c", 89 | "http.path != \"/foo\"")) 90 | 91 | local c = context.new(s) 92 | c:add_value("http.path", "/foo") 93 | 94 | local matched = r:execute(c) 95 | ngx.say(matched) 96 | ngx.say(c:get_result()) 97 | 98 | c = context.new(s) 99 | c:add_value("http.path", "/foo1") 100 | 101 | matched = r:execute(c) 102 | ngx.say(matched) 103 | ngx.say(c:get_result()) 104 | } 105 | } 106 | --- request 107 | GET /t 108 | --- response_body 109 | true 110 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150cnilnil 111 | true 112 | a921a9aa-ec0e-4cf3-a6cc-8aa5583d150cnilnil 113 | --- no_error_log 114 | [error] 115 | [warn] 116 | [crit] 117 | 118 | 119 | 120 | === TEST 3: Equals/NotEquals works IpAddr 121 | --- http_config eval: $::HttpConfig 122 | --- config 123 | location = /t { 124 | content_by_lua_block { 125 | local schema = require("resty.router.schema") 126 | local router = require("resty.router.router") 127 | local context = require("resty.router.context") 128 | 129 | local s = schema.new() 130 | 131 | s:add_field("net.ip", "IpAddr") 132 | 133 | local r = router.new(s) 134 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 135 | "net.ip == 192.168.1.1")) 136 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-8aa5583d150c", 137 | "net.ip != 192.168.1.1")) 138 | 139 | local c = context.new(s) 140 | c:add_value("net.ip", "192.168.1.1") 141 | 142 | local matched = r:execute(c) 143 | ngx.say(matched) 144 | ngx.say(c:get_result()) 145 | 146 | c = context.new(s) 147 | c:add_value("net.ip", "192.168.1.2") 148 | 149 | matched = r:execute(c) 150 | ngx.say(matched) 151 | ngx.say(c:get_result()) 152 | } 153 | } 154 | --- request 155 | GET /t 156 | --- response_body 157 | true 158 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150cnilnil 159 | true 160 | a921a9aa-ec0e-4cf3-a6cc-8aa5583d150cnilnil 161 | --- no_error_log 162 | [error] 163 | [warn] 164 | [crit] 165 | -------------------------------------------------------------------------------- /t/04-rawstr.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: rawstr 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | local context = require("resty.router.context") 32 | 33 | local s = schema.new() 34 | 35 | s:add_field("http.path", "String") 36 | s:add_field("tcp.port", "Int") 37 | 38 | local r = router.new(s) 39 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 40 | "http.path ^= r#\"/foo\"# && tcp.port == 80")) 41 | 42 | local c = context.new(s) 43 | c:add_value("http.path", "/foo/bar") 44 | c:add_value("tcp.port", 80) 45 | 46 | local matched = r:execute(c) 47 | ngx.say(matched) 48 | 49 | local uuid, prefix = c:get_result("http.path") 50 | ngx.say(uuid) 51 | ngx.say(prefix) 52 | } 53 | } 54 | --- request 55 | GET /t 56 | --- response_body 57 | true 58 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 59 | /foo 60 | --- no_error_log 61 | [error] 62 | [warn] 63 | [crit] 64 | 65 | 66 | 67 | === TEST 2: rawstr with quote inside 68 | --- http_config eval: $::HttpConfig 69 | --- config 70 | location = /t { 71 | content_by_lua_block { 72 | local schema = require("resty.router.schema") 73 | local router = require("resty.router.router") 74 | local context = require("resty.router.context") 75 | 76 | local s = schema.new() 77 | 78 | s:add_field("http.path", "String") 79 | s:add_field("tcp.port", "Int") 80 | 81 | local r = router.new(s) 82 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 83 | "http.path ^= r#\"/foo\"\'\"# && tcp.port == 80")) 84 | 85 | local c = context.new(s) 86 | c:add_value("http.path", "/foo\"\'/bar") 87 | c:add_value("tcp.port", 80) 88 | 89 | local matched = r:execute(c) 90 | ngx.say(matched) 91 | 92 | local uuid, prefix = c:get_result("http.path") 93 | ngx.say(uuid) 94 | ngx.say(prefix) 95 | } 96 | } 97 | --- request 98 | GET /t 99 | --- response_body 100 | true 101 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 102 | /foo"' 103 | --- no_error_log 104 | [error] 105 | [warn] 106 | [crit] 107 | 108 | 109 | 110 | 111 | === TEST 3: rawstr with regex inside 112 | --- http_config eval: $::HttpConfig 113 | --- config 114 | location = /t { 115 | content_by_lua_block { 116 | local schema = require("resty.router.schema") 117 | local router = require("resty.router.router") 118 | local context = require("resty.router.context") 119 | 120 | local s = schema.new() 121 | 122 | s:add_field("http.path", "String") 123 | s:add_field("tcp.port", "Int") 124 | 125 | local r = router.new(s) 126 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 127 | "http.path ~ r#\"^/\\d+/test$\"# && tcp.port == 80")) 128 | 129 | local c = context.new(s) 130 | c:add_value("http.path", "/123/test") 131 | c:add_value("tcp.port", 80) 132 | 133 | local matched = r:execute(c) 134 | ngx.say(matched) 135 | 136 | local uuid, prefix = c:get_result("http.path") 137 | ngx.say(uuid) 138 | ngx.say(prefix) 139 | } 140 | } 141 | --- request 142 | GET /t 143 | --- response_body 144 | true 145 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 146 | /123/test 147 | --- no_error_log 148 | [error] 149 | [warn] 150 | [crit] 151 | 152 | 153 | 154 | === TEST 4: rawstr with regex inside expect mismatch 155 | --- http_config eval: $::HttpConfig 156 | --- config 157 | location = /t { 158 | content_by_lua_block { 159 | local schema = require("resty.router.schema") 160 | local router = require("resty.router.router") 161 | local context = require("resty.router.context") 162 | 163 | local s = schema.new() 164 | 165 | s:add_field("http.path", "String") 166 | s:add_field("tcp.port", "Int") 167 | 168 | local r = router.new(s) 169 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 170 | "http.path ~ r#\"^/\\D+/test$\"# && tcp.port == 80")) 171 | 172 | local c = context.new(s) 173 | c:add_value("http.path", "/123/test") 174 | c:add_value("tcp.port", 80) 175 | 176 | local matched = r:execute(c) 177 | ngx.say(matched) 178 | 179 | local uuid, prefix = c:get_result("http.path") 180 | ngx.say(uuid) 181 | ngx.say(prefix) 182 | } 183 | } 184 | --- request 185 | GET /t 186 | --- response_body 187 | false 188 | nil 189 | nil 190 | --- no_error_log 191 | [error] 192 | [warn] 193 | [crit] 194 | 195 | 196 | -------------------------------------------------------------------------------- /t/06-validate.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: test valid schema + expr 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | 32 | local s = schema.new() 33 | s:add_field("http.headers.foo", "String") 34 | 35 | local expr = "http.headers.foo == \"bar\"" 36 | local r, err = router.validate(s, expr) 37 | 38 | ngx.say(type(r)) 39 | ngx.say(err) 40 | ngx.say(#r) 41 | ngx.say(r[1]) 42 | } 43 | } 44 | --- request 45 | GET /t 46 | --- response_body 47 | table 48 | nil 49 | 1 50 | http.headers.foo 51 | --- no_error_log 52 | [error] 53 | [warn] 54 | [crit] 55 | 56 | 57 | === TEST 2: test type inconsistency (schema is String, expr is Int) 58 | --- http_config eval: $::HttpConfig 59 | --- config 60 | location = /t { 61 | content_by_lua_block { 62 | local schema = require("resty.router.schema") 63 | local router = require("resty.router.router") 64 | 65 | local s = schema.new() 66 | s:add_field("http.headers.foo", "String") 67 | 68 | local expr = "http.headers.foo == 123" 69 | local r, err = router.validate(s, expr) 70 | 71 | ngx.say(r) 72 | ngx.say(err) 73 | } 74 | } 75 | --- request 76 | GET /t 77 | --- response_body_like 78 | nil 79 | Type mismatch between the LHS and RHS values of predicate 80 | --- no_error_log 81 | [error] 82 | [warn] 83 | [crit] 84 | 85 | 86 | === TEST 3: test invalid schema + invalid expr 87 | --- http_config eval: $::HttpConfig 88 | --- config 89 | location = /t { 90 | content_by_lua_block { 91 | local schema = require("resty.router.schema") 92 | local router = require("resty.router.router") 93 | 94 | local s = schema.new() 95 | s:add_field("http.headers.foo", "String") 96 | 97 | local expr = "== 123" 98 | local r, err = router.validate(s, expr) 99 | 100 | ngx.say(r) 101 | ngx.say(err) 102 | } 103 | } 104 | --- request 105 | GET /t 106 | --- response_body 107 | nil 108 | --> 1:1 109 | | 110 | 1 | == 123 111 | | ^--- 112 | | 113 | = expected term 114 | 115 | --- no_error_log 116 | [error] 117 | [warn] 118 | [crit] 119 | 120 | 121 | === TEST 4: test valid schema + invalid expr 122 | --- http_config eval: $::HttpConfig 123 | --- config 124 | location = /t { 125 | content_by_lua_block { 126 | local schema = require("resty.router.schema") 127 | local router = require("resty.router.router") 128 | 129 | local s = schema.new() 130 | s:add_field("http.headers.foo", "String") 131 | 132 | local expr = "== \"bar\"" 133 | local r, err = router.validate(s, expr) 134 | 135 | ngx.say(r) 136 | ngx.say(err) 137 | } 138 | } 139 | --- request 140 | GET /t 141 | --- response_body 142 | nil 143 | --> 1:1 144 | | 145 | 1 | == "bar" 146 | | ^--- 147 | | 148 | = expected term 149 | 150 | --- no_error_log 151 | [error] 152 | [warn] 153 | [crit] 154 | 155 | 156 | === TEST 5: valid regex 157 | --- http_config eval: $::HttpConfig 158 | --- config 159 | location = /t { 160 | content_by_lua_block { 161 | local schema = require("resty.router.schema") 162 | local router = require("resty.router.router") 163 | 164 | local s = schema.new() 165 | s:add_field("http.headers.foo", "String") 166 | 167 | local expr = [[http.headers.foo ~ "/\\\\/*user$"]] 168 | local r, err = router.validate(s, expr) 169 | ngx.say(type(r)) 170 | ngx.say(err) 171 | } 172 | } 173 | --- request 174 | GET /t 175 | --- response_body 176 | table 177 | nil 178 | --- no_error_log 179 | [error] 180 | [warn] 181 | [crit] 182 | 183 | 184 | === TEST 6: invalid regex 185 | --- http_config eval: $::HttpConfig 186 | --- config 187 | location = /t { 188 | content_by_lua_block { 189 | local schema = require("resty.router.schema") 190 | local router = require("resty.router.router") 191 | 192 | local s = schema.new() 193 | s:add_field("http.headers.foo", "String") 194 | 195 | local expr = [[http.headers.foo ~ "([."]] 196 | local r, err = router.validate(s, expr) 197 | ngx.say(r) 198 | ngx.say(err) 199 | } 200 | } 201 | --- request 202 | GET /t 203 | --- response_body 204 | nil 205 | --> 1:20 206 | | 207 | 1 | http.headers.foo ~ "([." 208 | | ^---^ 209 | | 210 | = regex parse error: 211 | ([. 212 | ^ 213 | error: unclosed character class 214 | 215 | --- no_error_log 216 | [error] 217 | [warn] 218 | [crit] 219 | 220 | 221 | === TEST 7: Rust regex 1.8.x will not think the regex is invalid 222 | --- http_config eval: $::HttpConfig 223 | --- config 224 | location = /t { 225 | content_by_lua_block { 226 | local schema = require("resty.router.schema") 227 | local router = require("resty.router.router") 228 | 229 | local s = schema.new() 230 | s:add_field("http.headers.foo", "String") 231 | 232 | local expr = [[http.headers.foo ~ "/\\/*user$"]] 233 | local r, err = router.validate(s, expr) 234 | ngx.say(type(r)) 235 | ngx.say(err) 236 | } 237 | } 238 | --- request 239 | GET /t 240 | --- response_body 241 | table 242 | nil 243 | --- no_error_log 244 | [error] 245 | [warn] 246 | [crit] 247 | 248 | 249 | 250 | === TEST 8: pratt parser propagates parser error 251 | --- http_config eval: $::HttpConfig 252 | --- config 253 | location = /t { 254 | content_by_lua_block { 255 | local schema = require("resty.router.schema") 256 | local router = require("resty.router.router") 257 | 258 | local s = schema.new() 259 | s:add_field("http.headers.foo", "String") 260 | 261 | local expr = [[http.headers.foo == "a" && http.headers.foo ~ "([."]] 262 | local r, err = router.validate(s, expr) 263 | ngx.say(r) 264 | ngx.say(err) 265 | } 266 | } 267 | --- request 268 | GET /t 269 | --- response_body 270 | nil 271 | --> 1:47 272 | | 273 | 1 | http.headers.foo == "a" && http.headers.foo ~ "([." 274 | | ^---^ 275 | | 276 | = regex parse error: 277 | ([. 278 | ^ 279 | error: unclosed character class 280 | 281 | --- no_error_log 282 | [error] 283 | [warn] 284 | [crit] 285 | -------------------------------------------------------------------------------- /t/01-sanity.t: -------------------------------------------------------------------------------- 1 | # vim:set ft= ts=4 sw=4 et: 2 | 3 | use Test::Nginx::Socket::Lua; 4 | use Cwd qw(cwd); 5 | 6 | repeat_each(2); 7 | 8 | plan tests => repeat_each() * blocks() * 5; 9 | 10 | my $pwd = cwd(); 11 | 12 | our $HttpConfig = qq{ 13 | lua_package_path "$pwd/lib/?.lua;;"; 14 | lua_package_cpath "$pwd/target/debug/?.so;;"; 15 | }; 16 | 17 | no_long_string(); 18 | no_diff(); 19 | 20 | run_tests(); 21 | 22 | __DATA__ 23 | 24 | === TEST 1: create schema, router, context 25 | --- http_config eval: $::HttpConfig 26 | --- config 27 | location = /t { 28 | content_by_lua_block { 29 | local schema = require("resty.router.schema") 30 | local router = require("resty.router.router") 31 | local context = require("resty.router.context") 32 | 33 | local s = schema.new() 34 | 35 | s:add_field("http.path", "String") 36 | s:add_field("tcp.port", "Int") 37 | 38 | local r = router.new(s) 39 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 40 | "http.path ^= \"/foo\" && tcp.port == 80")) 41 | 42 | local c = context.new(s) 43 | c:add_value("http.path", "/foo/bar") 44 | c:add_value("tcp.port", 80) 45 | 46 | local matched = r:execute(c) 47 | ngx.say(matched) 48 | 49 | local uuid, prefix = c:get_result("http.path") 50 | ngx.say(uuid) 51 | ngx.say(prefix) 52 | } 53 | } 54 | --- request 55 | GET /t 56 | --- response_body 57 | true 58 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 59 | /foo 60 | --- no_error_log 61 | [error] 62 | [warn] 63 | [crit] 64 | 65 | 66 | 67 | === TEST 2: multiple routes, different priority 68 | --- http_config eval: $::HttpConfig 69 | --- config 70 | location = /t { 71 | content_by_lua_block { 72 | local schema = require("resty.router.schema") 73 | local router = require("resty.router.router") 74 | local context = require("resty.router.context") 75 | 76 | local s = schema.new() 77 | 78 | s:add_field("http.path", "String") 79 | s:add_field("tcp.port", "Int") 80 | 81 | local r = router.new(s) 82 | assert(r:add_matcher(1, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 83 | "http.path ^= \"/foo\" && tcp.port == 80")) 84 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150d", 85 | "http.path ^= \"/\"")) 86 | 87 | local c = context.new(s) 88 | c:add_value("http.path", "/foo/bar") 89 | c:add_value("tcp.port", 80) 90 | 91 | local matched = r:execute(c) 92 | ngx.say(matched) 93 | 94 | 95 | local uuid, prefix = c:get_result("http.path") 96 | ngx.say("uuid = " .. uuid .. " prefix = " .. prefix) 97 | } 98 | } 99 | --- request 100 | GET /t 101 | --- response_body 102 | true 103 | uuid = a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c prefix = /foo 104 | --- no_error_log 105 | [error] 106 | [warn] 107 | [crit] 108 | 109 | 110 | 111 | === TEST 3: remove matcher 112 | --- http_config eval: $::HttpConfig 113 | --- config 114 | location = /t { 115 | content_by_lua_block { 116 | local schema = require("resty.router.schema") 117 | local router = require("resty.router.router") 118 | local context = require("resty.router.context") 119 | 120 | local s = schema.new() 121 | 122 | s:add_field("http.path", "String") 123 | s:add_field("tcp.port", "Int") 124 | 125 | local r = router.new(s) 126 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 127 | "http.path ^= \"/foo\" && tcp.port == 80")) 128 | 129 | local c = context.new(s) 130 | c:add_value("http.path", "/foo/bar") 131 | c:add_value("tcp.port", 80) 132 | 133 | local matched = r:execute(c) 134 | ngx.say(matched) 135 | 136 | local uuid, prefix = c:get_result("http.path") 137 | ngx.say(uuid) 138 | ngx.say(prefix) 139 | 140 | assert(r:remove_matcher("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c")) 141 | 142 | c = context.new(s) 143 | c:add_value("http.path", "/foo/bar") 144 | c:add_value("tcp.port", 80) 145 | 146 | matched = r:execute(c) 147 | ngx.say(matched) 148 | } 149 | } 150 | --- request 151 | GET /t 152 | --- response_body 153 | true 154 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 155 | /foo 156 | false 157 | --- no_error_log 158 | [error] 159 | [warn] 160 | [crit] 161 | 162 | 163 | 164 | === TEST 4: invalid ATC DSL 165 | --- http_config eval: $::HttpConfig 166 | --- config 167 | location = /t { 168 | content_by_lua_block { 169 | local schema = require("resty.router.schema") 170 | local router = require("resty.router.router") 171 | local context = require("resty.router.context") 172 | 173 | local s = schema.new() 174 | 175 | s:add_field("http.path", "String") 176 | s:add_field("tcp.port", "Int") 177 | 178 | local r = router.new(s) 179 | ngx.say(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 180 | "http.path = \"/foo\" && tcp.port == 80")) 181 | } 182 | } 183 | --- request 184 | GET /t 185 | --- response_body 186 | nil --> 1:11 187 | | 188 | 1 | http.path = "/foo" && tcp.port == 80 189 | | ^--- 190 | | 191 | = expected binary_operator 192 | --- no_error_log 193 | [error] 194 | [warn] 195 | [crit] 196 | 197 | 198 | 199 | === TEST 5: context:reset() 200 | --- http_config eval: $::HttpConfig 201 | --- config 202 | location = /t { 203 | content_by_lua_block { 204 | local schema = require("resty.router.schema") 205 | local router = require("resty.router.router") 206 | local context = require("resty.router.context") 207 | 208 | local s = schema.new() 209 | 210 | s:add_field("http.path", "String") 211 | s:add_field("tcp.port", "Int") 212 | 213 | local r = router.new(s) 214 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 215 | "http.path ^= \"/foo\" && tcp.port == 80")) 216 | 217 | local c = context.new(s) 218 | c:add_value("http.path", "/foo/bar") 219 | c:add_value("tcp.port", 80) 220 | 221 | local matched = r:execute(c) 222 | ngx.say(matched) 223 | 224 | local uuid, prefix = c:get_result("http.path") 225 | ngx.say(uuid) 226 | ngx.say(prefix) 227 | 228 | c:reset() 229 | 230 | local uuid, prefix = c:get_result("http.path") 231 | ngx.say(uuid) 232 | ngx.say(prefix) 233 | 234 | local matched = r:execute(c) 235 | ngx.say(matched) 236 | } 237 | } 238 | --- request 239 | GET /t 240 | --- response_body 241 | true 242 | a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c 243 | /foo 244 | nil 245 | nil 246 | false 247 | --- no_error_log 248 | [error] 249 | [warn] 250 | [crit] 251 | -------------------------------------------------------------------------------- /src/router.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::Expression; 2 | use crate::context::{Context, Match}; 3 | use crate::interpreter::Execute; 4 | use crate::parser::parse; 5 | use crate::schema::Schema; 6 | use crate::semantics::{FieldCounter, Validate}; 7 | use std::borrow::Borrow; 8 | use std::collections::{BTreeMap, HashMap}; 9 | use uuid::Uuid; 10 | 11 | #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] 12 | struct MatcherKey(usize, Uuid); 13 | 14 | #[derive(Debug)] 15 | pub struct Router { 16 | schema: S, 17 | matchers: BTreeMap, 18 | pub fields: HashMap, 19 | } 20 | 21 | impl Router 22 | where 23 | S: Borrow, 24 | { 25 | /// Creates a new [`Router`] that holds [`Borrow`]<[`Schema`]>. 26 | /// 27 | /// This provides flexibility to use different types of schema providers. 28 | pub fn new(schema: S) -> Self { 29 | Self { 30 | schema, 31 | matchers: BTreeMap::new(), 32 | fields: HashMap::new(), 33 | } 34 | } 35 | 36 | /// Returns a reference to the [`Schema`] used by this router. 37 | /// 38 | /// Especially useful when the router owns or wraps the schema, 39 | /// and you need to pass a reference to other components like [`Context`]. 40 | pub fn schema(&self) -> &Schema { 41 | self.schema.borrow() 42 | } 43 | 44 | pub fn add_matcher(&mut self, priority: usize, uuid: Uuid, atc: &str) -> Result<(), String> { 45 | let expr = parse(atc).map_err(|e| e.to_string())?; 46 | 47 | self.add_matcher_expr(priority, uuid, expr) 48 | } 49 | 50 | pub fn add_matcher_expr( 51 | &mut self, 52 | priority: usize, 53 | uuid: Uuid, 54 | expr: Expression, 55 | ) -> Result<(), String> { 56 | let key = MatcherKey(priority, uuid); 57 | 58 | if self.matchers.contains_key(&key) { 59 | return Err("UUID already exists".to_string()); 60 | } 61 | 62 | expr.validate(self.schema())?; 63 | expr.add_to_counter(&mut self.fields); 64 | 65 | assert!(self.matchers.insert(key, expr).is_none()); 66 | 67 | Ok(()) 68 | } 69 | 70 | pub fn remove_matcher(&mut self, priority: usize, uuid: Uuid) -> bool { 71 | let key = MatcherKey(priority, uuid); 72 | 73 | let Some(ast) = self.matchers.remove(&key) else { 74 | return false; 75 | }; 76 | 77 | ast.remove_from_counter(&mut self.fields); 78 | true 79 | } 80 | 81 | pub fn execute(&self, context: &mut Context) -> bool { 82 | let Some(m) = self.try_match(context) else { 83 | return false; 84 | }; 85 | 86 | context.result = Some(m); 87 | true 88 | } 89 | 90 | /// Note that unlike `execute`, this doesn't set `Context.result` 91 | /// but it also doesn't need a `&mut Context`. 92 | pub fn try_match(&self, context: &Context) -> Option { 93 | let mut mat = Match::new(); 94 | 95 | for (MatcherKey(_, id), m) in self.matchers.iter().rev() { 96 | if m.execute(context, &mut mat) { 97 | mat.uuid = *id; 98 | return Some(mat); 99 | } 100 | 101 | mat.reset(); 102 | } 103 | 104 | None 105 | } 106 | } 107 | 108 | #[cfg(test)] 109 | mod tests { 110 | use uuid::Uuid; 111 | 112 | use crate::{ast::Type, context::Context, schema::Schema}; 113 | 114 | use super::Router; 115 | 116 | use std::sync::Arc; 117 | 118 | #[test] 119 | fn execute_succeeds() { 120 | let mut schema = Schema::default(); 121 | schema.add_field("http.path", Type::String); 122 | 123 | let mut router = Router::new(&schema); 124 | router 125 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 126 | .expect("should add"); 127 | 128 | let mut ctx = Context::new(&schema); 129 | ctx.add_value("http.path", "/dev".to_owned().into()); 130 | assert!(router.execute(&mut ctx)); 131 | } 132 | 133 | #[test] 134 | fn execute_fails() { 135 | let mut schema = Schema::default(); 136 | schema.add_field("http.path", Type::String); 137 | 138 | let mut router = Router::new(&schema); 139 | router 140 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 141 | .expect("should add"); 142 | 143 | let mut ctx = Context::new(&schema); 144 | ctx.add_value("http.path", "/not-dev".to_owned().into()); 145 | assert!(!router.execute(&mut ctx)); 146 | } 147 | 148 | #[test] 149 | fn try_match_succeeds() { 150 | let mut schema = Schema::default(); 151 | schema.add_field("http.path", Type::String); 152 | 153 | let mut router = Router::new(&schema); 154 | router 155 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 156 | .expect("should add"); 157 | 158 | let mut ctx = Context::new(&schema); 159 | ctx.add_value("http.path", "/dev".to_owned().into()); 160 | router.try_match(&ctx).expect("matches"); 161 | } 162 | 163 | #[test] 164 | fn try_match_fails() { 165 | let mut schema = Schema::default(); 166 | schema.add_field("http.path", Type::String); 167 | 168 | let mut router = Router::new(&schema); 169 | router 170 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 171 | .expect("should add"); 172 | 173 | let mut ctx = Context::new(&schema); 174 | ctx.add_value("http.path", "/not-dev".to_owned().into()); 175 | router.try_match(&ctx).ok_or(()).expect_err("should fail"); 176 | } 177 | 178 | #[test] 179 | fn test_shared_schema_instantiation() { 180 | let mut schema = Schema::default(); 181 | schema.add_field("http.path", Type::String); 182 | 183 | let mut router = Router::new(&schema); 184 | router 185 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 186 | .expect("should add"); 187 | let mut ctx = Context::new(router.schema()); 188 | ctx.add_value("http.path", "/dev".to_owned().into()); 189 | router.try_match(&ctx).expect("matches"); 190 | } 191 | 192 | #[test] 193 | fn test_owned_schema_instantiation() { 194 | let mut schema = Schema::default(); 195 | schema.add_field("http.path", Type::String); 196 | 197 | let mut router = Router::new(schema); 198 | router 199 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 200 | .expect("should add"); 201 | let mut ctx = Context::new(router.schema()); 202 | ctx.add_value("http.path", "/dev".to_owned().into()); 203 | router.try_match(&ctx).expect("matches"); 204 | } 205 | 206 | #[test] 207 | fn test_arc_schema_instantiation() { 208 | let mut schema = Schema::default(); 209 | schema.add_field("http.path", Type::String); 210 | 211 | let mut router = Router::new(Arc::new(schema)); 212 | router 213 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 214 | .expect("should add"); 215 | let mut ctx = Context::new(router.schema()); 216 | ctx.add_value("http.path", "/dev".to_owned().into()); 217 | router.try_match(&ctx).expect("matches"); 218 | } 219 | 220 | #[test] 221 | fn test_box_schema_instantiation() { 222 | let mut schema = Schema::default(); 223 | schema.add_field("http.path", Type::String); 224 | 225 | let mut router = Router::new(Box::new(schema)); 226 | router 227 | .add_matcher(0, Uuid::default(), "http.path == \"/dev\"") 228 | .expect("should add"); 229 | let mut ctx = Context::new(router.schema()); 230 | ctx.add_value("http.path", "/dev".to_owned().into()); 231 | router.try_match(&ctx).expect("matches"); 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Name 2 | 3 | ATC Router library for Kong. 4 | 5 | # Table of Contents 6 | 7 | * [Name](#name) 8 | * [Semantics](#semantics) 9 | * [Synopsis](#synopsis) 10 | * [APIs](#apis) 11 | * [resty.router.schema](#restyrouterschema) 12 | * [new](#new) 13 | * [add\_field](#add_field) 14 | * [get\_field\_type](#get_field_type) 15 | * [resty.router.router](#restyrouterrouter) 16 | * [new](#new) 17 | * [add\_matcher](#add_matcher) 18 | * [remove\_matcher](#remove_matcher) 19 | * [execute](#execute) 20 | * [get\_fields](#get_fields) 21 | * [validate](#validate) 22 | * [resty.router.context](#restyroutercontext) 23 | * [new](#new) 24 | * [add\_value](#add_value) 25 | * [get\_result](#get_result) 26 | * [reset](#reset) 27 | * [Copyright and license](#copyright-and-license) 28 | 29 | # Semantics 30 | 31 | At the core of the library, ATC Router is a [DSL] that supports simple predicate 32 | and logical combinations between the predicates. 33 | 34 | [DSL]:https://en.wikipedia.org/wiki/Domain-specific_language 35 | 36 | Each data referred in the DSL has a type, the type can be one of the following: 37 | 38 | * `"String"` - a UTF-8 string value 39 | * `IpCidr` - an IP address range in CIDR format 40 | * `IpAddr` - a single IP address that can be checked against an `IpCidr` 41 | * `Int` - an 64-bit signed integer 42 | 43 | Please refer to the [documentation](https://docs.konghq.com/gateway/latest/reference/expressions-language/) 44 | on Kong website for how the language is used in practice. 45 | 46 | # Synopsis 47 | 48 | ``` 49 | lua_package_path '/path/to/atc-router/lib/?.lua;;'; 50 | 51 | # run `make build` to generate dynamic library 52 | 53 | lua_package_cpath '/path/to/atc-router/target/debug/?.so;;'; 54 | 55 | # A simple example creates schema, router and context, and use them to check if 56 | # "http.path" starts with "/foo" and if "tcp.port" equals 80. 57 | 58 | location = /simple_example { 59 | content_by_lua_block { 60 | local schema = require("resty.router.schema") 61 | local router = require("resty.router.router") 62 | local context = require("resty.router.context") 63 | 64 | local s = schema.new() 65 | 66 | s:add_field("http.path", "String") 67 | s:add_field("tcp.port", "Int") 68 | 69 | local r = router.new(s) 70 | assert(r:add_matcher(0, "a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c", 71 | "http.path ^= \"/foo\" && tcp.port == 80")) 72 | 73 | local c = context.new(s) 74 | c:add_value("http.path", "/foo/bar") 75 | c:add_value("tcp.port", 80) 76 | 77 | local matched = r:execute(c) 78 | ngx.say(matched) 79 | 80 | local uuid, prefix = c:get_result("http.path") 81 | ngx.say(uuid) 82 | ngx.say(prefix) 83 | } 84 | } 85 | ``` 86 | 87 | # APIs 88 | 89 | ## resty.router.schema 90 | 91 | ### new 92 | 93 | **syntax:** *s = schema.new()* 94 | 95 | **context:** *any* 96 | 97 | Create a new schema instance that can later be used by `router` and `context`. 98 | 99 | [Back to TOC](#table-of-contents) 100 | 101 | ### add\_field 102 | 103 | **syntax:** *res, err = s:add_field(field, field_type)* 104 | 105 | **context:** *any* 106 | 107 | Adds the field named `field` into the schema. Type can be one of the ones mentioned 108 | in the [Semantics](#semantics) section above. 109 | 110 | If an error occurred, `nil` and a string describing the error will be returned. 111 | 112 | [Back to TOC](#table-of-contents) 113 | 114 | ### get\_field\_type 115 | 116 | **syntax:** *typ, err = s:get_field_type(field)* 117 | 118 | **context:** *any* 119 | 120 | Gets the field type from the schema. 121 | 122 | If an error occurred, `nil` and a string describing the error will be returned. 123 | 124 | [Back to TOC](#table-of-contents) 125 | 126 | ## resty.router.router 127 | 128 | ### new 129 | 130 | **syntax:** *r = router.new(schema)* 131 | 132 | **context:** *any* 133 | 134 | Create a new router instance that can later be used for performing matches. `schema` 135 | must refer to an existing schema instance. 136 | 137 | [Back to TOC](#table-of-contents) 138 | 139 | ### add\_matcher 140 | 141 | **syntax:** *res, err = r:add_matcher(priority, uuid, atc)* 142 | 143 | **context:** *any* 144 | 145 | Add a matcher to the router. `priority` is a 64-bit unsigned integer that instructs 146 | the priority for which the matchers should be evaluated. `uuid` is the string 147 | representation of the UUID of the matcher which will be used later for match results. 148 | `atc` is the matcher written in ATC DSL syntax. 149 | 150 | If an error occurred or the matcher has syntax/semantics errors, 151 | `nil` and a string describing the error will be returned. 152 | 153 | [Back to TOC](#table-of-contents) 154 | 155 | ### remove\_matcher 156 | 157 | **syntax:** *res, err = r:remove_matcher(uuid)* 158 | 159 | **context:** *any* 160 | 161 | Remove matcher with `uuid` from the router. 162 | 163 | Returns `true` if the matcher has successfully been removed. `false` if the 164 | matcher does not exist. 165 | 166 | [Back to TOC](#table-of-contents) 167 | 168 | ### execute 169 | 170 | **syntax:** *res, err = r:execute(context)* 171 | 172 | **context:** *any* 173 | 174 | Executes the router against value provided inside the `context` instance. 175 | 176 | `context` must use the same schema as the router, otherwise Lua error will be thrown. 177 | 178 | Returns `true` if at least one matcher produced a valid match. `false` if the 179 | none of the matcher matched. 180 | 181 | [Back to TOC](#table-of-contents) 182 | 183 | ### get\_fields 184 | 185 | **syntax:** *res = r:get_fields()* 186 | 187 | **context:** *any* 188 | 189 | Returns the currently used field names by all matchers inside the router as 190 | an Lua array. It can help reduce unnecessarily producing values that are not 191 | actually used by the user supplied matchers. 192 | 193 | [Back to TOC](#table-of-contents) 194 | 195 | ### validate 196 | 197 | **syntax:** *fields, err = router.validate(schema, expr)* 198 | 199 | **context:** *any* 200 | 201 | Validates an expression against a given schema. 202 | 203 | Returns the fields used in the provided expression when the expression is valid. If the expression is invalid, 204 | `nil` and a string describing the reason will be returned. 205 | 206 | [Back to TOC](#table-of-contents) 207 | 208 | ## resty.router.context 209 | 210 | ### new 211 | 212 | **syntax:** *c = context.new(schema)* 213 | 214 | **context:** *any* 215 | 216 | Create a new context instance that can later be used for storing contextual information. 217 | for router matches. `schema` must refer to an existing schema instance. 218 | 219 | [Back to TOC](#table-of-contents) 220 | 221 | ### add\_value 222 | 223 | **syntax:** *res, err = c:add_value(field, value)* 224 | 225 | **context:** *any* 226 | 227 | Provides `value` for `field` inside the context. 228 | 229 | Returns `true` if field exists and value has successfully been provided. 230 | 231 | If an error occurred, `nil` and a string describing the error will be returned. 232 | 233 | [Back to TOC](#table-of-contents) 234 | 235 | ### get\_result 236 | 237 | **syntax:** *uuid, matched_value, captures = c:get_result(matched_field)* 238 | 239 | **context:** *any* 240 | 241 | After a successful router match, gets the match result from the context. 242 | 243 | If `matched_field` is provided, then `matched_value` will be returned with the value 244 | matched by the specified field. If `matched_field` is `nil` or field did 245 | not match, then `nil` is returned for `matched_value`. 246 | 247 | If the context did not contain a valid match result, `nil` is returned. 248 | 249 | Otherwise, the string UUID, value matching field `matched_field` and 250 | regex captures from the matched route are returned. 251 | 252 | [Back to TOC](#table-of-contents) 253 | 254 | ### reset 255 | 256 | **syntax:** *c:reset()* 257 | 258 | **context:** *any* 259 | 260 | This resets context `c` without deallocating the underlying memory 261 | so the context can be used again as if it was just created. 262 | 263 | [Back to TOC](#table-of-contents) 264 | 265 | # Copyright and license 266 | 267 | Copyright © 2022-2023 Kong, Inc. 268 | 269 | Licensed under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0). 270 | 271 | Files in the project may not be copied, modified, or distributed except according to those terms. 272 | 273 | [Back to TOC](#table-of-contents) 274 | 275 | -------------------------------------------------------------------------------- /src/semantics.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::{BinaryOperator, Expression, LogicalExpression, Type, Value}; 2 | use crate::schema::Schema; 3 | use std::collections::HashMap; 4 | 5 | type ValidationResult = Result<(), String>; 6 | type ValidationHashMap = HashMap; 7 | 8 | pub trait Validate { 9 | fn validate(&self, schema: &Schema) -> ValidationResult; 10 | } 11 | 12 | pub trait FieldCounter { 13 | fn add_to_counter(&self, map: &mut ValidationHashMap); 14 | fn remove_from_counter(&self, map: &mut ValidationHashMap); 15 | } 16 | 17 | impl FieldCounter for Expression { 18 | fn add_to_counter(&self, map: &mut ValidationHashMap) { 19 | use Expression::{Logical, Predicate}; 20 | use LogicalExpression::{And, Not, Or}; 21 | 22 | match self { 23 | Logical(l) => match l.as_ref() { 24 | And(l, r) | Or(l, r) => { 25 | l.add_to_counter(map); 26 | r.add_to_counter(map); 27 | } 28 | Not(r) => { 29 | r.add_to_counter(map); 30 | } 31 | }, 32 | Predicate(p) => { 33 | *map.entry(p.lhs.var_name.clone()).or_default() += 1; 34 | } 35 | } 36 | } 37 | 38 | fn remove_from_counter(&self, map: &mut ValidationHashMap) { 39 | use Expression::{Logical, Predicate}; 40 | use LogicalExpression::{And, Not, Or}; 41 | 42 | match self { 43 | Logical(l) => match l.as_ref() { 44 | And(l, r) | Or(l, r) => { 45 | l.remove_from_counter(map); 46 | r.remove_from_counter(map); 47 | } 48 | Not(r) => { 49 | r.remove_from_counter(map); 50 | } 51 | }, 52 | Predicate(p) => { 53 | let val = map.get_mut(&p.lhs.var_name).unwrap(); 54 | *val -= 1; 55 | 56 | if *val == 0 { 57 | assert!(map.remove(&p.lhs.var_name).is_some()); 58 | } 59 | } 60 | } 61 | } 62 | } 63 | 64 | fn raise_err(msg: &str) -> ValidationResult { 65 | Err(msg.to_string()) 66 | } 67 | 68 | const MSG_UNKNOWN_LHS: &str = "Unknown LHS field"; 69 | const MSG_TYPE_MISMATCH_LHS_RHS: &str = "Type mismatch between the LHS and RHS values of predicate"; 70 | const MSG_LOWER_ONLY_FOR_STRING: &str = 71 | "lower-case transformation function only supported with String type fields"; 72 | const MSG_REGEX_ONLY_FOR_STRING: &str = "Regex operators only supports string operands"; 73 | const MSG_PREFIX_POSTFIX_ONLY_FOR_STRING: &str = 74 | "Prefix/Postfix operators only supports string operands"; 75 | const MSG_ONLY_FOR_INT: &str = 76 | "Greater/GreaterOrEqual/Less/LessOrEqual operators only supports integer operands"; 77 | const MSG_ONLY_FOR_CIDR: &str = "In/NotIn operators only supports IP in CIDR"; 78 | const MSG_CONTAINS_ONLY_FOR_CIDR: &str = "Contains operator only supports string operands"; 79 | 80 | impl Validate for Expression { 81 | fn validate(&self, schema: &Schema) -> ValidationResult { 82 | use Expression::{Logical, Predicate}; 83 | use LogicalExpression::{And, Not, Or}; 84 | 85 | match self { 86 | Logical(l) => { 87 | match l.as_ref() { 88 | And(l, r) | Or(l, r) => { 89 | l.validate(schema)?; 90 | r.validate(schema)?; 91 | } 92 | Not(r) => { 93 | r.validate(schema)?; 94 | } 95 | } 96 | 97 | Ok(()) 98 | } 99 | Predicate(p) => { 100 | use BinaryOperator::{ 101 | Contains, Equals, Greater, GreaterOrEqual, In, Less, LessOrEqual, NotEquals, 102 | NotIn, Postfix, Prefix, Regex, 103 | }; 104 | 105 | // lhs and rhs must be the same type 106 | let Some(lhs_type) = p.lhs.my_type(schema) else { 107 | return raise_err(MSG_UNKNOWN_LHS); 108 | }; 109 | 110 | if p.op != Regex // Regex RHS is always Regex, and LHS is always String 111 | && p.op != In // In/NotIn supports IPAddr in IpCidr 112 | && p.op != NotIn 113 | && lhs_type != &p.rhs.my_type() 114 | { 115 | return raise_err(MSG_TYPE_MISMATCH_LHS_RHS); 116 | } 117 | 118 | let (lower, _any) = p.lhs.get_transformations(); 119 | 120 | // LHS transformations only makes sense with string fields 121 | if lower && lhs_type != &Type::String { 122 | return raise_err(MSG_LOWER_ONLY_FOR_STRING); 123 | } 124 | 125 | match p.op { 126 | Equals | NotEquals => Ok(()), 127 | Regex => { 128 | // unchecked path above 129 | match lhs_type { 130 | Type::String => Ok(()), 131 | _ => raise_err(MSG_REGEX_ONLY_FOR_STRING), 132 | } 133 | } 134 | Prefix | Postfix => match p.rhs { 135 | Value::String(_) => Ok(()), 136 | _ => raise_err(MSG_PREFIX_POSTFIX_ONLY_FOR_STRING), 137 | }, 138 | Greater | GreaterOrEqual | Less | LessOrEqual => match p.rhs { 139 | Value::Int(_) => Ok(()), 140 | _ => raise_err(MSG_ONLY_FOR_INT), 141 | }, 142 | In | NotIn => { 143 | // unchecked path above 144 | match (lhs_type, &p.rhs) { 145 | (Type::IpAddr, Value::IpCidr(_)) => Ok(()), 146 | _ => raise_err(MSG_ONLY_FOR_CIDR), 147 | } 148 | } 149 | Contains => match p.rhs { 150 | Value::String(_) => Ok(()), 151 | _ => raise_err(MSG_CONTAINS_ONLY_FOR_CIDR), 152 | }, 153 | } // match p.op 154 | } // Predicate(p) 155 | } // match self 156 | } // fn validate 157 | } 158 | 159 | #[cfg(test)] 160 | mod tests { 161 | use super::*; 162 | use crate::parser::parse; 163 | use lazy_static::lazy_static; 164 | 165 | lazy_static! { 166 | static ref SCHEMA: Schema = { 167 | let mut s = Schema::default(); 168 | s.add_field("string", Type::String); 169 | s.add_field("int", Type::Int); 170 | s.add_field("ipaddr", Type::IpAddr); 171 | s 172 | }; 173 | } 174 | 175 | #[test] 176 | fn unknown_field() { 177 | let expression = parse(r#"unkn == "abc""#).unwrap(); 178 | assert_eq!( 179 | expression.validate(&SCHEMA).unwrap_err(), 180 | "Unknown LHS field" 181 | ); 182 | } 183 | 184 | #[test] 185 | fn string_lhs() { 186 | let tests = vec![ 187 | r#"string == "abc""#, 188 | r#"string != "abc""#, 189 | r#"string ~ "abc""#, 190 | r#"string ^= "abc""#, 191 | r#"string =^ "abc""#, 192 | r#"lower(string) =^ "abc""#, 193 | ]; 194 | for input in tests { 195 | let expression = parse(input).unwrap(); 196 | expression.validate(&SCHEMA).unwrap(); 197 | } 198 | 199 | let failing_tests = vec![ 200 | r#"string == 192.168.0.1"#, 201 | r#"string == 192.168.0.0/24"#, 202 | r#"string == 123"#, 203 | r#"string in "abc""#, 204 | ]; 205 | for input in failing_tests { 206 | let expression = parse(input).unwrap(); 207 | assert!(expression.validate(&SCHEMA).is_err()); 208 | } 209 | } 210 | 211 | #[test] 212 | fn ipaddr_lhs() { 213 | let tests = vec![ 214 | r#"ipaddr == 192.168.0.1"#, 215 | r#"ipaddr == fd00::1"#, 216 | r#"ipaddr in 192.168.0.0/24"#, 217 | r#"ipaddr in fd00::/64"#, 218 | r#"ipaddr not in 192.168.0.0/24"#, 219 | r#"ipaddr not in fd00::/64"#, 220 | ]; 221 | for input in tests { 222 | let expression = parse(input).unwrap(); 223 | expression.validate(&SCHEMA).unwrap(); 224 | } 225 | 226 | let failing_tests = vec![ 227 | r#"ipaddr == "abc""#, 228 | r#"ipaddr == 123"#, 229 | r#"ipaddr in 192.168.0.1"#, 230 | r#"ipaddr in fd00::1"#, 231 | r#"ipaddr == 192.168.0.0/24"#, 232 | r#"ipaddr == fd00::/64"#, 233 | r#"lower(ipaddr) == fd00::1"#, 234 | ]; 235 | for input in failing_tests { 236 | let expression = parse(input).unwrap(); 237 | assert!(expression.validate(&SCHEMA).is_err()); 238 | } 239 | } 240 | 241 | #[test] 242 | fn int_lhs() { 243 | let tests = vec![ 244 | r#"int == 123"#, 245 | r#"int >= 123"#, 246 | r#"int <= 123"#, 247 | r#"int > 123"#, 248 | r#"int < 123"#, 249 | ]; 250 | for input in tests { 251 | let expression = parse(input).unwrap(); 252 | expression.validate(&SCHEMA).unwrap(); 253 | } 254 | 255 | let failing_tests = vec![ 256 | r#"int == "abc""#, 257 | r#"int in 192.168.0.0/24"#, 258 | r#"lower(int) == 123"#, 259 | ]; 260 | for input in failing_tests { 261 | let expression = parse(input).unwrap(); 262 | assert!(expression.validate(&SCHEMA).is_err()); 263 | } 264 | } 265 | } 266 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /src/ffi/context.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::Value; 2 | use crate::context::Context; 3 | use crate::ffi::{CValue, ERR_BUF_MAX_LEN}; 4 | use crate::schema::Schema; 5 | use std::cmp::min; 6 | use std::ffi; 7 | use std::os::raw::c_char; 8 | use std::slice::from_raw_parts_mut; 9 | use uuid::fmt::Hyphenated; 10 | 11 | /// Allocate a new context object associated with the schema. 12 | /// 13 | /// # Errors 14 | /// 15 | /// This function never returns an error, however, it can panic if memory allocation failed. 16 | /// 17 | /// # Safety 18 | /// 19 | /// Violating any of the following constraints will result in undefined behavior: 20 | /// 21 | /// - `schema` must be a valid pointer returned by [`schema_new`]. 22 | /// 23 | /// [`schema_new`]: crate::ffi::schema::schema_new 24 | #[no_mangle] 25 | pub unsafe extern "C" fn context_new(schema: &Schema) -> *mut Context<'_> { 26 | Box::into_raw(Box::new(Context::new(schema))) 27 | } 28 | 29 | /// Deallocate the context object. 30 | /// 31 | /// # Errors 32 | /// 33 | /// This function never fails. 34 | /// 35 | /// # Safety 36 | /// 37 | /// Violating any of the following constraints will result in undefined behavior: 38 | /// 39 | /// - `context` must be a valid pointer returned by [`context_new`]. 40 | #[no_mangle] 41 | pub unsafe extern "C" fn context_free(context: *mut Context) { 42 | drop(Box::from_raw(context)); 43 | } 44 | 45 | /// Add a value associated with a field to the context. 46 | /// This is useful when you want to match a value against a field in the schema. 47 | /// 48 | /// # Arguments 49 | /// 50 | /// - `context`: a pointer to the [`Context`] object. 51 | /// - `field`: the C-style string representing the field name. 52 | /// - `value`: the value to be added to the context. 53 | /// - `errbuf`: a buffer to store the error message. 54 | /// - `errbuf_len`: a pointer to the length of the error message buffer. 55 | /// 56 | /// # Returns 57 | /// 58 | /// Returns `true` if the value was added successfully, otherwise `false`, 59 | /// and the error message will be stored in the `errbuf`, 60 | /// and the length of the error message will be stored in `errbuf_len`. 61 | /// 62 | /// # Errors 63 | /// 64 | /// This function will return `false` if the value could not be added to the context, 65 | /// such as when a String value is not a valid UTF-8 string. 66 | /// 67 | /// # Panics 68 | /// 69 | /// This function will panic if the provided value does not match the schema. 70 | /// 71 | /// # Safety 72 | /// 73 | /// Violating any of the following constraints will result in undefined behavior: 74 | /// 75 | /// * `context` must be a valid pointer returned by [`context_new`]. 76 | /// * `field` must be a valid pointer to a C-style string, 77 | /// must be properply aligned, and must not have '\0' in the middle. 78 | /// * `value` must be a valid pointer to a [`CValue`]. 79 | /// * `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, 80 | /// and it must be properly aligned. 81 | /// * `errbuf_len` must be vlaid to read and write for `size_of::()` bytes, 82 | /// and it must be properly aligned. 83 | #[no_mangle] 84 | pub unsafe extern "C" fn context_add_value( 85 | context: &mut Context, 86 | field: *const i8, 87 | value: &CValue, 88 | errbuf: *mut u8, 89 | errbuf_len: *mut usize, 90 | ) -> bool { 91 | let field = ffi::CStr::from_ptr(field as *const c_char) 92 | .to_str() 93 | .unwrap(); 94 | let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); 95 | 96 | let value: Result = value.try_into(); 97 | if let Err(e) = value { 98 | let errlen = min(e.len(), *errbuf_len); 99 | errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); 100 | *errbuf_len = errlen; 101 | return false; 102 | } 103 | 104 | context.add_value(field, value.unwrap()); 105 | 106 | true 107 | } 108 | 109 | /// Reset the context so that it can be reused. 110 | /// This is useful when you want to reuse the same context for multiple matches. 111 | /// This will clear all the values that were added to the context, 112 | /// but keep the memory allocated for the context. 113 | /// 114 | /// # Errors 115 | /// 116 | /// This function never fails. 117 | /// 118 | /// # Safety 119 | /// 120 | /// Violating any of the following constraints will result in undefined behavior: 121 | /// 122 | /// - `context` must be a valid pointer returned by [`context_new`]. 123 | #[no_mangle] 124 | pub unsafe extern "C" fn context_reset(context: &mut Context) { 125 | context.reset(); 126 | } 127 | 128 | /// Get the result of the context. 129 | /// 130 | /// # Arguments 131 | /// 132 | /// - `context`: a pointer to the [`Context`] object. 133 | /// - `uuid_hex`: If not `NULL`, the UUID of the matched matcher will be stored. 134 | /// - `matched_field`: If not `NULL`, the field name (C-style string) of the matched value will be stored. 135 | /// - `matched_value`: If the `matched_field` is not `NULL`, the value of the matched field will be stored. 136 | /// - `matched_value_len`: If the `matched_field` is not `NULL`, the length of the value of the matched field will be stored. 137 | /// - `capture_names`: A pointer to an array of pointers to the capture names, each element is a non-C-style string pointer. 138 | /// - `capture_names_len`: A pointer to an array of the length of each capture name. 139 | /// - `capture_values`: A pointer to an array of pointers to the capture values, each element is a non-C-style string pointer. 140 | /// - `capture_values_len`: A pointer to an array of the length of each capture value. 141 | /// 142 | /// # Returns 143 | /// 144 | /// Returns the number of captures that are stored in the context. 145 | /// 146 | /// # Lifetimes 147 | /// 148 | /// The string pointers stored in `matched_value`, `capture_names`, and `capture_values` 149 | /// might be invalidated if any of the following operations are happened: 150 | /// 151 | /// - The `context` was deallocated. 152 | /// - The `context` was reset by [`context_reset`]. 153 | /// 154 | /// # Panics 155 | /// 156 | /// This function will panic if the `matched_field` is not a valid UTF-8 string. 157 | /// 158 | /// # Safety 159 | /// 160 | /// Violating any of the following constraints will result in undefined behavior: 161 | /// 162 | /// - `context` must be a valid pointer returned by [`context_new`], 163 | /// must be passed to [`router_execute`] before calling this function, 164 | /// and must not be reset by [`context_reset`] before calling this function. 165 | /// - If `uuid_hex` is not `NULL`, `uuid_hex` must be valid to read and write for 166 | /// `16 * size_of::()` bytes, and it must be properly aligned. 167 | /// - If `matched_field` is not `NULL`, 168 | /// `matched_field` must be a vlaid pointer to a C-style string, 169 | /// must be properly aligned, and must not have '\0' in the middle. 170 | /// - If `matched_value` is not `NULL`, 171 | /// `matched_value` must be valid to read and write for 172 | /// `mem::size_of::<*const u8>()` bytes, and it must be properly aligned. 173 | /// - If `matched_value` is not `NULL`, `matched_value_len` must be valid to read and write for 174 | /// `size_of::()` bytes, and it must be properly aligned. 175 | /// - If `uuid_hex` is not `NULL`, `capture_names` must be valid to read and write for 176 | /// ` * size_of::<*const u8>()` bytes, and it must be properly aligned. 177 | /// - If `uuid_hex` is not `NULL`, `capture_names_len` must be valid to read and write for 178 | /// ` * size_of::()` bytes, and it must be properly aligned. 179 | /// - If `uuid_hex` is not `NULL`, `capture_values` must be valid to read and write for 180 | /// ` * size_of::<*const u8>()` bytes, and it must be properly aligned. 181 | /// - If `uuid_hex` is not `NULL`, `capture_values_len` must be valid to read and write for 182 | /// ` * size_of::()` bytes, and it must be properly aligned. 183 | /// 184 | /// Note: You should get the `` by calling this function and set every pointer 185 | /// except the `context` to `NULL` to get the number of captures. 186 | /// 187 | /// [`router_execute`]: crate::ffi::router::router_execute 188 | #[no_mangle] 189 | pub unsafe extern "C" fn context_get_result( 190 | context: &Context, 191 | uuid_hex: *mut u8, 192 | matched_field: *const i8, 193 | matched_value: *mut *const u8, 194 | matched_value_len: *mut usize, 195 | capture_names: *mut *const u8, 196 | capture_names_len: *mut usize, 197 | capture_values: *mut *const u8, 198 | capture_values_len: *mut usize, 199 | ) -> isize { 200 | if context.result.is_none() { 201 | return -1; 202 | } 203 | 204 | if !uuid_hex.is_null() { 205 | let uuid_hex = from_raw_parts_mut(uuid_hex, Hyphenated::LENGTH); 206 | let res = context.result.as_ref().unwrap(); 207 | 208 | res.uuid.as_hyphenated().encode_lower(uuid_hex); 209 | 210 | if !matched_field.is_null() { 211 | let matched_field = ffi::CStr::from_ptr(matched_field as *const c_char) 212 | .to_str() 213 | .unwrap(); 214 | assert!(!matched_value.is_null()); 215 | assert!(!matched_value_len.is_null()); 216 | if let Some(Value::String(v)) = res.matches.get(matched_field) { 217 | *matched_value = v.as_bytes().as_ptr(); 218 | *matched_value_len = v.len(); 219 | } else { 220 | *matched_value_len = 0; 221 | } 222 | } 223 | 224 | if !context.result.as_ref().unwrap().captures.is_empty() { 225 | assert!(*capture_names_len >= res.captures.len()); 226 | assert!(*capture_names_len == *capture_values_len); 227 | assert!(!capture_names.is_null()); 228 | assert!(!capture_names_len.is_null()); 229 | assert!(!capture_values.is_null()); 230 | assert!(!capture_values_len.is_null()); 231 | 232 | let capture_names = from_raw_parts_mut(capture_names, *capture_names_len); 233 | let capture_names_len = from_raw_parts_mut(capture_names_len, *capture_names_len); 234 | let capture_values = from_raw_parts_mut(capture_values, *capture_values_len); 235 | let capture_values_len = from_raw_parts_mut(capture_values_len, *capture_values_len); 236 | 237 | for (i, (k, v)) in res.captures.iter().enumerate() { 238 | capture_names[i] = k.as_bytes().as_ptr(); 239 | capture_names_len[i] = k.len(); 240 | 241 | capture_values[i] = v.as_bytes().as_ptr(); 242 | capture_values_len[i] = v.len(); 243 | } 244 | } 245 | } 246 | 247 | context 248 | .result 249 | .as_ref() 250 | .unwrap() 251 | .captures 252 | .len() 253 | .try_into() 254 | .unwrap() 255 | } 256 | -------------------------------------------------------------------------------- /src/ffi/router.rs: -------------------------------------------------------------------------------- 1 | use crate::context::Context; 2 | use crate::ffi::ERR_BUF_MAX_LEN; 3 | use crate::router::Router; 4 | use crate::schema::Schema; 5 | use std::cmp::min; 6 | use std::ffi; 7 | use std::os::raw::c_char; 8 | use std::slice::from_raw_parts_mut; 9 | use uuid::Uuid; 10 | 11 | /// Create a new router object associated with the schema. 12 | /// 13 | /// # Arguments 14 | /// 15 | /// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. 16 | /// 17 | /// # Errors 18 | /// 19 | /// This function never fails. 20 | /// 21 | /// # Safety 22 | /// 23 | /// Violating any of the following constraints will result in undefined behavior: 24 | /// 25 | /// - `schema` must be a valid pointer returned by [`schema_new`]. 26 | /// 27 | /// [`schema_new`]: crate::ffi::schema::schema_new 28 | #[no_mangle] 29 | pub unsafe extern "C" fn router_new(schema: &Schema) -> *mut Router<&Schema> { 30 | Box::into_raw(Box::new(Router::new(schema))) 31 | } 32 | 33 | /// Deallocate the router object. 34 | /// 35 | /// # Errors 36 | /// 37 | /// This function never fails. 38 | /// 39 | /// # Safety 40 | /// 41 | /// Violating any of the following constraints will result in undefined behavior: 42 | /// 43 | /// - `router` must be a valid pointer returned by [`router_new`]. 44 | #[no_mangle] 45 | pub unsafe extern "C" fn router_free(router: *mut Router<&Schema>) { 46 | drop(Box::from_raw(router)); 47 | } 48 | 49 | /// Add a new matcher to the router. 50 | /// 51 | /// # Arguments 52 | /// 53 | /// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. 54 | /// - `priority`: the priority of the matcher, higher value means higher priority, 55 | /// and the matcher with the highest priority will be executed first. 56 | /// - `uuid`: the C-style string representing the UUID of the matcher. 57 | /// - `atc`: the C-style string representing the ATC expression. 58 | /// - `errbuf`: a buffer to store the error message. 59 | /// - `errbuf_len`: a pointer to the length of the error message buffer. 60 | /// 61 | /// # Returns 62 | /// 63 | /// Returns `true` if the matcher was added successfully, otherwise `false`, 64 | /// and the error message will be stored in the `errbuf`, 65 | /// and the length of the error message will be stored in `errbuf_len`. 66 | /// 67 | /// # Errors 68 | /// 69 | /// This function will return `false` if the matcher could not be added to the router, 70 | /// such as duplicate UUID, and invalid ATC expression. 71 | /// 72 | /// # Panics 73 | /// 74 | /// This function will panic when: 75 | /// 76 | /// - `uuid` doesn't point to a ASCII sequence representing a valid 128-bit UUID. 77 | /// - `atc` doesn't point to a valid C-style string. 78 | /// 79 | /// # Safety 80 | /// 81 | /// Violating any of the following constraints will result in undefined behavior: 82 | /// 83 | /// - `router` must be a valid pointer returned by [`router_new`]. 84 | /// - `uuid` must be a valid pointer to a C-style string, must be properly aligned, 85 | /// and must not have '\0' in the middle. 86 | /// - `atc` must be a valid pointer to a C-style string, must be properly aligned, 87 | /// and must not have '\0' in the middle. 88 | /// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, 89 | /// and it must be properly aligned. 90 | /// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, 91 | /// and it must be properly aligned. 92 | #[no_mangle] 93 | pub unsafe extern "C" fn router_add_matcher( 94 | router: &mut Router<&Schema>, 95 | priority: usize, 96 | uuid: *const i8, 97 | atc: *const i8, 98 | errbuf: *mut u8, 99 | errbuf_len: *mut usize, 100 | ) -> bool { 101 | let uuid = ffi::CStr::from_ptr(uuid as *const c_char).to_str().unwrap(); 102 | let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); 103 | let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); 104 | 105 | let uuid = Uuid::try_parse(uuid).expect("invalid UUID format"); 106 | 107 | if let Err(e) = router.add_matcher(priority, uuid, atc) { 108 | let errlen = min(e.len(), *errbuf_len); 109 | errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); 110 | *errbuf_len = errlen; 111 | return false; 112 | } 113 | 114 | true 115 | } 116 | 117 | /// Remove a matcher from the router. 118 | /// 119 | /// # Arguments 120 | /// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. 121 | /// - `priority`: the priority of the matcher to be removed. 122 | /// - `uuid`: the C-style string representing the UUID of the matcher to be removed. 123 | /// 124 | /// # Returns 125 | /// 126 | /// Returns `true` if the matcher was removed successfully, otherwise `false`, 127 | /// such as when the matcher with the specified UUID doesn't exist or 128 | /// the priority doesn't match the UUID. 129 | /// 130 | /// # Panics 131 | /// 132 | /// This function will panic when `uuid` doesn't point to a ASCII sequence 133 | /// 134 | /// # Safety 135 | /// 136 | /// Violating any of the following constraints will result in undefined behavior: 137 | /// 138 | /// - `router` must be a valid pointer returned by [`router_new`]. 139 | /// - `uuid` must be a valid pointer to a C-style string, must be properly aligned, 140 | /// and must not have '\0' in the middle. 141 | #[no_mangle] 142 | pub unsafe extern "C" fn router_remove_matcher( 143 | router: &mut Router<&Schema>, 144 | priority: usize, 145 | uuid: *const i8, 146 | ) -> bool { 147 | let uuid = ffi::CStr::from_ptr(uuid as *const c_char).to_str().unwrap(); 148 | let uuid = Uuid::try_parse(uuid).expect("invalid UUID format"); 149 | 150 | router.remove_matcher(priority, uuid) 151 | } 152 | 153 | /// Execute the router with the context. 154 | /// 155 | /// # Arguments 156 | /// 157 | /// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. 158 | /// - `context`: a pointer to the [`Context`] object. 159 | /// 160 | /// # Returns 161 | /// 162 | /// Returns `true` if found a match, `false` means no match found. 163 | /// 164 | /// # Safety 165 | /// 166 | /// Violating any of the following constraints will result in undefined behavior: 167 | /// 168 | /// - `router` must be a valid pointer returned by [`router_new`]. 169 | /// - `context` must be a valid pointer returned by [`context_new`], 170 | /// and must be reset by [`context_reset`] before calling this function 171 | /// if you want to reuse the same context for multiple matches. 172 | /// 173 | /// [`context_new`]: crate::ffi::context::context_new 174 | /// [`context_reset`]: crate::ffi::context::context_reset 175 | #[no_mangle] 176 | pub unsafe extern "C" fn router_execute(router: &Router<&Schema>, context: &mut Context) -> bool { 177 | router.execute(context) 178 | } 179 | 180 | /// Get the de-duplicated fields that are actually used in the router. 181 | /// This is useful when you want to know what fields are actually used in the router, 182 | /// so you can generate their values on-demand. 183 | /// 184 | /// # Arguments 185 | /// 186 | /// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. 187 | /// - `fields`: a pointer to an array of pointers to the field names 188 | /// (NOT C-style strings) that are actually used in the router, which will be filled in. 189 | /// if `fields` is `NULL`, this function will only return the number of fields used 190 | /// in the router. 191 | /// - `fields_len`: a pointer to an array of the length of each field name. 192 | /// 193 | /// # Lifetimes 194 | /// 195 | /// The string pointers stored in `fields` might be invalidated if any of the following 196 | /// operations are happened: 197 | /// 198 | /// - The `router` was deallocated. 199 | /// - A new matcher was added to the `router`. 200 | /// - A matcher was removed from the `router`. 201 | /// 202 | /// # Returns 203 | /// 204 | /// Returns the number of fields that are actually used in the router. 205 | /// 206 | /// # Errors 207 | /// 208 | /// This function never fails. 209 | /// 210 | /// # Safety 211 | /// 212 | /// Violating any of the following constraints will result in undefined behavior: 213 | /// 214 | /// - `router` must be a valid pointer returned by [`router_new`]. 215 | /// - If `fields` is not `NULL`, `fields` must be valid to read and write for 216 | /// `fields_len * size_of::<*const u8>()` bytes, and it must be properly aligned. 217 | /// - If `fields` is not `NULL`, `fields_len` must be valid to read and write for 218 | /// `size_of::()` bytes, and it must be properly aligned. 219 | /// - DO NOT write the memory pointed by the elements of `fields`. 220 | /// - DO NOT access the memory pointed by the elements of `fields` 221 | /// after it becomes invalid, see the `Lifetimes` section. 222 | #[no_mangle] 223 | pub unsafe extern "C" fn router_get_fields( 224 | router: &Router<&Schema>, 225 | fields: *mut *const u8, 226 | fields_len: *mut usize, 227 | ) -> usize { 228 | if !fields.is_null() { 229 | assert!(!fields_len.is_null()); 230 | assert!(*fields_len >= router.fields.len()); 231 | 232 | let fields = from_raw_parts_mut(fields, *fields_len); 233 | let fields_len = from_raw_parts_mut(fields_len, *fields_len); 234 | 235 | for (i, k) in router.fields.keys().enumerate() { 236 | fields[i] = k.as_bytes().as_ptr(); 237 | fields_len[i] = k.len() 238 | } 239 | } 240 | 241 | router.fields.len() 242 | } 243 | 244 | #[cfg(test)] 245 | mod tests { 246 | use super::*; 247 | 248 | #[test] 249 | fn test_long_error_message() { 250 | unsafe { 251 | let schema = Schema::default(); 252 | let mut router = Router::new(&schema); 253 | let uuid = ffi::CString::new("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c").unwrap(); 254 | let junk = ffi::CString::new(vec![b'a'; ERR_BUF_MAX_LEN * 2]).unwrap(); 255 | let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; 256 | let mut errbuf_len = ERR_BUF_MAX_LEN; 257 | 258 | let result = router_add_matcher( 259 | &mut router, 260 | 1, 261 | uuid.as_ptr() as *const i8, 262 | junk.as_ptr() as *const i8, 263 | errbuf.as_mut_ptr(), 264 | &mut errbuf_len, 265 | ); 266 | assert!(!result); 267 | assert_eq!(errbuf_len, ERR_BUF_MAX_LEN); 268 | } 269 | } 270 | 271 | #[test] 272 | fn test_short_error_message() { 273 | unsafe { 274 | let schema = Schema::default(); 275 | let mut router = Router::new(&schema); 276 | let uuid = ffi::CString::new("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c").unwrap(); 277 | let junk = ffi::CString::new("aaaa").unwrap(); 278 | let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; 279 | let mut errbuf_len = ERR_BUF_MAX_LEN; 280 | 281 | let result = router_add_matcher( 282 | &mut router, 283 | 1, 284 | uuid.as_ptr() as *const i8, 285 | junk.as_ptr() as *const i8, 286 | errbuf.as_mut_ptr(), 287 | &mut errbuf_len, 288 | ); 289 | assert!(!result); 290 | assert!(errbuf_len < ERR_BUF_MAX_LEN); 291 | } 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /src/parser.rs: -------------------------------------------------------------------------------- 1 | extern crate pest; 2 | 3 | use crate::ast::{ 4 | BinaryOperator, Expression, Lhs, LhsTransformations, LogicalExpression, Predicate, Value, 5 | }; 6 | use cidr::{IpCidr, Ipv4Cidr, Ipv6Cidr}; 7 | use pest::error::Error as ParseError; 8 | use pest::error::ErrorVariant; 9 | use pest::iterators::Pair; 10 | use pest::pratt_parser::Assoc as AssocNew; 11 | use pest::pratt_parser::{Op, PrattParser}; 12 | use pest::Parser; 13 | use regex::Regex; 14 | use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; 15 | 16 | type ParseResult = Result>; 17 | 18 | /// cbindgen:ignore 19 | // Bug: https://github.com/eqrion/cbindgen/issues/286 20 | trait IntoParseResult { 21 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 22 | fn into_parse_result(self, pair: &Pair) -> ParseResult; 23 | } 24 | 25 | impl IntoParseResult for Result 26 | where 27 | E: ToString, 28 | { 29 | fn into_parse_result(self, pair: &Pair) -> ParseResult { 30 | self.map_err(|e| { 31 | let span = pair.as_span(); 32 | 33 | let err_var = ErrorVariant::CustomError { 34 | message: e.to_string(), 35 | }; 36 | 37 | ParseError::new_from_span(err_var, span) 38 | }) 39 | } 40 | } 41 | 42 | #[derive(Parser)] 43 | #[grammar = "atc_grammar.pest"] 44 | struct ATCParser { 45 | pratt_parser: PrattParser, 46 | } 47 | 48 | macro_rules! parse_num { 49 | ($node:expr, $ty:ident, $radix:expr) => { 50 | $ty::from_str_radix($node.as_str(), $radix).into_parse_result(&$node) 51 | }; 52 | } 53 | 54 | impl ATCParser { 55 | fn new() -> Self { 56 | Self { 57 | pratt_parser: PrattParser::new() 58 | .op(Op::infix(Rule::and_op, AssocNew::Left)) 59 | .op(Op::infix(Rule::or_op, AssocNew::Left)), 60 | } 61 | } 62 | // matcher = { SOI ~ expression ~ EOI } 63 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 64 | fn parse_matcher(&mut self, source: &str) -> ParseResult { 65 | let pairs = ATCParser::parse(Rule::matcher, source)?; 66 | let expr_pair = pairs.peek().unwrap().into_inner().peek().unwrap(); 67 | let rule = expr_pair.as_rule(); 68 | match rule { 69 | Rule::expression => parse_expression(expr_pair, &self.pratt_parser), 70 | _ => unreachable!(), 71 | } 72 | } 73 | } 74 | 75 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 76 | fn parse_ident(pair: Pair) -> ParseResult { 77 | Ok(pair.as_str().into()) 78 | } 79 | 80 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 81 | fn parse_lhs(pair: Pair) -> ParseResult { 82 | let pairs = pair.into_inner(); 83 | let pair = pairs.peek().unwrap(); 84 | let rule = pair.as_rule(); 85 | Ok(match rule { 86 | Rule::transform_func => parse_transform_func(pair)?, 87 | Rule::ident => { 88 | let var = parse_ident(pair)?; 89 | Lhs { 90 | var_name: var, 91 | transformations: Vec::new(), 92 | } 93 | } 94 | _ => unreachable!(), 95 | }) 96 | } 97 | 98 | // rhs = { str_literal | ip_literal | int_literal } 99 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 100 | fn parse_rhs(pair: Pair) -> ParseResult { 101 | let pairs = pair.into_inner(); 102 | let pair = pairs.peek().unwrap(); 103 | let rule = pair.as_rule(); 104 | Ok(match rule { 105 | Rule::str_literal => Value::String(parse_str_literal(pair)?), 106 | Rule::rawstr_literal => Value::String(parse_rawstr_literal(pair)?), 107 | Rule::ipv4_cidr_literal => Value::IpCidr(IpCidr::V4(parse_ipv4_cidr_literal(pair)?)), 108 | Rule::ipv6_cidr_literal => Value::IpCidr(IpCidr::V6(parse_ipv6_cidr_literal(pair)?)), 109 | Rule::ipv4_literal => Value::IpAddr(IpAddr::V4(parse_ipv4_literal(pair)?)), 110 | Rule::ipv6_literal => Value::IpAddr(IpAddr::V6(parse_ipv6_literal(pair)?)), 111 | Rule::int_literal => Value::Int(parse_int_literal(pair)?), 112 | _ => unreachable!(), 113 | }) 114 | } 115 | 116 | // str_literal = ${ "\"" ~ str_inner ~ "\"" } 117 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 118 | fn parse_str_literal(pair: Pair) -> ParseResult { 119 | let char_pairs = pair.into_inner(); 120 | let mut s = String::new(); 121 | for char_pair in char_pairs { 122 | let rule = char_pair.as_rule(); 123 | match rule { 124 | Rule::str_esc => s.push(parse_str_esc(char_pair)), 125 | Rule::str_char => s.push(parse_str_char(char_pair)), 126 | _ => unreachable!(), 127 | } 128 | } 129 | Ok(s) 130 | } 131 | 132 | // rawstr_literal = ${ "r#\"" ~ rawstr_char* ~ "\"#" } 133 | // rawstr_char = { !"\"#" ~ ANY } 134 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 135 | fn parse_rawstr_literal(pair: Pair) -> ParseResult { 136 | let char_pairs = pair.into_inner(); 137 | let mut s = String::new(); 138 | for char_pair in char_pairs { 139 | let rule = char_pair.as_rule(); 140 | match rule { 141 | Rule::rawstr_char => s.push(parse_str_char(char_pair)), 142 | _ => unreachable!(), 143 | } 144 | } 145 | Ok(s) 146 | } 147 | 148 | fn parse_str_esc(pair: Pair) -> char { 149 | match pair.as_str() { 150 | r#"\""# => '"', 151 | r#"\\"# => '\\', 152 | r#"\n"# => '\n', 153 | r#"\r"# => '\r', 154 | r#"\t"# => '\t', 155 | 156 | _ => unreachable!(), 157 | } 158 | } 159 | fn parse_str_char(pair: Pair) -> char { 160 | pair.as_str().chars().next().unwrap() 161 | } 162 | 163 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 164 | fn parse_ipv4_cidr_literal(pair: Pair) -> ParseResult { 165 | pair.as_str().parse().into_parse_result(&pair) 166 | } 167 | 168 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 169 | fn parse_ipv6_cidr_literal(pair: Pair) -> ParseResult { 170 | pair.as_str().parse().into_parse_result(&pair) 171 | } 172 | 173 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 174 | fn parse_ipv4_literal(pair: Pair) -> ParseResult { 175 | pair.as_str().parse().into_parse_result(&pair) 176 | } 177 | 178 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 179 | fn parse_ipv6_literal(pair: Pair) -> ParseResult { 180 | pair.as_str().parse().into_parse_result(&pair) 181 | } 182 | 183 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 184 | fn parse_int_literal(pair: Pair) -> ParseResult { 185 | let is_neg = pair.as_str().starts_with('-'); 186 | let pairs = pair.into_inner(); 187 | let pair = pairs.peek().unwrap(); // digits 188 | let rule = pair.as_rule(); 189 | let radix = match rule { 190 | Rule::hex_digits => 16, 191 | Rule::oct_digits => 8, 192 | Rule::dec_digits => 10, 193 | _ => unreachable!(), 194 | }; 195 | 196 | let mut num = parse_num!(pair, i64, radix)?; 197 | 198 | if is_neg { 199 | num = -num; 200 | } 201 | 202 | Ok(num) 203 | } 204 | 205 | // predicate = { lhs ~ binary_operator ~ rhs } 206 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 207 | fn parse_predicate(pair: Pair) -> ParseResult { 208 | let mut pairs = pair.into_inner(); 209 | let lhs = parse_lhs(pairs.next().unwrap())?; 210 | let op = parse_binary_operator(pairs.next().unwrap()); 211 | let rhs_pair = pairs.next().unwrap(); 212 | let rhs = parse_rhs(rhs_pair.clone())?; 213 | Ok(Predicate { 214 | lhs, 215 | rhs: if op == BinaryOperator::Regex { 216 | let Value::String(s) = rhs else { 217 | return Err(ParseError::new_from_span( 218 | ErrorVariant::CustomError { 219 | message: "regex operator can only be used with String operands".to_string(), 220 | }, 221 | rhs_pair.as_span(), 222 | )); 223 | }; 224 | 225 | let r = Regex::new(&s).map_err(|e| { 226 | ParseError::new_from_span( 227 | ErrorVariant::CustomError { 228 | message: e.to_string(), 229 | }, 230 | rhs_pair.as_span(), 231 | ) 232 | })?; 233 | 234 | Value::Regex(r) 235 | } else { 236 | rhs 237 | }, 238 | op, 239 | }) 240 | } 241 | 242 | // transform_func = { ident ~ "(" ~ lhs ~ ")" } 243 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 244 | fn parse_transform_func(pair: Pair) -> ParseResult { 245 | let span = pair.as_span(); 246 | let pairs = pair.into_inner(); 247 | let mut pairs = pairs.peekable(); 248 | let func_name = pairs.next().unwrap().as_str().to_string(); 249 | let mut lhs = parse_lhs(pairs.next().unwrap())?; 250 | lhs.transformations.push(match func_name.as_str() { 251 | "lower" => LhsTransformations::Lower, 252 | "any" => LhsTransformations::Any, 253 | unknown => { 254 | return Err(ParseError::new_from_span( 255 | ErrorVariant::CustomError { 256 | message: format!("unknown transformation function: {unknown}"), 257 | }, 258 | span, 259 | )); 260 | } 261 | }); 262 | 263 | Ok(lhs) 264 | } 265 | 266 | // binary_operator = { "==" | "!=" | "~" | "^=" | "=^" | ">=" | 267 | // ">" | "<=" | "<" | "in" | "not" ~ "in" | "contains" } 268 | fn parse_binary_operator(pair: Pair) -> BinaryOperator { 269 | let rule = pair.as_str(); 270 | use BinaryOperator as BinaryOp; 271 | match rule { 272 | "==" => BinaryOp::Equals, 273 | "!=" => BinaryOp::NotEquals, 274 | "~" => BinaryOp::Regex, 275 | "^=" => BinaryOp::Prefix, 276 | "=^" => BinaryOp::Postfix, 277 | ">=" => BinaryOp::GreaterOrEqual, 278 | ">" => BinaryOp::Greater, 279 | "<=" => BinaryOp::LessOrEqual, 280 | "<" => BinaryOp::Less, 281 | "in" => BinaryOp::In, 282 | "not in" => BinaryOp::NotIn, 283 | "contains" => BinaryOp::Contains, 284 | _ => unreachable!(), 285 | } 286 | } 287 | 288 | // parenthesised_expression = { not_op? ~ "(" ~ expression ~ ")" } 289 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 290 | fn parse_parenthesised_expression( 291 | pair: Pair, 292 | pratt: &PrattParser, 293 | ) -> ParseResult { 294 | let mut pairs = pair.into_inner(); 295 | let pair = pairs.next().unwrap(); 296 | let rule = pair.as_rule(); 297 | match rule { 298 | Rule::expression => parse_expression(pair, pratt), 299 | Rule::not_op => Ok(Expression::Logical(Box::new(LogicalExpression::Not( 300 | parse_expression(pairs.next().unwrap(), pratt)?, 301 | )))), 302 | _ => unreachable!(), 303 | } 304 | } 305 | 306 | // term = { predicate | parenthesised_expression } 307 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 308 | fn parse_term(pair: Pair, pratt: &PrattParser) -> ParseResult { 309 | let pairs = pair.into_inner(); 310 | let inner_rule = pairs.peek().unwrap(); 311 | let rule = inner_rule.as_rule(); 312 | match rule { 313 | Rule::predicate => Ok(Expression::Predicate(parse_predicate(inner_rule)?)), 314 | Rule::parenthesised_expression => parse_parenthesised_expression(inner_rule, pratt), 315 | _ => unreachable!(), 316 | } 317 | } 318 | 319 | // expression = { term ~ ( logical_operator ~ term )* } 320 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 321 | fn parse_expression(pair: Pair, pratt: &PrattParser) -> ParseResult { 322 | let pairs = pair.into_inner(); 323 | pratt 324 | .map_primary(|operand| match operand.as_rule() { 325 | Rule::term => parse_term(operand, pratt), 326 | _ => unreachable!(), 327 | }) 328 | .map_infix(|lhs, op, rhs| { 329 | Ok(match op.as_rule() { 330 | Rule::and_op => Expression::Logical(Box::new(LogicalExpression::And(lhs?, rhs?))), 331 | Rule::or_op => Expression::Logical(Box::new(LogicalExpression::Or(lhs?, rhs?))), 332 | _ => unreachable!(), 333 | }) 334 | }) 335 | .parse(pairs) 336 | } 337 | 338 | #[allow(clippy::result_large_err)] // it's fine as parsing is not the hot path 339 | pub fn parse(source: &str) -> ParseResult { 340 | ATCParser::new().parse_matcher(source) 341 | } 342 | 343 | #[cfg(test)] 344 | mod tests { 345 | use super::*; 346 | 347 | #[test] 348 | fn test_bad_syntax() { 349 | assert_eq!( 350 | parse("! a == 1").unwrap_err().to_string(), 351 | " --> 1:1\n |\n1 | ! a == 1\n | ^---\n |\n = expected term" 352 | ); 353 | assert_eq!( 354 | parse("a == 1 || ! b == 2").unwrap_err().to_string(), 355 | " --> 1:11\n |\n1 | a == 1 || ! b == 2\n | ^---\n |\n = expected term" 356 | ); 357 | assert_eq!( 358 | parse("(a == 1 || b == 2) && ! c == 3") 359 | .unwrap_err() 360 | .to_string(), 361 | " --> 1:23\n |\n1 | (a == 1 || b == 2) && ! c == 3\n | ^---\n |\n = expected term" 362 | ); 363 | } 364 | } 365 | -------------------------------------------------------------------------------- /src/ast.rs: -------------------------------------------------------------------------------- 1 | use crate::schema::Schema; 2 | use cidr::IpCidr; 3 | use regex::Regex; 4 | use std::net::IpAddr; 5 | 6 | #[cfg(feature = "serde")] 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 10 | #[derive(Clone, Debug, PartialEq, Eq)] 11 | pub enum Expression { 12 | Logical(Box), 13 | Predicate(Predicate), 14 | } 15 | 16 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 17 | #[derive(Clone, Debug, PartialEq, Eq)] 18 | pub enum LogicalExpression { 19 | And(Expression, Expression), 20 | Or(Expression, Expression), 21 | Not(Expression), 22 | } 23 | 24 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 25 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] 26 | pub enum LhsTransformations { 27 | Lower, 28 | Any, 29 | } 30 | 31 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 32 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] 33 | pub enum BinaryOperator { 34 | Equals, // == 35 | NotEquals, // != 36 | Regex, // ~ 37 | Prefix, // ^= 38 | Postfix, // =^ 39 | Greater, // > 40 | GreaterOrEqual, // >= 41 | Less, // < 42 | LessOrEqual, // <= 43 | In, // in 44 | NotIn, // not in 45 | Contains, // contains 46 | } 47 | 48 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 49 | #[derive(Debug, Clone)] 50 | pub enum Value { 51 | String(String), 52 | IpCidr(IpCidr), 53 | IpAddr(IpAddr), 54 | Int(i64), 55 | #[cfg_attr(feature = "serde", serde(with = "serde_regex"))] 56 | Regex(Regex), 57 | } 58 | 59 | impl PartialEq for Value { 60 | fn eq(&self, other: &Self) -> bool { 61 | match (self, other) { 62 | (Self::Regex(_), _) | (_, Self::Regex(_)) => { 63 | panic!("Regexes can not be compared using eq") 64 | } 65 | (Self::String(s1), Self::String(s2)) => s1 == s2, 66 | (Self::IpCidr(i1), Self::IpCidr(i2)) => i1 == i2, 67 | (Self::IpAddr(i1), Self::IpAddr(i2)) => i1 == i2, 68 | (Self::Int(i1), Self::Int(i2)) => i1 == i2, 69 | _ => false, 70 | } 71 | } 72 | } 73 | 74 | impl Eq for Value {} 75 | 76 | impl Value { 77 | pub fn my_type(&self) -> Type { 78 | match self { 79 | Value::String(_) => Type::String, 80 | Value::IpCidr(_) => Type::IpCidr, 81 | Value::IpAddr(_) => Type::IpAddr, 82 | Value::Int(_) => Type::Int, 83 | Value::Regex(_) => Type::Regex, 84 | } 85 | } 86 | } 87 | 88 | impl Value { 89 | pub fn as_str(&self) -> Option<&str> { 90 | let Value::String(s) = self else { 91 | return None; 92 | }; 93 | Some(s.as_str()) 94 | } 95 | 96 | pub fn as_regex(&self) -> Option<&Regex> { 97 | let Value::Regex(r) = self else { 98 | return None; 99 | }; 100 | Some(r) 101 | } 102 | 103 | pub fn as_int(&self) -> Option { 104 | let Value::Int(i) = self else { 105 | return None; 106 | }; 107 | Some(*i) 108 | } 109 | 110 | pub fn as_ipaddr(&self) -> Option<&IpAddr> { 111 | let Value::IpAddr(a) = self else { 112 | return None; 113 | }; 114 | Some(a) 115 | } 116 | 117 | pub fn as_ipcidr(&self) -> Option<&IpCidr> { 118 | let Value::IpCidr(c) = self else { 119 | return None; 120 | }; 121 | Some(c) 122 | } 123 | } 124 | 125 | impl From for Value { 126 | fn from(v: String) -> Self { 127 | Value::String(v) 128 | } 129 | } 130 | 131 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 132 | #[derive(Debug, Eq, PartialEq)] 133 | #[repr(C)] 134 | pub enum Type { 135 | String, 136 | IpCidr, 137 | IpAddr, 138 | Int, 139 | Regex, 140 | } 141 | 142 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 143 | #[derive(Clone, Debug, PartialEq, Eq)] 144 | pub struct Lhs { 145 | pub var_name: String, 146 | pub transformations: Vec, 147 | } 148 | 149 | impl Lhs { 150 | pub fn my_type<'a>(&self, schema: &'a Schema) -> Option<&'a Type> { 151 | schema.type_of(&self.var_name) 152 | } 153 | 154 | pub fn get_transformations(&self) -> (bool, bool) { 155 | let mut lower = false; 156 | let mut any = false; 157 | 158 | self.transformations.iter().for_each(|i| match i { 159 | LhsTransformations::Any => any = true, 160 | LhsTransformations::Lower => lower = true, 161 | }); 162 | 163 | (lower, any) 164 | } 165 | } 166 | 167 | #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] 168 | #[derive(Clone, Debug, PartialEq, Eq)] 169 | pub struct Predicate { 170 | pub lhs: Lhs, 171 | pub rhs: Value, 172 | pub op: BinaryOperator, 173 | } 174 | 175 | #[cfg(test)] 176 | mod tests { 177 | use super::*; 178 | use crate::parser::parse; 179 | use std::fmt; 180 | 181 | impl fmt::Display for Expression { 182 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 183 | write!( 184 | f, 185 | "{}", 186 | match self { 187 | Expression::Logical(logical) => logical.to_string(), 188 | Expression::Predicate(predicate) => predicate.to_string(), 189 | } 190 | ) 191 | } 192 | } 193 | 194 | impl fmt::Display for LogicalExpression { 195 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 196 | write!( 197 | f, 198 | "{}", 199 | match self { 200 | LogicalExpression::And(left, right) => { 201 | format!("({} && {})", left, right) 202 | } 203 | LogicalExpression::Or(left, right) => { 204 | format!("({} || {})", left, right) 205 | } 206 | LogicalExpression::Not(e) => { 207 | format!("!({})", e) 208 | } 209 | } 210 | ) 211 | } 212 | } 213 | 214 | impl fmt::Display for LhsTransformations { 215 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 216 | write!( 217 | f, 218 | "{}", 219 | match self { 220 | LhsTransformations::Lower => "lower".to_string(), 221 | LhsTransformations::Any => "any".to_string(), 222 | } 223 | ) 224 | } 225 | } 226 | 227 | impl fmt::Display for Value { 228 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 229 | match self { 230 | Value::String(s) => write!(f, "\"{}\"", s), 231 | Value::IpCidr(cidr) => write!(f, "{}", cidr), 232 | Value::IpAddr(addr) => write!(f, "{}", addr), 233 | Value::Int(i) => write!(f, "{}", i), 234 | Value::Regex(re) => write!(f, "\"{}\"", re), 235 | } 236 | } 237 | } 238 | 239 | impl fmt::Display for Lhs { 240 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 241 | let mut s = self.var_name.to_string(); 242 | for transformation in &self.transformations { 243 | s = format!("{}({})", transformation, s); 244 | } 245 | write!(f, "{}", s) 246 | } 247 | } 248 | 249 | impl fmt::Display for BinaryOperator { 250 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 251 | use BinaryOperator::*; 252 | 253 | write!( 254 | f, 255 | "{}", 256 | match self { 257 | Equals => "==", 258 | NotEquals => "!=", 259 | Regex => "~", 260 | Prefix => "^=", 261 | Postfix => "=^", 262 | Greater => ">", 263 | GreaterOrEqual => ">=", 264 | Less => "<", 265 | LessOrEqual => "<=", 266 | In => "in", 267 | NotIn => "not in", 268 | Contains => "contains", 269 | } 270 | ) 271 | } 272 | } 273 | 274 | impl fmt::Display for Predicate { 275 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 276 | write!(f, "({} {} {})", self.lhs, self.op, self.rhs) 277 | } 278 | } 279 | 280 | #[test] 281 | fn expr_op_and_prec() { 282 | let tests = vec![ 283 | ("a > 0", "(a > 0)"), 284 | ("a in \"abc\"", "(a in \"abc\")"), 285 | ("a == 1 && b != 2", "((a == 1) && (b != 2))"), 286 | ( 287 | "a ^= \"1\" && b =^ \"2\" || c >= 3", 288 | "((a ^= \"1\") && ((b =^ \"2\") || (c >= 3)))", 289 | ), 290 | ( 291 | "a == 1 && b != 2 || c >= 3", 292 | "((a == 1) && ((b != 2) || (c >= 3)))", 293 | ), 294 | ( 295 | "a > 1 || b < 2 && c <= 3 || d not in \"foo\"", 296 | "(((a > 1) || (b < 2)) && ((c <= 3) || (d not in \"foo\")))", 297 | ), 298 | ( 299 | "a > 1 || ((b < 2) && (c <= 3)) || d not in \"foo\"", 300 | "(((a > 1) || ((b < 2) && (c <= 3))) || (d not in \"foo\"))", 301 | ), 302 | ("!(a == 1)", "!((a == 1))"), 303 | ( 304 | "!(a == 1) && b == 2 && !(c == 3) && d >= 4", 305 | "(((!((a == 1)) && (b == 2)) && !((c == 3))) && (d >= 4))", 306 | ), 307 | ( 308 | "!(a == 1 || b == 2 && c == 3) && d == 4", 309 | "(!((((a == 1) || (b == 2)) && (c == 3))) && (d == 4))", 310 | ), 311 | ]; 312 | for (input, expected) in tests { 313 | let result = parse(input).unwrap(); 314 | assert_eq!(result.to_string(), expected); 315 | } 316 | } 317 | 318 | #[test] 319 | fn expr_var_name_and_ip() { 320 | let tests = vec![ 321 | // ipv4_literal 322 | ("kong.foo in 1.1.1.1", "(kong.foo in 1.1.1.1)"), 323 | // ipv4_cidr_literal 324 | ( 325 | "kong.foo.foo2 in 10.0.0.0/24", 326 | "(kong.foo.foo2 in 10.0.0.0/24)", 327 | ), 328 | // ipv6_literal 329 | ( 330 | "kong.foo.foo3 in 2001:db8::/32", 331 | "(kong.foo.foo3 in 2001:db8::/32)", 332 | ), 333 | // ipv6_cidr_literal 334 | ( 335 | "kong.foo.foo4 in 2001:db8::/32", 336 | "(kong.foo.foo4 in 2001:db8::/32)", 337 | ), 338 | ]; 339 | for (input, expected) in tests { 340 | let result = parse(input).unwrap(); 341 | assert_eq!(result.to_string(), expected); 342 | } 343 | } 344 | 345 | #[test] 346 | fn expr_regex() { 347 | let tests = vec![ 348 | // regex_literal 349 | ( 350 | "kong.foo.foo5 ~ \"^foo.*$\"", 351 | "(kong.foo.foo5 ~ \"^foo.*$\")", 352 | ), 353 | // regex_literal 354 | ( 355 | "kong.foo.foo6 ~ \"^foo.*$\"", 356 | "(kong.foo.foo6 ~ \"^foo.*$\")", 357 | ), 358 | ]; 359 | for (input, expected) in tests { 360 | let result = parse(input).unwrap(); 361 | assert_eq!(result.to_string(), expected); 362 | } 363 | } 364 | 365 | #[test] 366 | fn expr_digits() { 367 | let tests = vec![ 368 | // dec literal 369 | ("kong.foo.foo7 == 123", "(kong.foo.foo7 == 123)"), 370 | // hex literal 371 | ("kong.foo.foo8 == 0x123", "(kong.foo.foo8 == 291)"), 372 | // oct literal 373 | ("kong.foo.foo9 == 0123", "(kong.foo.foo9 == 83)"), 374 | // dec negative literal 375 | ("kong.foo.foo10 == -123", "(kong.foo.foo10 == -123)"), 376 | // hex negative literal 377 | ("kong.foo.foo11 == -0x123", "(kong.foo.foo11 == -291)"), 378 | // oct negative literal 379 | ("kong.foo.foo12 == -0123", "(kong.foo.foo12 == -83)"), 380 | ]; 381 | for (input, expected) in tests { 382 | let result = parse(input).unwrap(); 383 | assert_eq!(result.to_string(), expected); 384 | } 385 | } 386 | 387 | #[test] 388 | fn expr_transformations() { 389 | let tests = vec![ 390 | // lower 391 | ( 392 | "lower(kong.foo.foo13) == \"foo\"", 393 | "(lower(kong.foo.foo13) == \"foo\")", 394 | ), 395 | // any 396 | ( 397 | "any(kong.foo.foo14) == \"foo\"", 398 | "(any(kong.foo.foo14) == \"foo\")", 399 | ), 400 | ]; 401 | for (input, expected) in tests { 402 | let result = parse(input).unwrap(); 403 | assert_eq!(result.to_string(), expected); 404 | } 405 | } 406 | 407 | #[test] 408 | fn expr_transformations_nested() { 409 | let tests = vec![ 410 | // lower + lower 411 | ( 412 | "lower(lower(kong.foo.foo15)) == \"foo\"", 413 | "(lower(lower(kong.foo.foo15)) == \"foo\")", 414 | ), 415 | // lower + any 416 | ( 417 | "lower(any(kong.foo.foo16)) == \"foo\"", 418 | "(lower(any(kong.foo.foo16)) == \"foo\")", 419 | ), 420 | // any + lower 421 | ( 422 | "any(lower(kong.foo.foo17)) == \"foo\"", 423 | "(any(lower(kong.foo.foo17)) == \"foo\")", 424 | ), 425 | // any + any 426 | ( 427 | "any(any(kong.foo.foo18)) == \"foo\"", 428 | "(any(any(kong.foo.foo18)) == \"foo\")", 429 | ), 430 | ]; 431 | for (input, expected) in tests { 432 | let result = parse(input).unwrap(); 433 | assert_eq!(result.to_string(), expected); 434 | } 435 | } 436 | 437 | #[test] 438 | fn str_unicode_test() { 439 | let tests = vec![ 440 | // cjk chars 441 | ("t_msg in \"你好\"", "(t_msg in \"你好\")"), 442 | // 0xXXX unicode 443 | ("t_msg in \"\u{4f60}\u{597d}\"", "(t_msg in \"你好\")"), 444 | ]; 445 | for (input, expected) in tests { 446 | let result = parse(input).unwrap(); 447 | assert_eq!(result.to_string(), expected); 448 | } 449 | } 450 | 451 | #[test] 452 | fn rawstr_test() { 453 | let tests = vec![ 454 | // invalid escape sequence 455 | (r##"a == r#"/path/to/\d+"#"##, r#"(a == "/path/to/\d+")"#), 456 | // valid escape sequence 457 | (r##"a == r#"/path/to/\n+"#"##, r#"(a == "/path/to/\n+")"#), 458 | ]; 459 | for (input, expected) in tests { 460 | let result = parse(input).unwrap(); 461 | assert_eq!(result.to_string(), expected); 462 | } 463 | } 464 | } 465 | -------------------------------------------------------------------------------- /src/interpreter.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate, Value}; 2 | use crate::context::{Context, Match}; 3 | 4 | pub trait Execute { 5 | fn execute(&self, ctx: &Context, m: &mut Match) -> bool; 6 | } 7 | 8 | impl Execute for Expression { 9 | fn execute(&self, ctx: &Context, m: &mut Match) -> bool { 10 | match self { 11 | Expression::Logical(l) => match l.as_ref() { 12 | LogicalExpression::And(l, r) => l.execute(ctx, m) && r.execute(ctx, m), 13 | LogicalExpression::Or(l, r) => l.execute(ctx, m) || r.execute(ctx, m), 14 | LogicalExpression::Not(r) => !r.execute(ctx, m), 15 | }, 16 | Expression::Predicate(p) => p.execute(ctx, m), 17 | } 18 | } 19 | } 20 | 21 | impl Execute for Predicate { 22 | fn execute(&self, ctx: &Context, m: &mut Match) -> bool { 23 | let lhs_values = match ctx.value_of(&self.lhs.var_name) { 24 | None => return false, 25 | Some(v) => v, 26 | }; 27 | 28 | let (lower, any) = self.lhs.get_transformations(); 29 | 30 | // can only be "all" or "any" mode. 31 | // - all: all values must match (default) 32 | // - any: ok if any any matched 33 | for mut lhs_value in lhs_values.iter() { 34 | let lhs_value_transformed; 35 | 36 | if lower { 37 | // SAFETY: this only panic if and only if 38 | // the semantic checking didn't catch the mismatched types, 39 | // which is a bug. 40 | let s = lhs_value.as_str().unwrap(); 41 | 42 | lhs_value_transformed = Value::String(s.to_lowercase()); 43 | lhs_value = &lhs_value_transformed; 44 | } 45 | 46 | let mut matched = false; 47 | match self.op { 48 | BinaryOperator::Equals => { 49 | if lhs_value == &self.rhs { 50 | m.matches 51 | .insert(self.lhs.var_name.clone(), self.rhs.clone()); 52 | 53 | if any { 54 | return true; 55 | } 56 | 57 | matched = true; 58 | } 59 | } 60 | BinaryOperator::NotEquals => { 61 | if lhs_value != &self.rhs { 62 | if any { 63 | return true; 64 | } 65 | 66 | matched = true; 67 | } 68 | } 69 | BinaryOperator::Regex => { 70 | // SAFETY: this only panic if and only if 71 | // the semantic checking didn't catch the mismatched types, 72 | // which is a bug. 73 | let lhs = lhs_value.as_str().unwrap(); 74 | let rhs = self.rhs.as_regex().unwrap(); 75 | 76 | if rhs.is_match(lhs) { 77 | let reg_cap = rhs.captures(lhs).unwrap(); 78 | 79 | m.matches.insert( 80 | self.lhs.var_name.clone(), 81 | Value::String(reg_cap.get(0).unwrap().as_str().to_string()), 82 | ); 83 | 84 | for (i, c) in reg_cap.iter().enumerate() { 85 | if let Some(c) = c { 86 | m.captures.insert(i.to_string(), c.as_str().to_string()); 87 | } 88 | } 89 | 90 | // named captures 91 | for n in rhs.capture_names().flatten() { 92 | if let Some(value) = reg_cap.name(n) { 93 | m.captures.insert(n.to_string(), value.as_str().to_string()); 94 | } 95 | } 96 | 97 | if any { 98 | return true; 99 | } 100 | 101 | matched = true; 102 | } 103 | } 104 | BinaryOperator::Prefix => { 105 | // SAFETY: this only panic if and only if 106 | // the semantic checking didn't catch the mismatched types, 107 | // which is a bug. 108 | let lhs = lhs_value.as_str().unwrap(); 109 | let rhs = self.rhs.as_str().unwrap(); 110 | 111 | if lhs.starts_with(rhs) { 112 | m.matches 113 | .insert(self.lhs.var_name.clone(), self.rhs.clone()); 114 | if any { 115 | return true; 116 | } 117 | 118 | matched = true; 119 | } 120 | } 121 | BinaryOperator::Postfix => { 122 | // SAFETY: this only panic if and only if 123 | // the semantic checking didn't catch the mismatched types, 124 | // which is a bug. 125 | let lhs = lhs_value.as_str().unwrap(); 126 | let rhs = self.rhs.as_str().unwrap(); 127 | 128 | if lhs.ends_with(rhs) { 129 | m.matches 130 | .insert(self.lhs.var_name.clone(), self.rhs.clone()); 131 | if any { 132 | return true; 133 | } 134 | 135 | matched = true; 136 | } 137 | } 138 | BinaryOperator::Greater => { 139 | // SAFETY: this only panic if and only if 140 | // the semantic checking didn't catch the mismatched types, 141 | // which is a bug. 142 | let lhs = lhs_value.as_int().unwrap(); 143 | let rhs = self.rhs.as_int().unwrap(); 144 | 145 | if lhs > rhs { 146 | if any { 147 | return true; 148 | } 149 | 150 | matched = true; 151 | } 152 | } 153 | BinaryOperator::GreaterOrEqual => { 154 | // SAFETY: this only panic if and only if 155 | // the semantic checking didn't catch the mismatched types, 156 | // which is a bug. 157 | let lhs = lhs_value.as_int().unwrap(); 158 | let rhs = self.rhs.as_int().unwrap(); 159 | 160 | if lhs >= rhs { 161 | if any { 162 | return true; 163 | } 164 | 165 | matched = true; 166 | } 167 | } 168 | BinaryOperator::Less => { 169 | // SAFETY: this only panic if and only if 170 | // the semantic checking didn't catch the mismatched types, 171 | // which is a bug. 172 | let lhs = lhs_value.as_int().unwrap(); 173 | let rhs = self.rhs.as_int().unwrap(); 174 | 175 | if lhs < rhs { 176 | if any { 177 | return true; 178 | } 179 | 180 | matched = true; 181 | } 182 | } 183 | BinaryOperator::LessOrEqual => { 184 | // SAFETY: this only panic if and only if 185 | // the semantic checking didn't catch the mismatched types, 186 | // which is a bug. 187 | let lhs = lhs_value.as_int().unwrap(); 188 | let rhs = self.rhs.as_int().unwrap(); 189 | 190 | if lhs <= rhs { 191 | if any { 192 | return true; 193 | } 194 | 195 | matched = true; 196 | } 197 | } 198 | BinaryOperator::In => { 199 | // SAFETY: this only panic if and only if 200 | // the semantic checking didn't catch the mismatched types, 201 | // which is a bug. 202 | let lhs = lhs_value.as_ipaddr().unwrap(); 203 | let rhs = self.rhs.as_ipcidr().unwrap(); 204 | 205 | if rhs.contains(lhs) { 206 | matched = true; 207 | if any { 208 | return true; 209 | } 210 | } 211 | } 212 | BinaryOperator::NotIn => { 213 | // SAFETY: this only panic if and only if 214 | // the semantic checking didn't catch the mismatched types, 215 | // which is a bug. 216 | let lhs = lhs_value.as_ipaddr().unwrap(); 217 | let rhs = self.rhs.as_ipcidr().unwrap(); 218 | 219 | if !rhs.contains(lhs) { 220 | matched = true; 221 | if any { 222 | return true; 223 | } 224 | } 225 | } 226 | BinaryOperator::Contains => { 227 | // SAFETY: this only panic if and only if 228 | // the semantic checking didn't catch the mismatched types, 229 | // which is a bug. 230 | let lhs = lhs_value.as_str().unwrap(); 231 | let rhs = self.rhs.as_str().unwrap(); 232 | 233 | if lhs.contains(rhs) { 234 | if any { 235 | return true; 236 | } 237 | 238 | matched = true; 239 | } 240 | } 241 | } // match 242 | 243 | if !any && !matched { 244 | // all and nothing matched 245 | return false; 246 | } 247 | } // for iter 248 | 249 | // if we reached here, it means that `any` did not find a match, 250 | // or we passed all matches for `all`. So we simply need to return 251 | // !any && lhs_values.len() > 0 to cover both cases 252 | !any && !lhs_values.is_empty() 253 | } 254 | } 255 | 256 | #[test] 257 | fn test_predicate() { 258 | use crate::ast; 259 | use crate::schema; 260 | 261 | let mut mat = Match::new(); 262 | let mut schema = schema::Schema::default(); 263 | schema.add_field("my_key", ast::Type::String); 264 | let mut ctx = Context::new(&schema); 265 | 266 | // check when value list is empty 267 | // check if all values match starts_with foo -- should be false 268 | let p = Predicate { 269 | lhs: ast::Lhs { 270 | var_name: "my_key".to_string(), 271 | transformations: vec![], 272 | }, 273 | rhs: Value::String("foo".to_string()), 274 | op: BinaryOperator::Prefix, 275 | }; 276 | 277 | assert!(!p.execute(&mut ctx, &mut mat)); 278 | 279 | // check if any value matches starts_with foo -- should be false 280 | let p = Predicate { 281 | lhs: ast::Lhs { 282 | var_name: "my_key".to_string(), 283 | transformations: vec![], 284 | }, 285 | rhs: Value::String("foo".to_string()), 286 | op: BinaryOperator::Prefix, 287 | }; 288 | 289 | assert!(!p.execute(&mut ctx, &mut mat)); 290 | 291 | // test any mode 292 | let lhs_values = vec![ 293 | Value::String("foofoo".to_string()), 294 | Value::String("foobar".to_string()), 295 | Value::String("foocar".to_string()), 296 | Value::String("fooban".to_string()), 297 | ]; 298 | 299 | for v in lhs_values { 300 | ctx.add_value("my_key", v); 301 | } 302 | 303 | // check if all values match starts_with foo -- should be true 304 | let p = Predicate { 305 | lhs: ast::Lhs { 306 | var_name: "my_key".to_string(), 307 | transformations: vec![], 308 | }, 309 | rhs: Value::String("foo".to_string()), 310 | op: BinaryOperator::Prefix, 311 | }; 312 | 313 | assert!(p.execute(&mut ctx, &mut mat)); 314 | 315 | // check if all values match ends_with foo -- should be false 316 | let p = Predicate { 317 | lhs: ast::Lhs { 318 | var_name: "my_key".to_string(), 319 | transformations: vec![], 320 | }, 321 | rhs: Value::String("foo".to_string()), 322 | op: BinaryOperator::Postfix, 323 | }; 324 | 325 | assert!(!p.execute(&mut ctx, &mut mat)); 326 | 327 | // check if any value matches ends_with foo -- should be true 328 | let p = Predicate { 329 | lhs: ast::Lhs { 330 | var_name: "my_key".to_string(), 331 | transformations: vec![ast::LhsTransformations::Any], 332 | }, 333 | rhs: Value::String("foo".to_string()), 334 | op: BinaryOperator::Postfix, 335 | }; 336 | 337 | assert!(p.execute(&mut ctx, &mut mat)); 338 | 339 | // check if any value matches starts_with foo -- should be true 340 | let p = Predicate { 341 | lhs: ast::Lhs { 342 | var_name: "my_key".to_string(), 343 | transformations: vec![ast::LhsTransformations::Any], 344 | }, 345 | rhs: Value::String("foo".to_string()), 346 | op: BinaryOperator::Prefix, 347 | }; 348 | 349 | assert!(p.execute(&mut ctx, &mut mat)); 350 | 351 | // check if any value matches ends_with nar -- should be false 352 | let p = Predicate { 353 | lhs: ast::Lhs { 354 | var_name: "my_key".to_string(), 355 | transformations: vec![ast::LhsTransformations::Any], 356 | }, 357 | rhs: Value::String("nar".to_string()), 358 | op: BinaryOperator::Postfix, 359 | }; 360 | 361 | assert!(!p.execute(&mut ctx, &mut mat)); 362 | 363 | // check if any value matches ends_with empty string -- should be true 364 | let p = Predicate { 365 | lhs: ast::Lhs { 366 | var_name: "my_key".to_string(), 367 | transformations: vec![ast::LhsTransformations::Any], 368 | }, 369 | rhs: Value::String("".to_string()), 370 | op: BinaryOperator::Postfix, 371 | }; 372 | 373 | assert!(p.execute(&mut ctx, &mut mat)); 374 | 375 | // check if any value matches starts_with empty string -- should be true 376 | let p = Predicate { 377 | lhs: ast::Lhs { 378 | var_name: "my_key".to_string(), 379 | transformations: vec![ast::LhsTransformations::Any], 380 | }, 381 | rhs: Value::String("".to_string()), 382 | op: BinaryOperator::Prefix, 383 | }; 384 | 385 | assert!(p.execute(&mut ctx, &mut mat)); 386 | 387 | // check if any value matches contains `ob` -- should be true 388 | let p = Predicate { 389 | lhs: ast::Lhs { 390 | var_name: "my_key".to_string(), 391 | transformations: vec![ast::LhsTransformations::Any], 392 | }, 393 | rhs: Value::String("ob".to_string()), 394 | op: BinaryOperator::Contains, 395 | }; 396 | 397 | assert!(p.execute(&mut ctx, &mut mat)); 398 | 399 | // check if any value matches contains `ok` -- should be false 400 | let p = Predicate { 401 | lhs: ast::Lhs { 402 | var_name: "my_key".to_string(), 403 | transformations: vec![ast::LhsTransformations::Any], 404 | }, 405 | rhs: Value::String("ok".to_string()), 406 | op: BinaryOperator::Contains, 407 | }; 408 | 409 | assert!(!p.execute(&mut ctx, &mut mat)); 410 | } 411 | -------------------------------------------------------------------------------- /src/ffi/expression.rs: -------------------------------------------------------------------------------- 1 | use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate}; 2 | use crate::ffi::ERR_BUF_MAX_LEN; 3 | use crate::schema::Schema; 4 | use bitflags::bitflags; 5 | use std::cmp::min; 6 | use std::ffi; 7 | use std::os::raw::c_char; 8 | use std::slice::from_raw_parts_mut; 9 | 10 | use std::iter::Iterator; 11 | 12 | struct PredicateIterator<'a> { 13 | stack: Vec<&'a Expression>, 14 | } 15 | 16 | impl<'a> PredicateIterator<'a> { 17 | fn new(expr: &'a Expression) -> Self { 18 | Self { stack: vec![expr] } 19 | } 20 | } 21 | 22 | impl<'a> Iterator for PredicateIterator<'a> { 23 | type Item = &'a Predicate; 24 | 25 | fn next(&mut self) -> Option { 26 | while let Some(expr) = self.stack.pop() { 27 | match expr { 28 | Expression::Logical(l) => match l.as_ref() { 29 | LogicalExpression::And(l, r) | LogicalExpression::Or(l, r) => { 30 | self.stack.push(l); 31 | self.stack.push(r); 32 | } 33 | LogicalExpression::Not(r) => { 34 | self.stack.push(r); 35 | } 36 | }, 37 | Expression::Predicate(p) => return Some(p), 38 | } 39 | } 40 | None 41 | } 42 | } 43 | 44 | impl Expression { 45 | fn iter_predicates(&self) -> PredicateIterator<'_> { 46 | PredicateIterator::new(self) 47 | } 48 | } 49 | 50 | bitflags! { 51 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] 52 | #[repr(C)] 53 | pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ { 54 | const EQUALS = 1 << 0; 55 | const NOT_EQUALS = 1 << 1; 56 | const REGEX = 1 << 2; 57 | const PREFIX = 1 << 3; 58 | const POSTFIX = 1 << 4; 59 | const GREATER = 1 << 5; 60 | const GREATER_OR_EQUAL = 1 << 6; 61 | const LESS = 1 << 7; 62 | const LESS_OR_EQUAL = 1 << 8; 63 | const IN = 1 << 9; 64 | const NOT_IN = 1 << 10; 65 | const CONTAINS = 1 << 11; 66 | 67 | const UNUSED = !(Self::EQUALS.bits() 68 | | Self::NOT_EQUALS.bits() 69 | | Self::REGEX.bits() 70 | | Self::PREFIX.bits() 71 | | Self::POSTFIX.bits() 72 | | Self::GREATER.bits() 73 | | Self::GREATER_OR_EQUAL.bits() 74 | | Self::LESS.bits() 75 | | Self::LESS_OR_EQUAL.bits() 76 | | Self::IN.bits() 77 | | Self::NOT_IN.bits() 78 | | Self::CONTAINS.bits()); 79 | } 80 | } 81 | 82 | impl From<&BinaryOperator> for BinaryOperatorFlags { 83 | fn from(op: &BinaryOperator) -> Self { 84 | match op { 85 | BinaryOperator::Equals => Self::EQUALS, 86 | BinaryOperator::NotEquals => Self::NOT_EQUALS, 87 | BinaryOperator::Regex => Self::REGEX, 88 | BinaryOperator::Prefix => Self::PREFIX, 89 | BinaryOperator::Postfix => Self::POSTFIX, 90 | BinaryOperator::Greater => Self::GREATER, 91 | BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL, 92 | BinaryOperator::Less => Self::LESS, 93 | BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL, 94 | BinaryOperator::In => Self::IN, 95 | BinaryOperator::NotIn => Self::NOT_IN, 96 | BinaryOperator::Contains => Self::CONTAINS, 97 | } 98 | } 99 | } 100 | 101 | pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0; 102 | pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1; 103 | pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; 104 | 105 | /// Validates an ATC expression against a schema and get its elements. 106 | /// 107 | /// # Arguments 108 | /// 109 | /// - `atc`: a C-style string representing the ATC expression. 110 | /// - `schema`: a valid pointer to a [`Schema`] object, as returned by [`schema_new`]. 111 | /// - `fields_buf`: a buffer for storing the fields used in the expression. 112 | /// - `fields_buf_len`: a pointer to the length of `fields_buf`. 113 | /// - `fields_total`: a pointer for storing the total number of unique fields used in the expression. 114 | /// - `operators`: a pointer for storing the bitflags representing used operators. 115 | /// - `errbuf`: a buffer to store any error messages. 116 | /// - `errbuf_len`: a pointer to the length of the error message buffer. 117 | /// 118 | /// # Returns 119 | /// 120 | /// An integer indicating the validation result: 121 | /// - `ATC_ROUTER_EXPRESSION_VALIDATE_OK` (0): Validation succeeded. 122 | /// - `ATC_ROUTER_EXPRESSION_VALIDATE_FAILED` (1): Validation failed; `errbuf` and `errbuf_len` will be updated with an error message. 123 | /// - `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL` (2): The provided `fields_buf` is too small. 124 | /// 125 | /// If `fields_buf_len` indicates that `fields_buf` is sufficient, this function writes the used fields to `fields_buf`, each field terminated by `\0`. 126 | /// It stores the total number of fields in `fields_total`. 127 | /// 128 | /// If `fields_buf_len` indicates that `fields_buf` is insufficient, it returns `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. 129 | /// 130 | /// It writes the used operators as bitflags to `operators`. 131 | /// Bitflags are defined by `BinaryOperatorFlags` and must exclude bits from `BinaryOperatorFlags::UNUSED`. 132 | /// 133 | /// 134 | /// # Safety 135 | /// 136 | /// Violating any of the following constraints results in undefined behavior: 137 | /// 138 | /// - `atc` must be a valid pointer to a C-style string, properly aligned, and must not contain an internal `\0`. 139 | /// - `schema` must be a valid pointer returned by [`schema_new`]. 140 | /// - `fields_buf`, must be valid for writing `fields_buf_len * size_of::()` bytes and properly aligned. 141 | /// - `fields_buf_len` must be a valid pointer to write `size_of::()` bytes and properly aligned. 142 | /// - `fields_total` must be a valid pointer to write `size_of::()` bytes and properly aligned. 143 | /// - `operators` must be a valid pointer to write `size_of::()` bytes and properly aligned. 144 | /// - `errbuf` must be valid for reading and writing `errbuf_len * size_of::()` bytes and properly aligned. 145 | /// - `errbuf_len` must be a valid pointer for reading and writing `size_of::()` bytes and properly aligned. 146 | /// 147 | /// [`schema_new`]: crate::ffi::schema::schema_new 148 | #[no_mangle] 149 | pub unsafe extern "C" fn expression_validate( 150 | atc: *const u8, 151 | schema: &Schema, 152 | fields_buf: *mut u8, 153 | fields_buf_len: *mut usize, 154 | fields_total: *mut usize, 155 | operators: *mut u64, 156 | errbuf: *mut u8, 157 | errbuf_len: *mut usize, 158 | ) -> i64 { 159 | use std::collections::HashSet; 160 | 161 | use crate::parser::parse; 162 | use crate::semantics::Validate; 163 | 164 | let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); 165 | let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); 166 | 167 | // Parse the expression 168 | let result = parse(atc).map_err(|e| e.to_string()); 169 | if let Err(e) = result { 170 | let errlen = min(e.len(), *errbuf_len); 171 | errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); 172 | *errbuf_len = errlen; 173 | return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; 174 | } 175 | // Unwrap is safe since we've already checked for error 176 | let ast = result.unwrap(); 177 | 178 | // Validate expression with schema 179 | if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) { 180 | let errlen = min(e.len(), *errbuf_len); 181 | errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); 182 | *errbuf_len = errlen; 183 | return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; 184 | } 185 | 186 | // Iterate over predicates to get fields and operators 187 | let mut ops = BinaryOperatorFlags::empty(); 188 | let mut existed_fields = HashSet::new(); 189 | let mut total_fields_length = 0; 190 | let mut fields_buf_ptr = fields_buf; 191 | *fields_total = 0; 192 | 193 | for pred in ast.iter_predicates() { 194 | ops |= BinaryOperatorFlags::from(&pred.op); 195 | 196 | let field = pred.lhs.var_name.as_str(); 197 | 198 | if existed_fields.insert(field) { 199 | // Fields is not existed yet. 200 | // Unwrap is safe since `field` cannot contain '\0' as `atc` must not contain any internal `\0`. 201 | let field = ffi::CString::new(field).unwrap(); 202 | let field_slice = field.as_bytes_with_nul(); 203 | let field_len = field_slice.len(); 204 | 205 | *fields_total += 1; 206 | total_fields_length += field_len; 207 | 208 | if *fields_buf_len < total_fields_length { 209 | return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; 210 | } 211 | 212 | let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); 213 | fields_buf.copy_from_slice(field_slice); 214 | fields_buf_ptr = fields_buf_ptr.add(field_len); 215 | } 216 | } 217 | 218 | *operators = ops.bits(); 219 | 220 | ATC_ROUTER_EXPRESSION_VALIDATE_OK 221 | } 222 | 223 | #[cfg(test)] 224 | mod tests { 225 | use super::*; 226 | use crate::ast::Type; 227 | 228 | fn expr_validate_on( 229 | schema: &Schema, 230 | atc: &str, 231 | fields_buf_size: usize, 232 | ) -> Result<(Vec, usize, u64), (i64, String)> { 233 | let atc = ffi::CString::new(atc).unwrap(); 234 | let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; 235 | let mut errbuf_len = ERR_BUF_MAX_LEN; 236 | 237 | let mut fields_buf = vec![0u8; fields_buf_size]; 238 | let mut fields_buf_len = fields_buf.len(); 239 | let mut fields_total = 0; 240 | let mut operators = 0u64; 241 | 242 | let result = unsafe { 243 | expression_validate( 244 | atc.as_bytes().as_ptr(), 245 | schema, 246 | fields_buf.as_mut_ptr(), 247 | &mut fields_buf_len, 248 | &mut fields_total, 249 | &mut operators, 250 | errbuf.as_mut_ptr(), 251 | &mut errbuf_len, 252 | ) 253 | }; 254 | 255 | match result { 256 | ATC_ROUTER_EXPRESSION_VALIDATE_OK => { 257 | let mut fields = Vec::::with_capacity(fields_total); 258 | let mut p = 0; 259 | for _ in 0..fields_total { 260 | let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) }; 261 | let len = field.to_bytes().len() + 1; 262 | fields.push(field.to_string_lossy().to_string()); 263 | p += len; 264 | } 265 | assert_eq!(fields_buf_len, p, "Fields buffer length mismatch"); 266 | fields.sort(); 267 | Ok((fields, fields_buf_len, operators)) 268 | } 269 | ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => { 270 | let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); 271 | Err((result, err)) 272 | } 273 | ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())), 274 | _ => panic!("Unknown error code"), 275 | } 276 | } 277 | 278 | #[test] 279 | fn test_expression_validate_success() { 280 | let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; 281 | 282 | let mut schema = Schema::default(); 283 | schema.add_field("net.protocol", Type::String); 284 | schema.add_field("net.dst.port", Type::Int); 285 | schema.add_field("net.src.ip", Type::IpAddr); 286 | schema.add_field("http.path", Type::String); 287 | 288 | let result = expr_validate_on(&schema, atc, 47); 289 | 290 | assert!(result.is_ok(), "Validation failed"); 291 | let (fields, fields_buf_len, ops) = result.unwrap(); // Unwrap is safe since we've already asserted it 292 | assert_eq!( 293 | ops, 294 | (BinaryOperatorFlags::EQUALS 295 | | BinaryOperatorFlags::REGEX 296 | | BinaryOperatorFlags::IN 297 | | BinaryOperatorFlags::NOT_IN 298 | | BinaryOperatorFlags::CONTAINS) 299 | .bits(), 300 | "Operators mismatch" 301 | ); 302 | assert_eq!( 303 | fields, 304 | vec![ 305 | "http.path".to_string(), 306 | "net.dst.port".to_string(), 307 | "net.protocol".to_string(), 308 | "net.src.ip".to_string() 309 | ], 310 | "Fields mismatch" 311 | ); 312 | assert_eq!(fields_buf_len, 47, "Fields buffer length mismatch"); 313 | } 314 | 315 | #[test] 316 | fn test_expression_validate_failed_parse() { 317 | let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0) && http.path contains "hello""##; 318 | 319 | let mut schema = Schema::default(); 320 | schema.add_field("net.protocol", Type::String); 321 | schema.add_field("net.dst.port", Type::Int); 322 | schema.add_field("net.src.ip", Type::IpAddr); 323 | schema.add_field("http.path", Type::String); 324 | 325 | let result = expr_validate_on(&schema, atc, 1024); 326 | 327 | assert!(result.is_err(), "Validation unexcepted success"); 328 | let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it 329 | assert_eq!( 330 | err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, 331 | "Error code mismatch" 332 | ); 333 | assert_eq!( 334 | err_message, 335 | "In/NotIn operators only supports IP in CIDR".to_string(), 336 | "Error message mismatch" 337 | ); 338 | } 339 | 340 | #[test] 341 | fn test_expression_validate_failed_validate() { 342 | let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; 343 | 344 | let mut schema = Schema::default(); 345 | schema.add_field("net.protocol", Type::String); 346 | schema.add_field("net.dst.port", Type::Int); 347 | schema.add_field("net.src.ip", Type::IpAddr); 348 | 349 | let result = expr_validate_on(&schema, atc, 1024); 350 | 351 | assert!(result.is_err(), "Validation unexcepted success"); 352 | let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it 353 | assert_eq!( 354 | err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, 355 | "Error code mismatch" 356 | ); 357 | assert_eq!( 358 | err_message, 359 | "Unknown LHS field".to_string(), 360 | "Error message mismatch" 361 | ); 362 | } 363 | 364 | #[test] 365 | fn test_expression_validate_buf_too_small() { 366 | let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; 367 | 368 | let mut schema = Schema::default(); 369 | schema.add_field("net.protocol", Type::String); 370 | schema.add_field("net.dst.port", Type::Int); 371 | schema.add_field("net.src.ip", Type::IpAddr); 372 | schema.add_field("http.path", Type::String); 373 | 374 | let result = expr_validate_on(&schema, atc, 46); 375 | 376 | assert!(result.is_err(), "Validation failed"); 377 | let (err_code, _) = result.unwrap_err(); // Unwrap is safe since we've already asserted it 378 | assert_eq!( 379 | err_code, ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL, 380 | "Error code mismatch" 381 | ); 382 | } 383 | } 384 | --------------------------------------------------------------------------------