├── .github └── workflows │ └── rust.yml ├── .gitignore ├── CHANGES.md ├── Cargo.toml ├── LICENSE ├── Makefile ├── README.md └── src ├── backend ├── input_builder.rs ├── memory.rs ├── mod.rs └── redis.rs ├── lib.rs └── middleware ├── builder.rs ├── mod.rs └── tests.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust Build 2 | 3 | on: [push, pull_request] 4 | 5 | env: 6 | CARGO_TERM_COLOR: always 7 | 8 | jobs: 9 | 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | services: 14 | redis: 15 | image: redis 16 | options: >- 17 | --health-cmd "redis-cli ping" 18 | --health-interval 10s 19 | --health-timeout 5s 20 | --health-retries 5 21 | ports: 22 | - 6379:6379 23 | 24 | steps: 25 | - uses: actions/checkout@v2 26 | 27 | - name: Cargo Test 28 | run: cargo test --all-features --workspace -- --nocapture 29 | env: 30 | REDIS_HOST: localhost 31 | REDIS_PORT: 6379 32 | 33 | - name: Cargo Format Check 34 | run: cargo fmt -- --check 35 | 36 | - name: Cargo Clippy Check 37 | run: cargo clippy --all-features --workspace -- -D warnings 38 | 39 | - name: Cargo Sort Check 40 | run: cargo install cargo-sort --debug && cargo-sort --check --workspace 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .idea/ 4 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 3 | ## 0.4.0 2024-08-07 4 | 5 | - Major: Update Dashmap and Redis dependencies. 6 | 7 | ## 0.3.1 2024-01-21 8 | 9 | - Patch: Fix Redis key expiry bug. 10 | 11 | ## 0.3.0 2024-01-21 12 | 13 | - Major: Removes async-trait dependency. 14 | - Major: Redis backend now uses BITFIELD to store counts. 15 | - Major: Backend return type is now a `Decision` enum instead of a `bool`. 16 | 17 | ## 0.2.2 2022-04-19 18 | 19 | - Patch: Improve documentation. 20 | 21 | ## 0.2.1 2022-04-18 22 | 23 | - Minor: Added `SimpleBuilderFuture` type alias. 24 | 25 | ## 0.2.0 2022-04-17 26 | 27 | - Major: Middleware will now always render a response on error cases, except for an error from the wrapped service call 28 | which is now returned immediately. 29 | 30 | ## 0.1.1 2022-04-16 31 | 32 | - Patch: Fixed docs.rs configuration 33 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "actix-extensible-rate-limit" 3 | version = "0.4.0" 4 | edition = "2021" 5 | license = "MIT OR Apache-2.0" 6 | description = "Rate limiting middleware for actix-web" 7 | repository = "https://github.com/jacob-pro/actix-extensible-rate-limit" 8 | homepage = "https://github.com/jacob-pro/actix-extensible-rate-limit" 9 | 10 | [dependencies] 11 | actix-web = { version = "4", default-features = false, features = ["macros"] } 12 | dashmap = { version = "6.0", optional = true } 13 | futures = "0.3.28" 14 | log = "0.4.19" 15 | redis = { version = "0.29.1", default-features = false, features = [ 16 | "tokio-comp", 17 | "aio", 18 | "connection-manager", 19 | ], optional = true } 20 | thiserror = "1.0.40" 21 | 22 | [features] 23 | default = ["dashmap"] 24 | 25 | [dev-dependencies] 26 | tokio = { version = "1", features = ["time", "test-util"] } 27 | 28 | [package.metadata.docs.rs] 29 | all-features = true 30 | rustdoc-args = ["--cfg", "docsrs"] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test format 2 | 3 | test: 4 | cargo fmt -- --check 5 | cargo-sort --check --workspace 6 | cargo clippy --all-features --workspace -- -D warnings 7 | cargo test --all-features --workspace 8 | 9 | format: 10 | cargo fmt 11 | cargo-sort --workspace 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Actix Extensible Rate Limit 2 | 3 | [![Build status](https://github.com/jacob-pro/actix-extensible-rate-limit/actions/workflows/rust.yml/badge.svg)](https://github.com/jacob-pro/actix-extensible-rate-limit/actions) 4 | [![crates.io](https://img.shields.io/crates/v/actix-extensible-rate-limit.svg)](https://crates.io/crates/actix-extensible-rate-limit) 5 | [![docs.rs](https://docs.rs/actix-extensible-rate-limit/badge.svg)](https://docs.rs/actix-extensible-rate-limit/latest/actix_extensible_rate_limit/) 6 | 7 | An attempt at a more flexible rate limiting middleware for actix-web 8 | 9 | Allows for: 10 | 11 | - Deriving a custom rate limit key from the request context. 12 | - Using dynamic rate limits and intervals determined by the request context. 13 | - Using custom backends (store & algorithm) 14 | - Setting a custom 429 response. 15 | - Transforming the response headers based on rate limit results (e.g `x-ratelimit-remaining`). 16 | - Rolling back rate limit counts based on response codes. 17 | 18 | ## Provided Backends 19 | 20 | | Backend | Algorithm | Store | 21 | |-----------------|--------------|------------------------------------------------| 22 | | InMemoryBackend | Fixed Window | [Dashmap](https://github.com/xacrimon/dashmap) | 23 | | RedisBackend | Fixed Window | [Redis](https://github.com/mitsuhiko/redis-rs) | 24 | 25 | ## Getting Started 26 | 27 | ```rust 28 | use actix_web::{App, HttpServer}; 29 | use actix_extensible_rate_limit::{ 30 | backend::{memory::InMemoryBackend, SimpleInputFunctionBuilder}, 31 | RateLimiter, 32 | }; 33 | use std::time::Duration; 34 | 35 | #[actix_web::main] 36 | async fn main() -> std::io::Result<()> { 37 | // A backend is responsible for storing rate limit data, and choosing whether to allow/deny requests 38 | let backend = InMemoryBackend::builder().build(); 39 | 40 | HttpServer::new(move || { 41 | // Assign a limit of 5 requests per minute per client ip address 42 | let input = SimpleInputFunctionBuilder::new(Duration::from_secs(60), 5) 43 | .real_ip_key() 44 | .build(); 45 | let middleware = RateLimiter::builder(backend.clone(), input) 46 | .add_headers() 47 | .build(); 48 | App::new().wrap(middleware) 49 | }) 50 | .bind("127.0.0.1:8080")? 51 | .run() 52 | .await 53 | } 54 | ``` 55 | 56 | Try it out: 57 | 58 | ``` 59 | $ curl -v http://127.0.0.1:8080 60 | * Trying 127.0.0.1:8080... 61 | * Connected to 127.0.0.1 (127.0.0.1) port 8080 (#0) 62 | > GET / HTTP/1.1 63 | > Host: 127.0.0.1:8080 64 | > User-Agent: curl/7.83.1 65 | > Accept: */* 66 | > 67 | * Mark bundle as not supporting multiuse 68 | < HTTP/1.1 404 Not Found 69 | < content-length: 0 70 | < x-ratelimit-limit: 5 71 | < x-ratelimit-reset: 60 72 | < x-ratelimit-remaining: 4 73 | < date: Sun, 21 Jan 2024 16:52:27 GMT 74 | < 75 | * Connection #0 to host 127.0.0.1 left intact 76 | ``` -------------------------------------------------------------------------------- /src/backend/input_builder.rs: -------------------------------------------------------------------------------- 1 | use crate::backend::SimpleInput; 2 | use actix_web::dev::ServiceRequest; 3 | use actix_web::ResponseError; 4 | use std::future::{ready, Ready}; 5 | use std::net::{AddrParseError, IpAddr, Ipv6Addr}; 6 | use std::time::Duration; 7 | use thiserror::Error; 8 | 9 | type CustomFn = Box Result>; 10 | 11 | pub type SimpleInputFuture = Ready>; 12 | 13 | /// Utility to create a input function that produces a [SimpleInput]. 14 | /// 15 | /// You should take care to ensure that you are producing unique keys per backend. 16 | /// 17 | /// This will not be of any use if you want to use dynamic interval/request policies 18 | /// or perform an asynchronous option; you should instead write your own input function. 19 | pub struct SimpleInputFunctionBuilder { 20 | interval: Duration, 21 | max_requests: u64, 22 | real_ip_key: bool, 23 | peer_ip_key: bool, 24 | path_key: bool, 25 | custom_key: Option, 26 | custom_fn: Option, 27 | } 28 | 29 | impl SimpleInputFunctionBuilder { 30 | pub fn new(interval: Duration, max_requests: u64) -> Self { 31 | Self { 32 | interval, 33 | max_requests, 34 | real_ip_key: false, 35 | peer_ip_key: false, 36 | path_key: false, 37 | custom_key: None, 38 | custom_fn: None, 39 | } 40 | } 41 | 42 | /// Adds the client's real IP to the rate limiting key. 43 | /// 44 | /// # Security 45 | /// 46 | /// This calls 47 | /// [ConnectionInfo::realip_remote_addr()](actix_web::dev::ConnectionInfo::realip_remote_addr) 48 | /// internally which is only suitable for Actix applications deployed behind a proxy that you 49 | /// control. 50 | /// 51 | /// # IPv6 52 | /// 53 | /// IPv6 addresses will be grouped into a single key per /64 54 | pub fn real_ip_key(mut self) -> Self { 55 | self.real_ip_key = true; 56 | self 57 | } 58 | 59 | /// Adds the connection peer IP to the rate limiting key. 60 | /// 61 | /// This is suitable when clients connect directly to the Actix application. 62 | /// 63 | /// # IPv6 64 | /// 65 | /// IPv6 addresses will be grouped into a single key per /64 66 | pub fn peer_ip_key(mut self) -> Self { 67 | self.peer_ip_key = true; 68 | self 69 | } 70 | 71 | /// Add the request path to the rate limiting key 72 | pub fn path_key(mut self) -> Self { 73 | self.path_key = true; 74 | self 75 | } 76 | 77 | /// Add a custom component to the rate limiting key 78 | pub fn custom_key(mut self, key: &str) -> Self { 79 | self.custom_key = Some(key.to_owned()); 80 | self 81 | } 82 | 83 | /// Dynamically add a custom component to the rate limiting key 84 | pub fn custom_fn(mut self, f: F) -> Self 85 | where 86 | F: Fn(&ServiceRequest) -> Result + 'static, 87 | { 88 | self.custom_fn = Some(Box::new(f)); 89 | self 90 | } 91 | 92 | pub fn build(self) -> impl Fn(&ServiceRequest) -> SimpleInputFuture + 'static { 93 | move |req| { 94 | ready((|| { 95 | let mut components = Vec::new(); 96 | let info = req.connection_info(); 97 | if let Some(custom) = &self.custom_key { 98 | components.push(custom.clone()); 99 | } 100 | if self.real_ip_key { 101 | components.push(ip_key(info.realip_remote_addr().unwrap())?) 102 | } 103 | if self.peer_ip_key { 104 | components.push(ip_key(info.peer_addr().unwrap())?) 105 | } 106 | if self.path_key { 107 | components.push(req.path().to_owned()); 108 | } 109 | if let Some(f) = &self.custom_fn { 110 | components.push(f(req)?) 111 | } 112 | let key = components.join("-"); 113 | 114 | Ok(SimpleInput { 115 | interval: self.interval, 116 | max_requests: self.max_requests, 117 | key, 118 | }) 119 | })()) 120 | } 121 | } 122 | } 123 | 124 | #[derive(Debug, Error)] 125 | enum Error { 126 | #[error("Unable to parse remote IP address: {0}")] 127 | InvalidIpError( 128 | #[source] 129 | #[from] 130 | AddrParseError, 131 | ), 132 | } 133 | 134 | impl ResponseError for Error {} 135 | 136 | // Groups IPv6 addresses together, see: 137 | // https://adam-p.ca/blog/2022/02/ipv6-rate-limiting/ 138 | // https://support.cloudflare.com/hc/en-us/articles/115001635128-Configuring-Cloudflare-Rate-Limiting 139 | fn ip_key(ip_str: &str) -> Result { 140 | let ip = ip_str.parse::()?; 141 | Ok(match ip { 142 | IpAddr::V4(v4) => v4.to_string(), 143 | IpAddr::V6(v6) => { 144 | if let Some(v4) = v6.to_ipv4() { 145 | return Ok(v4.to_string()); 146 | } 147 | let zeroes = [0u16; 4]; 148 | let concat = [&v6.segments()[0..4], &zeroes].concat(); 149 | let concat: [u16; 8] = concat.try_into().unwrap(); 150 | let subnet = Ipv6Addr::from(concat); 151 | format!("{}/64", subnet) 152 | } 153 | }) 154 | } 155 | 156 | #[cfg(test)] 157 | mod tests { 158 | use super::*; 159 | 160 | #[test] 161 | fn test_ip_key() { 162 | // Check that IPv4 addresses are preserved 163 | assert_eq!(ip_key("142.250.187.206").unwrap(), "142.250.187.206"); 164 | // Check that IPv4 mapped addresses are preserved 165 | assert_eq!(ip_key("::FFFF:142.250.187.206").unwrap(), "142.250.187.206"); 166 | // Check that IPv6 addresses are grouped into /64 subnets 167 | assert_eq!( 168 | ip_key("2a00:1450:4009:81f::200e").unwrap(), 169 | "2a00:1450:4009:81f::/64" 170 | ); 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/backend/memory.rs: -------------------------------------------------------------------------------- 1 | use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput}; 2 | use actix_web::rt::task::JoinHandle; 3 | use actix_web::rt::time::Instant; 4 | use dashmap::DashMap; 5 | use std::convert::Infallible; 6 | use std::sync::Arc; 7 | use std::time::Duration; 8 | 9 | pub const DEFAULT_GC_INTERVAL_SECONDS: u64 = 60 * 10; 10 | 11 | /// A Fixed Window rate limiter [Backend] that uses [Dashmap](dashmap::DashMap) to store keys 12 | /// in memory. 13 | #[derive(Clone)] 14 | pub struct InMemoryBackend { 15 | map: Arc>, 16 | gc_handle: Option>>, 17 | } 18 | 19 | struct Value { 20 | ttl: Instant, 21 | count: u64, 22 | } 23 | 24 | impl InMemoryBackend { 25 | pub fn builder() -> Builder { 26 | Builder { 27 | gc_interval: Some(Duration::from_secs(DEFAULT_GC_INTERVAL_SECONDS)), 28 | } 29 | } 30 | 31 | fn garbage_collector(map: Arc>, interval: Duration) -> JoinHandle<()> { 32 | assert!( 33 | interval.as_secs_f64() > 0f64, 34 | "GC interval must be non-zero" 35 | ); 36 | actix_web::rt::spawn(async move { 37 | loop { 38 | let now = Instant::now(); 39 | map.retain(|_k, v| v.ttl > now); 40 | actix_web::rt::time::sleep_until(now + interval).await; 41 | } 42 | }) 43 | } 44 | } 45 | 46 | pub struct Builder { 47 | gc_interval: Option, 48 | } 49 | 50 | impl Builder { 51 | /// Override the default garbage collector interval. 52 | /// 53 | /// Set to None to disable garbage collection. 54 | /// 55 | /// The garbage collector periodically scans the internal map, removing expired buckets. 56 | pub fn with_gc_interval(mut self, interval: Option) -> Self { 57 | self.gc_interval = interval; 58 | self 59 | } 60 | 61 | pub fn build(self) -> InMemoryBackend { 62 | let map = Arc::new(DashMap::::new()); 63 | let gc_handle = self.gc_interval.map(|gc_interval| { 64 | Arc::new(InMemoryBackend::garbage_collector(map.clone(), gc_interval)) 65 | }); 66 | InMemoryBackend { map, gc_handle } 67 | } 68 | } 69 | 70 | impl Backend for InMemoryBackend { 71 | type Output = SimpleOutput; 72 | type RollbackToken = String; 73 | type Error = Infallible; 74 | 75 | async fn request( 76 | &self, 77 | input: SimpleInput, 78 | ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> { 79 | let now = Instant::now(); 80 | let mut count = 1; 81 | let mut expiry = now 82 | .checked_add(input.interval) 83 | .expect("Interval unexpectedly large"); 84 | self.map 85 | .entry(input.key.clone()) 86 | .and_modify(|v| { 87 | // If this bucket hasn't yet expired, increment and extract the count/expiry 88 | if v.ttl > now { 89 | v.count += 1; 90 | count = v.count; 91 | expiry = v.ttl; 92 | } else { 93 | // If this bucket has expired we will reset the count to 1 and set a new TTL. 94 | v.ttl = expiry; 95 | v.count = count; 96 | } 97 | }) 98 | .or_insert_with(|| Value { 99 | // If the bucket doesn't exist, create it with a count of 1, and set the TTL. 100 | ttl: expiry, 101 | count, 102 | }); 103 | let allow = count <= input.max_requests; 104 | let output = SimpleOutput { 105 | limit: input.max_requests, 106 | remaining: input.max_requests.saturating_sub(count), 107 | reset: expiry, 108 | }; 109 | Ok((Decision::from_allowed(allow), output, input.key)) 110 | } 111 | 112 | async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> { 113 | self.map.entry(token).and_modify(|v| { 114 | v.count = v.count.saturating_sub(1); 115 | }); 116 | Ok(()) 117 | } 118 | } 119 | 120 | impl SimpleBackend for InMemoryBackend { 121 | async fn remove_key(&self, key: &str) -> Result<(), Self::Error> { 122 | self.map.remove(key); 123 | Ok(()) 124 | } 125 | } 126 | 127 | impl Drop for InMemoryBackend { 128 | fn drop(&mut self) { 129 | if let Some(handle) = &self.gc_handle { 130 | handle.abort(); 131 | } 132 | } 133 | } 134 | 135 | #[cfg(test)] 136 | mod tests { 137 | use super::*; 138 | 139 | const MINUTE: Duration = Duration::from_secs(60); 140 | 141 | #[actix_web::test] 142 | async fn test_allow_deny() { 143 | tokio::time::pause(); 144 | let backend = InMemoryBackend::builder().build(); 145 | let input = SimpleInput { 146 | interval: MINUTE, 147 | max_requests: 5, 148 | key: "KEY1".to_string(), 149 | }; 150 | for _ in 0..5 { 151 | // First 5 should be allowed 152 | let (allow, _, _) = backend.request(input.clone()).await.unwrap(); 153 | assert!(allow.is_allowed()); 154 | } 155 | // Sixth should be denied 156 | let (allow, _, _) = backend.request(input.clone()).await.unwrap(); 157 | assert!(!allow.is_allowed()); 158 | } 159 | 160 | #[actix_web::test] 161 | async fn test_reset() { 162 | tokio::time::pause(); 163 | let backend = InMemoryBackend::builder().with_gc_interval(None).build(); 164 | let input = SimpleInput { 165 | interval: MINUTE, 166 | max_requests: 1, 167 | key: "KEY1".to_string(), 168 | }; 169 | // Make first request, should be allowed 170 | let (decision, _, _) = backend.request(input.clone()).await.unwrap(); 171 | assert!(decision.is_allowed()); 172 | // Request again, should be denied 173 | let (decision, _, _) = backend.request(input.clone()).await.unwrap(); 174 | assert!(decision.is_denied()); 175 | // Advance time and try again, should now be allowed 176 | tokio::time::advance(MINUTE).await; 177 | // We want to be sure the key hasn't been garbage collected, and we are testing the expiry logic 178 | assert!(backend.map.contains_key("KEY1")); 179 | let (decision, _, _) = backend.request(input).await.unwrap(); 180 | assert!(decision.is_allowed()); 181 | } 182 | 183 | #[actix_web::test] 184 | async fn test_garbage_collection() { 185 | tokio::time::pause(); 186 | let backend = InMemoryBackend::builder() 187 | .with_gc_interval(Some(MINUTE)) 188 | .build(); 189 | backend 190 | .request(SimpleInput { 191 | interval: MINUTE, 192 | max_requests: 1, 193 | key: "KEY1".to_string(), 194 | }) 195 | .await 196 | .unwrap(); 197 | backend 198 | .request(SimpleInput { 199 | interval: MINUTE * 2, 200 | max_requests: 1, 201 | key: "KEY2".to_string(), 202 | }) 203 | .await 204 | .unwrap(); 205 | assert!(backend.map.contains_key("KEY1")); 206 | assert!(backend.map.contains_key("KEY2")); 207 | // Advance time such that the garbage collector runs, 208 | // expired KEY1 should be cleaned, but KEY2 should remain. 209 | tokio::time::advance(MINUTE).await; 210 | assert!(!backend.map.contains_key("KEY1")); 211 | assert!(backend.map.contains_key("KEY2")); 212 | } 213 | 214 | #[actix_web::test] 215 | async fn test_output() { 216 | tokio::time::pause(); 217 | let backend = InMemoryBackend::builder().build(); 218 | let input = SimpleInput { 219 | interval: MINUTE, 220 | max_requests: 2, 221 | key: "KEY1".to_string(), 222 | }; 223 | // First of 2 should be allowed. 224 | let (decision, output, _) = backend.request(input.clone()).await.unwrap(); 225 | assert!(decision.is_allowed()); 226 | assert_eq!(output.remaining, 1); 227 | assert_eq!(output.limit, 2); 228 | assert_eq!(output.reset, Instant::now() + MINUTE); 229 | // Second of 2 should be allowed. 230 | let (decision, output, _) = backend.request(input.clone()).await.unwrap(); 231 | assert!(decision.is_allowed()); 232 | assert_eq!(output.remaining, 0); 233 | assert_eq!(output.limit, 2); 234 | assert_eq!(output.reset, Instant::now() + MINUTE); 235 | // Should be denied 236 | let (decision, output, _) = backend.request(input).await.unwrap(); 237 | assert!(decision.is_denied()); 238 | assert_eq!(output.remaining, 0); 239 | assert_eq!(output.limit, 2); 240 | assert_eq!(output.reset, Instant::now() + MINUTE); 241 | } 242 | 243 | #[actix_web::test] 244 | async fn test_rollback() { 245 | tokio::time::pause(); 246 | let backend = InMemoryBackend::builder().build(); 247 | let input = SimpleInput { 248 | interval: MINUTE, 249 | max_requests: 5, 250 | key: "KEY1".to_string(), 251 | }; 252 | let (_, output, rollback) = backend.request(input.clone()).await.unwrap(); 253 | assert_eq!(output.remaining, 4); 254 | backend.rollback(rollback).await.unwrap(); 255 | // Remaining requests should still be the same, since the previous call was excluded 256 | let (_, output, _) = backend.request(input).await.unwrap(); 257 | assert_eq!(output.remaining, 4); 258 | } 259 | 260 | #[actix_web::test] 261 | async fn test_remove_key() { 262 | tokio::time::pause(); 263 | let backend = InMemoryBackend::builder().with_gc_interval(None).build(); 264 | let input = SimpleInput { 265 | interval: MINUTE, 266 | max_requests: 1, 267 | key: "KEY1".to_string(), 268 | }; 269 | let (decision, _, _) = backend.request(input.clone()).await.unwrap(); 270 | assert!(decision.is_allowed()); 271 | let (decision, _, _) = backend.request(input.clone()).await.unwrap(); 272 | assert!(decision.is_denied()); 273 | backend.remove_key("KEY1").await.unwrap(); 274 | // Counter should have been reset 275 | let (decision, _, _) = backend.request(input).await.unwrap(); 276 | assert!(decision.is_allowed()); 277 | } 278 | } 279 | -------------------------------------------------------------------------------- /src/backend/mod.rs: -------------------------------------------------------------------------------- 1 | mod input_builder; 2 | 3 | #[cfg(feature = "dashmap")] 4 | #[cfg_attr(docsrs, doc(cfg(feature = "dashmap")))] 5 | pub mod memory; 6 | 7 | #[cfg(feature = "redis")] 8 | #[cfg_attr(docsrs, doc(cfg(feature = "redis")))] 9 | pub mod redis; 10 | 11 | pub use input_builder::{SimpleInputFunctionBuilder, SimpleInputFuture}; 12 | use std::future::Future; 13 | 14 | use crate::HeaderCompatibleOutput; 15 | use actix_web::rt::time::Instant; 16 | use std::time::Duration; 17 | 18 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 19 | pub enum Decision { 20 | Allowed, 21 | Denied, 22 | } 23 | 24 | impl Decision { 25 | pub fn from_allowed(allowed: bool) -> Self { 26 | if allowed { 27 | Self::Allowed 28 | } else { 29 | Self::Denied 30 | } 31 | } 32 | 33 | pub fn is_allowed(self) -> bool { 34 | matches!(self, Self::Allowed) 35 | } 36 | 37 | pub fn is_denied(self) -> bool { 38 | matches!(self, Self::Denied) 39 | } 40 | } 41 | 42 | /// Describes an implementation of a rate limiting store and algorithm. 43 | /// 44 | /// A Backend is required to implement [Clone], usually this means wrapping your data store within 45 | /// an [Arc](std::sync::Arc), although many connection pools already do so internally; there is no 46 | /// need to wrap it twice. 47 | pub trait Backend: Clone { 48 | type Output; 49 | type RollbackToken; 50 | type Error; 51 | 52 | /// Process an incoming request. 53 | /// 54 | /// The input could include such things as a rate limit key, and the rate limit policy to be 55 | /// applied. 56 | /// 57 | /// Returns a boolean of whether to allow or deny the request, arbitrary output that can be used 58 | /// to transform the allowed and denied responses, and a token to allow the rate limit counter 59 | /// to be rolled back in certain conditions. 60 | fn request( 61 | &self, 62 | input: I, 63 | ) -> impl Future>; 64 | 65 | /// Under certain conditions we may not want to rollback the request operation. 66 | /// 67 | /// E.g. We may want to exclude 5xx errors from counting against a user's rate limit, 68 | /// we can only exclude them after having already allowed the request through the rate limiter 69 | /// in the first place, so we must therefore deduct from the rate limit counter afterwards. 70 | /// 71 | /// Note that if this function fails there is not much the [RateLimiter](crate::RateLimiter) 72 | /// can do about it, given that the request has already been allowed. 73 | /// 74 | /// # Arguments 75 | /// 76 | /// * `token`: The token returned from the initial call to [Backend::request()]. 77 | fn rollback(&self, token: Self::RollbackToken) 78 | -> impl Future>; 79 | } 80 | 81 | /// A default [Backend] Input structure. 82 | /// 83 | /// This may not be suitable for all use-cases. 84 | #[derive(Debug, Clone)] 85 | pub struct SimpleInput { 86 | /// The rate limiting interval. 87 | pub interval: Duration, 88 | /// The total requests to be allowed within the interval. 89 | pub max_requests: u64, 90 | /// The rate limit key to be used for this request. 91 | pub key: String, 92 | } 93 | 94 | /// A default [Backend::Output] structure. 95 | /// 96 | /// This may not be suitable for all use-cases. 97 | #[derive(Debug, Clone)] 98 | pub struct SimpleOutput { 99 | /// Total number of requests that are permitted within the rate limit interval. 100 | pub limit: u64, 101 | /// Number of requests that will be permitted until the limit resets. 102 | pub remaining: u64, 103 | /// Time at which the rate limit resets. 104 | pub reset: Instant, 105 | } 106 | 107 | /// Additional functions for a [Backend] that uses [SimpleInput] and [SimpleOutput]. 108 | pub trait SimpleBackend: Backend { 109 | /// Removes the bucket for a given rate limit key. 110 | /// 111 | /// Intended to be used to reset a key before changing the interval. 112 | fn remove_key(&self, key: &str) -> impl Future>; 113 | } 114 | 115 | impl HeaderCompatibleOutput for SimpleOutput { 116 | fn limit(&self) -> u64 { 117 | self.limit 118 | } 119 | 120 | fn remaining(&self) -> u64 { 121 | self.remaining 122 | } 123 | 124 | /// Seconds until the rate limit resets (rounded upwards, so that it is guaranteed to be reset 125 | /// after waiting for the duration). 126 | fn seconds_until_reset(&self) -> u64 { 127 | let millis = self 128 | .reset 129 | .saturating_duration_since(Instant::now()) 130 | .as_millis() as f64; 131 | (millis / 1000f64).ceil() as u64 132 | } 133 | } 134 | 135 | #[cfg(test)] 136 | mod tests { 137 | use super::*; 138 | 139 | #[actix_web::test] 140 | async fn test_seconds_until_reset() { 141 | tokio::time::pause(); 142 | let output = SimpleOutput { 143 | limit: 0, 144 | remaining: 0, 145 | reset: Instant::now() + Duration::from_secs(60), 146 | }; 147 | tokio::time::advance(Duration::from_secs_f64(29.9)).await; 148 | // Verify rounded upwards from 30.1 149 | assert_eq!(output.seconds_until_reset(), 31); 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/backend/redis.rs: -------------------------------------------------------------------------------- 1 | use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput}; 2 | use actix_web::rt::time::Instant; 3 | use actix_web::{HttpResponse, ResponseError}; 4 | use redis::aio::ConnectionManager; 5 | use redis::AsyncCommands; 6 | use std::borrow::Cow; 7 | use std::time::Duration; 8 | use thiserror::Error; 9 | 10 | const BITFIELD_ENCODING: &str = "u63"; 11 | const BITFIELD_OFFSET: u8 = 0; 12 | 13 | #[derive(Debug, Error)] 14 | pub enum Error { 15 | #[error("Redis error: {0}")] 16 | Redis( 17 | #[source] 18 | #[from] 19 | redis::RedisError, 20 | ), 21 | #[error("Unexpected negative TTL response for the rate limit key")] 22 | NegativeTtl, 23 | } 24 | 25 | impl ResponseError for Error { 26 | fn error_response(&self) -> HttpResponse { 27 | HttpResponse::InternalServerError().finish() 28 | } 29 | } 30 | 31 | /// A Fixed Window rate limiter [Backend] that uses stores data in Redis. 32 | #[derive(Clone)] 33 | pub struct RedisBackend { 34 | connection: ConnectionManager, 35 | key_prefix: Option, 36 | } 37 | 38 | impl RedisBackend { 39 | /// Create a RedisBackendBuilder. 40 | /// 41 | /// # Arguments 42 | /// 43 | /// * `pool`: [A Redis connection pool](https://github.com/importcjj/mobc-redis) 44 | /// 45 | /// # Examples 46 | /// 47 | /// ```no_run 48 | /// # use actix_extensible_rate_limit::backend::redis::RedisBackend; 49 | /// # use redis::aio::ConnectionManager; 50 | /// # async fn example() { 51 | /// let client = redis::Client::open("redis://127.0.0.1/").unwrap(); 52 | /// let manager = ConnectionManager::new(client).await.unwrap(); 53 | /// let backend = RedisBackend::builder(manager).build(); 54 | /// # }; 55 | /// ``` 56 | pub fn builder(connection: ConnectionManager) -> Builder { 57 | Builder { 58 | connection, 59 | key_prefix: None, 60 | } 61 | } 62 | 63 | fn make_key<'t>(&self, key: &'t str) -> Cow<'t, str> { 64 | match &self.key_prefix { 65 | None => Cow::Borrowed(key), 66 | Some(prefix) => Cow::Owned(format!("{prefix}{key}")), 67 | } 68 | } 69 | } 70 | 71 | pub struct Builder { 72 | connection: ConnectionManager, 73 | key_prefix: Option, 74 | } 75 | 76 | impl Builder { 77 | /// Apply an optional prefix to all rate limit keys given to this backend. 78 | /// 79 | /// This may be useful when the Redis instance is being used for other purposes; the prefix is 80 | /// used as a 'namespace' to avoid collision with other caches or keys inside Redis. 81 | pub fn key_prefix(mut self, key_prefix: Option<&str>) -> Self { 82 | self.key_prefix = key_prefix.map(ToOwned::to_owned); 83 | self 84 | } 85 | 86 | pub fn build(self) -> RedisBackend { 87 | RedisBackend { 88 | connection: self.connection, 89 | key_prefix: self.key_prefix, 90 | } 91 | } 92 | } 93 | 94 | impl Backend for RedisBackend { 95 | type Output = SimpleOutput; 96 | type RollbackToken = String; 97 | type Error = Error; 98 | 99 | async fn request( 100 | &self, 101 | input: SimpleInput, 102 | ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> { 103 | let key = self.make_key(&input.key); 104 | 105 | let mut pipe = redis::pipe(); 106 | pipe.atomic() 107 | // Increment the rate limit count 108 | .cmd("BITFIELD") 109 | .arg(key.as_ref()) 110 | .arg("OVERFLOW") 111 | .arg("SAT") 112 | .arg("INCRBY") 113 | .arg(BITFIELD_ENCODING) 114 | .arg(BITFIELD_OFFSET) 115 | .arg(1) 116 | .arg("GET") 117 | .arg(BITFIELD_ENCODING) 118 | .arg(BITFIELD_OFFSET) 119 | // Set the key to expire (only if it doesn't already have an expiry) 120 | .cmd("EXPIRE") 121 | .arg(key.as_ref()) 122 | .arg(input.interval.as_secs()) 123 | .arg("NX") 124 | .ignore() 125 | // Return time-to-live of key 126 | .cmd("TTL") 127 | .arg(key.as_ref()); 128 | 129 | let mut con = self.connection.clone(); 130 | let (counts, ttl): (Vec, i64) = pipe.query_async(&mut con).await?; 131 | if ttl < 0 { 132 | return Err(Error::NegativeTtl); 133 | } 134 | let count = *counts.first().expect("BITFIELD should return one value"); 135 | 136 | let allow = count <= input.max_requests; 137 | let output = SimpleOutput { 138 | limit: input.max_requests, 139 | remaining: input.max_requests.saturating_sub(count), 140 | reset: Instant::now() + Duration::from_secs(ttl as u64), 141 | }; 142 | Ok((Decision::from_allowed(allow), output, input.key)) 143 | } 144 | 145 | async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> { 146 | let key = self.make_key(&token); 147 | 148 | let mut con = self.connection.clone(); 149 | 150 | let mut pipe = redis::pipe(); 151 | pipe.atomic() 152 | // Decrement the rate limit count 153 | .cmd("BITFIELD") 154 | .arg(key.as_ref()) 155 | .arg("OVERFLOW") 156 | .arg("SAT") 157 | .arg("INCRBY") 158 | .arg(BITFIELD_ENCODING) 159 | .arg(BITFIELD_OFFSET) 160 | .arg(-1) 161 | // Set the key to expire immediately, if it doesn't already have an expiry 162 | .cmd("EXPIRE") 163 | .arg(key.as_ref()) 164 | .arg(0) 165 | .arg("NX") 166 | .ignore(); 167 | 168 | let () = pipe.query_async(&mut con).await?; 169 | 170 | Ok(()) 171 | } 172 | } 173 | 174 | impl SimpleBackend for RedisBackend { 175 | /// Note that the key prefix (if set) is automatically included, you do not need to prepend 176 | /// it yourself. 177 | async fn remove_key(&self, key: &str) -> Result<(), Self::Error> { 178 | let key = self.make_key(key); 179 | let mut con = self.connection.clone(); 180 | let () = con.del(key.as_ref()).await?; 181 | Ok(()) 182 | } 183 | } 184 | 185 | #[cfg(test)] 186 | mod tests { 187 | use super::*; 188 | use crate::HeaderCompatibleOutput; 189 | use redis::Cmd; 190 | 191 | const MINUTE: Duration = Duration::from_secs(60); 192 | 193 | // Each test must use non-overlapping keys (because the tests may be run concurrently) 194 | // Each test should also reset its key on each run, so that it is in a clean state. 195 | async fn make_backend(clear_test_key: &str) -> Builder { 196 | let host = option_env!("REDIS_HOST").unwrap_or("127.0.0.1"); 197 | let port = option_env!("REDIS_PORT").unwrap_or("6379"); 198 | let client = redis::Client::open(format!("redis://{host}:{port}")).unwrap(); 199 | let mut manager = ConnectionManager::new(client).await.unwrap(); 200 | manager.del::<_, ()>(clear_test_key).await.unwrap(); 201 | RedisBackend::builder(manager) 202 | } 203 | 204 | #[actix_web::test] 205 | async fn test_allow_deny() { 206 | let backend = make_backend("test_allow_deny").await.build(); 207 | let input = SimpleInput { 208 | interval: MINUTE, 209 | max_requests: 5, 210 | key: "test_allow_deny".to_string(), 211 | }; 212 | let mut prev_seconds_until_reset = u64::MAX; 213 | for i in (0..5).rev() { 214 | // First 5 should be allowed 215 | let (decision, output, _) = backend.request(input.clone()).await.unwrap(); 216 | // Remaining counts should be decreasing 217 | assert_eq!(output.remaining, i); 218 | // Limit should be the same 219 | assert_eq!(output.limit, 5); 220 | // Request should be allowed 221 | assert!(decision.is_allowed()); 222 | // Check expiry time is going down each time (instead of being reset) 223 | assert!(output.seconds_until_reset() < prev_seconds_until_reset); 224 | // Sleep for a second 225 | prev_seconds_until_reset = output.seconds_until_reset(); 226 | tokio::time::sleep(Duration::from_secs(1)).await; 227 | } 228 | // Sixth should be denied 229 | let (decision, output, _) = backend.request(input.clone()).await.unwrap(); 230 | assert_eq!(output.remaining, 0); 231 | assert_eq!(output.limit, 5); 232 | assert!(decision.is_denied()); 233 | } 234 | 235 | #[actix_web::test] 236 | async fn test_reset() { 237 | let backend = make_backend("test_reset").await.build(); 238 | let input = SimpleInput { 239 | interval: Duration::from_secs(3), 240 | max_requests: 1, 241 | key: "test_reset".to_string(), 242 | }; 243 | // Make first request, should be allowed 244 | let (decision, _, _) = backend.request(input.clone()).await.unwrap(); 245 | assert!(decision.is_allowed()); 246 | 247 | // Request again immediately afterwards, should now be denied 248 | let (decision, out, _) = backend.request(input.clone()).await.unwrap(); 249 | assert!(decision.is_denied()); 250 | 251 | // Sleep until reset, should now be allowed 252 | tokio::time::sleep(Duration::from_secs(out.seconds_until_reset())).await; 253 | let (decision, _, _) = backend.request(input).await.unwrap(); 254 | assert!(decision.is_allowed()); 255 | } 256 | 257 | #[actix_web::test] 258 | async fn test_output() { 259 | let backend = make_backend("test_output").await.build(); 260 | let input = SimpleInput { 261 | interval: MINUTE, 262 | max_requests: 2, 263 | key: "test_output".to_string(), 264 | }; 265 | // First of 2 should be allowed. 266 | let (decision, output, _) = backend.request(input.clone()).await.unwrap(); 267 | assert!(decision.is_allowed()); 268 | assert_eq!(output.remaining, 1); 269 | assert_eq!(output.limit, 2); 270 | assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60); 271 | 272 | // Second of 2 should be allowed. 273 | let (decision, output, _) = backend.request(input.clone()).await.unwrap(); 274 | assert!(decision.is_allowed()); 275 | assert_eq!(output.remaining, 0); 276 | assert_eq!(output.limit, 2); 277 | assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60); 278 | 279 | // Should be denied 280 | let (decision, output, _) = backend.request(input).await.unwrap(); 281 | assert!(decision.is_denied()); 282 | assert_eq!(output.remaining, 0); 283 | assert_eq!(output.limit, 2); 284 | assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60); 285 | } 286 | 287 | #[actix_web::test] 288 | async fn test_rollback() { 289 | let backend = make_backend("test_rollback").await.build(); 290 | let input = SimpleInput { 291 | interval: MINUTE, 292 | max_requests: 5, 293 | key: "test_rollback".to_string(), 294 | }; 295 | let (_, output, rollback) = backend.request(input.clone()).await.unwrap(); 296 | assert_eq!(output.remaining, 4); 297 | backend.rollback(rollback).await.unwrap(); 298 | // Remaining requests should still be the same, since the previous call was excluded 299 | let (_, output, _) = backend.request(input).await.unwrap(); 300 | assert_eq!(output.remaining, 4); 301 | // Check ttl is not corrupted 302 | assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60); 303 | } 304 | 305 | #[actix_web::test] 306 | async fn test_rollback_key_gone() { 307 | let key = "test_rollback_key_gone"; 308 | let backend = make_backend(key).await.build(); 309 | let mut con = backend.connection.clone(); 310 | // The rollback could happen after the key has already expired / gone 311 | backend.rollback(key.to_string()).await.unwrap(); 312 | // In which case the count should remain at 0 (it must not become negative) 313 | let mut cmd = Cmd::new(); 314 | cmd.arg("BITFIELD") 315 | .arg(key) 316 | .arg("GET") 317 | .arg(BITFIELD_ENCODING) 318 | .arg(BITFIELD_OFFSET); 319 | let value: Vec = cmd.query_async(&mut con).await.unwrap(); 320 | assert_eq!(value[0], 0u64); 321 | } 322 | 323 | #[actix_web::test] 324 | async fn test_remove_key() { 325 | let backend = make_backend("test_remove_key").await.build(); 326 | let input = SimpleInput { 327 | interval: MINUTE, 328 | max_requests: 1, 329 | key: "test_remove_key".to_string(), 330 | }; 331 | let (decision, _, _) = backend.request(input.clone()).await.unwrap(); 332 | assert!(decision.is_allowed()); 333 | let (decision, _, _) = backend.request(input.clone()).await.unwrap(); 334 | assert!(decision.is_denied()); 335 | backend.remove_key("test_remove_key").await.unwrap(); 336 | // Counter should have been reset 337 | let (decision, _, _) = backend.request(input).await.unwrap(); 338 | assert!(decision.is_allowed()); 339 | } 340 | 341 | #[actix_web::test] 342 | async fn test_key_prefix() { 343 | let backend = make_backend("prefix:test_key_prefix") 344 | .await 345 | .key_prefix(Some("prefix:")) 346 | .build(); 347 | let mut con = backend.connection.clone(); 348 | let input = SimpleInput { 349 | interval: MINUTE, 350 | max_requests: 5, 351 | key: "test_key_prefix".to_string(), 352 | }; 353 | backend.request(input.clone()).await.unwrap(); 354 | assert!(con 355 | .exists::<_, bool>("prefix:test_key_prefix") 356 | .await 357 | .unwrap()); 358 | 359 | backend.remove_key("test_key_prefix").await.unwrap(); 360 | assert!(!con 361 | .exists::<_, bool>("prefix:test_key_prefix") 362 | .await 363 | .unwrap()); 364 | } 365 | } 366 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Rate limiting middleware for actix-web 2 | //! 3 | //! # Getting Started: 4 | //! ```no_run 5 | //! # use actix_extensible_rate_limit::backend::{memory::InMemoryBackend, SimpleInputFunctionBuilder}; 6 | //! # use actix_extensible_rate_limit::RateLimiter; 7 | //! # use actix_web::{App, HttpServer}; 8 | //! # use std::time::Duration; 9 | //! #[actix_web::main] 10 | //! async fn main() -> std::io::Result<()> { 11 | //! // A backend is responsible for storing rate limit data, and choosing whether to allow/deny requests 12 | //! let backend = InMemoryBackend::builder().build(); 13 | //! HttpServer::new(move || { 14 | //! // Assign a limit of 5 requests per minute per client ip address 15 | //! let input = SimpleInputFunctionBuilder::new(Duration::from_secs(60), 5) 16 | //! .real_ip_key() 17 | //! .build(); 18 | //! let middleware = RateLimiter::builder(backend.clone(), input) 19 | //! .add_headers() 20 | //! .build(); 21 | //! App::new().wrap(middleware) 22 | //! }) 23 | //! .bind("127.0.0.1:8080")? 24 | //! .run() 25 | //! .await 26 | //! } 27 | //! ``` 28 | 29 | #![cfg_attr(docsrs, feature(doc_cfg))] 30 | 31 | pub mod backend; 32 | mod middleware; 33 | 34 | pub use middleware::builder::{HeaderCompatibleOutput, RateLimiterBuilder}; 35 | pub use middleware::RateLimiter; 36 | -------------------------------------------------------------------------------- /src/middleware/builder.rs: -------------------------------------------------------------------------------- 1 | use crate::backend::Backend; 2 | use crate::middleware::{AllowedTransformation, DeniedResponse, RateLimiter, RollbackCondition}; 3 | use actix_web::dev::ServiceRequest; 4 | use actix_web::http::header::{HeaderMap, HeaderName, HeaderValue, RETRY_AFTER}; 5 | use actix_web::http::StatusCode; 6 | use actix_web::HttpResponse; 7 | use std::future::Future; 8 | use std::rc::Rc; 9 | 10 | #[allow(clippy::declare_interior_mutable_const)] 11 | pub const X_RATELIMIT_LIMIT: HeaderName = HeaderName::from_static("x-ratelimit-limit"); 12 | #[allow(clippy::declare_interior_mutable_const)] 13 | pub const X_RATELIMIT_REMAINING: HeaderName = HeaderName::from_static("x-ratelimit-remaining"); 14 | #[allow(clippy::declare_interior_mutable_const)] 15 | pub const X_RATELIMIT_RESET: HeaderName = HeaderName::from_static("x-ratelimit-reset"); 16 | 17 | pub struct RateLimiterBuilder { 18 | backend: BE, 19 | input_fn: F, 20 | fail_open: bool, 21 | allowed_transformation: Option>>, 22 | denied_response: Rc>, 23 | rollback_condition: Option>, 24 | } 25 | 26 | impl RateLimiterBuilder 27 | where 28 | BE: Backend + 'static, 29 | BI: 'static, 30 | F: Fn(&ServiceRequest) -> O, 31 | O: Future>, 32 | { 33 | pub(super) fn new(backend: BE, input_fn: F) -> Self { 34 | Self { 35 | backend, 36 | input_fn, 37 | fail_open: false, 38 | allowed_transformation: None, 39 | denied_response: Rc::new(|_| HttpResponse::TooManyRequests().finish()), 40 | rollback_condition: None, 41 | } 42 | } 43 | 44 | /// Choose whether to allow a request if the backend returns a failure. 45 | /// 46 | /// Default is false. 47 | pub fn fail_open(mut self, fail_open: bool) -> Self { 48 | self.fail_open = fail_open; 49 | self 50 | } 51 | 52 | /// Sets the [RateLimiterBuilder::request_allowed_transformation] and 53 | /// [RateLimiterBuilder::request_denied_response] functions, such that the following headers 54 | /// are set in both the allowed and denied responses: 55 | /// 56 | /// - `x-ratelimit-limit`\ 57 | /// - `x-ratelimit-remaining`\ 58 | /// - `x-ratelimit-reset` (seconds until the reset) 59 | /// - `retry-after` (denied only, seconds until the reset) 60 | /// 61 | /// This function requires the Backend Output to implement [HeaderCompatibleOutput] 62 | pub fn add_headers(mut self) -> Self 63 | where 64 | BO: HeaderCompatibleOutput, 65 | { 66 | self.allowed_transformation = Some(Rc::new(|map, output, rolled_back| { 67 | if let Some(status) = output { 68 | map.insert(X_RATELIMIT_LIMIT, HeaderValue::from(status.limit())); 69 | let remaining = if rolled_back { 70 | status.remaining() + 1 71 | } else { 72 | status.remaining() 73 | }; 74 | map.insert(X_RATELIMIT_REMAINING, HeaderValue::from(remaining)); 75 | map.insert( 76 | X_RATELIMIT_RESET, 77 | HeaderValue::from(status.seconds_until_reset()), 78 | ); 79 | } 80 | })); 81 | self.denied_response = Rc::new(|status| { 82 | let mut response = HttpResponse::TooManyRequests().finish(); 83 | let map = response.headers_mut(); 84 | map.insert(X_RATELIMIT_LIMIT, HeaderValue::from(status.limit())); 85 | map.insert(X_RATELIMIT_REMAINING, HeaderValue::from(status.remaining())); 86 | let seconds = status.seconds_until_reset(); 87 | map.insert(X_RATELIMIT_RESET, HeaderValue::from(seconds)); 88 | map.insert(RETRY_AFTER, HeaderValue::from(seconds)); 89 | response 90 | }); 91 | self 92 | } 93 | 94 | /// In the event that the request is allowed: 95 | /// 96 | /// You can optionally mutate the response headers to include the rate limit status. 97 | /// 98 | /// By default no changes are made to the response. 99 | /// 100 | /// Note the [Backend::Output] will be [None] if the backend failed and 101 | /// [RateLimiterBuilder::fail_open] is enabled. 102 | /// 103 | /// The boolean parameter indicates if the rate limit was rolled back (so the remaining 104 | /// request count can be adjusted). 105 | pub fn request_allowed_transformation(mut self, mutation: Option) -> Self 106 | where 107 | M: Fn(&mut HeaderMap, Option<&BO>, bool) + 'static, 108 | { 109 | self.allowed_transformation = mutation.map(|m| Rc::new(m) as Rc>); 110 | self 111 | } 112 | 113 | /// In the event that the request is denied, configure the [HttpResponse] returned. 114 | /// 115 | /// Defaults to an empty body with status 429. 116 | pub fn request_denied_response(mut self, denied_response: R) -> Self 117 | where 118 | R: Fn(&BO) -> HttpResponse + 'static, 119 | { 120 | self.denied_response = Rc::new(denied_response); 121 | self 122 | } 123 | 124 | /// After processing a request, attempt to rollback the request count based on the status 125 | /// of the service response. 126 | /// 127 | /// By default the rate limit is never rolled back. 128 | pub fn rollback_condition(mut self, condition: Option) -> Self 129 | where 130 | C: Fn(StatusCode) -> bool + 'static, 131 | { 132 | self.rollback_condition = condition.map(|m| Rc::new(m) as Rc); 133 | self 134 | } 135 | 136 | /// Configures the [RateLimiterBuilder::rollback_condition] to rollback if the status code 137 | /// is a server error (5xx). 138 | pub fn rollback_server_errors(self) -> Self { 139 | self.rollback_condition(Some(|status: StatusCode| status.is_server_error())) 140 | } 141 | 142 | pub fn build(self) -> RateLimiter { 143 | RateLimiter { 144 | backend: self.backend, 145 | input_fn: Rc::new(self.input_fn), 146 | fail_open: self.fail_open, 147 | allowed_mutation: self.allowed_transformation, 148 | denied_response: self.denied_response, 149 | rollback_condition: self.rollback_condition, 150 | } 151 | } 152 | } 153 | 154 | /// A trait that a [Backend::Output] should implement in order to use the 155 | /// [RateLimiterBuilder::add_headers] function. 156 | pub trait HeaderCompatibleOutput { 157 | /// Value for the `x-ratelimit-limit` header. 158 | fn limit(&self) -> u64; 159 | 160 | /// Value for the `x-ratelimit-remaining` header. 161 | fn remaining(&self) -> u64; 162 | 163 | /// Value for the `x-ratelimit-reset` and `retry-at` headers. 164 | /// 165 | /// This should be the number of seconds from now until the limit resets.\ 166 | /// If the limit has already reset this should return 0. 167 | fn seconds_until_reset(&self) -> u64; 168 | } 169 | -------------------------------------------------------------------------------- /src/middleware/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod builder; 2 | #[cfg(test)] 3 | mod tests; 4 | 5 | use crate::backend::Backend; 6 | use actix_web::body::EitherBody; 7 | use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}; 8 | use actix_web::http::header::HeaderMap; 9 | use actix_web::http::StatusCode; 10 | use actix_web::HttpResponse; 11 | use builder::RateLimiterBuilder; 12 | use futures::future::{ok, LocalBoxFuture, Ready}; 13 | use std::cell::RefCell; 14 | use std::{future::Future, rc::Rc}; 15 | 16 | type AllowedTransformation = dyn Fn(&mut HeaderMap, Option<&BO>, bool); 17 | type DeniedResponse = dyn Fn(&BO) -> HttpResponse; 18 | type RollbackCondition = dyn Fn(StatusCode) -> bool; 19 | 20 | /// Rate limit middleware. 21 | pub struct RateLimiter { 22 | backend: BA, 23 | input_fn: Rc, 24 | fail_open: bool, 25 | allowed_mutation: Option>>, 26 | denied_response: Rc>, 27 | rollback_condition: Option>, 28 | } 29 | 30 | impl Clone for RateLimiter 31 | where 32 | BA: Backend + 'static, 33 | BI: 'static, 34 | F: Fn(&ServiceRequest) -> O + 'static, 35 | O: Future>, 36 | { 37 | fn clone(&self) -> Self { 38 | Self { 39 | backend: self.backend.clone(), 40 | input_fn: self.input_fn.clone(), 41 | fail_open: self.fail_open, 42 | allowed_mutation: self.allowed_mutation.clone(), 43 | denied_response: self.denied_response.clone(), 44 | rollback_condition: self.rollback_condition.clone(), 45 | } 46 | } 47 | } 48 | 49 | impl RateLimiter 50 | where 51 | BA: Backend + 'static, 52 | BI: 'static, 53 | F: Fn(&ServiceRequest) -> O + 'static, 54 | O: Future>, 55 | { 56 | /// # Arguments 57 | /// 58 | /// * `backend`: A rate limiting algorithm and store implementation. 59 | /// * `input_fn`: A future that produces input to the backend based on the incoming request. 60 | pub fn builder(backend: BA, input_fn: F) -> RateLimiterBuilder { 61 | RateLimiterBuilder::new(backend, input_fn) 62 | } 63 | } 64 | 65 | impl Transform for RateLimiter 66 | where 67 | S: Service, Error = actix_web::Error> + 'static, 68 | S::Future: 'static, 69 | B: 'static, 70 | BA: Backend + 'static, 71 | BI: 'static, 72 | BO: 'static, 73 | BE: Into + std::fmt::Display + 'static, 74 | F: Fn(&ServiceRequest) -> O + 'static, 75 | O: Future>, 76 | { 77 | type Response = ServiceResponse>; 78 | type Error = actix_web::Error; 79 | type Transform = RateLimiterMiddleware; 80 | type InitError = (); 81 | type Future = Ready>; 82 | 83 | fn new_transform(&self, service: S) -> Self::Future { 84 | ok(RateLimiterMiddleware { 85 | service: Rc::new(RefCell::new(service)), 86 | backend: self.backend.clone(), 87 | input_fn: Rc::clone(&self.input_fn), 88 | fail_open: self.fail_open, 89 | allowed_transformation: self.allowed_mutation.clone(), 90 | denied_response: self.denied_response.clone(), 91 | rollback_condition: self.rollback_condition.clone(), 92 | }) 93 | } 94 | } 95 | 96 | pub struct RateLimiterMiddleware { 97 | service: Rc>, 98 | backend: BE, 99 | input_fn: Rc, 100 | fail_open: bool, 101 | allowed_transformation: Option>>, 102 | denied_response: Rc>, 103 | rollback_condition: Option>, 104 | } 105 | 106 | impl Service for RateLimiterMiddleware 107 | where 108 | S: Service, Error = actix_web::Error> + 'static, 109 | S::Future: 'static, 110 | B: 'static, 111 | BA: Backend + 'static, 112 | BI: 'static, 113 | BO: 'static, 114 | BE: Into + std::fmt::Display + 'static, 115 | F: Fn(&ServiceRequest) -> O + 'static, 116 | O: Future>, 117 | { 118 | type Response = ServiceResponse>; 119 | type Error = actix_web::Error; 120 | type Future = LocalBoxFuture<'static, Result>; 121 | 122 | forward_ready!(service); 123 | 124 | fn call(&self, req: ServiceRequest) -> Self::Future { 125 | let service = self.service.clone(); 126 | let backend = self.backend.clone(); 127 | let input_fn = self.input_fn.clone(); 128 | let fail_open = self.fail_open; 129 | let allowed_transformation = self.allowed_transformation.clone(); 130 | let denied_response = self.denied_response.clone(); 131 | let rollback_condition = self.rollback_condition.clone(); 132 | 133 | Box::pin(async move { 134 | let input = match input_fn(&req).await { 135 | Ok(input) => input, 136 | Err(e) => { 137 | log::error!("Rate limiter input function failed: {e}"); 138 | return Ok(req.into_response(e.error_response()).map_into_right_body()); 139 | } 140 | }; 141 | 142 | let (output, rollback) = match backend.request(input).await { 143 | // Able to successfully query rate limiter backend 144 | Ok((decision, output, rollback)) => { 145 | if decision.is_denied() { 146 | let response: HttpResponse = denied_response(&output); 147 | return Ok(req.into_response(response).map_into_right_body()); 148 | } 149 | (Some(output), Some(rollback)) 150 | } 151 | // Unable to query rate limiter backend 152 | Err(e) => { 153 | if fail_open { 154 | log::warn!("Rate limiter failed: {}, allowing the request anyway", e); 155 | (None, None) 156 | } else { 157 | log::error!("Rate limiter failed: {}", e); 158 | return Ok(req 159 | .into_response(e.into().error_response()) 160 | .map_into_right_body()); 161 | } 162 | } 163 | }; 164 | 165 | let mut service_response = service.call(req).await?; 166 | 167 | let mut rolled_back = false; 168 | if let Some(token) = rollback { 169 | if let Some(rollback_condition) = rollback_condition { 170 | let status = service_response.status(); 171 | if rollback_condition(status) { 172 | if let Err(e) = backend.rollback(token).await { 173 | log::error!("Unable to rollback rate-limit count for response: {:?}, error: {e}", status); 174 | } else { 175 | rolled_back = true; 176 | }; 177 | } 178 | } 179 | } 180 | 181 | if let Some(transformation) = allowed_transformation { 182 | transformation(service_response.headers_mut(), output.as_ref(), rolled_back); 183 | } 184 | 185 | Ok(service_response.map_into_left_body()) 186 | }) 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /src/middleware/tests.rs: -------------------------------------------------------------------------------- 1 | use crate::backend::Decision; 2 | use crate::middleware::*; 3 | use actix_web::http::header::{HeaderName, HeaderValue}; 4 | use actix_web::http::StatusCode; 5 | use actix_web::test::{read_body, TestRequest}; 6 | use actix_web::{get, test, App, HttpResponse, Responder, ResponseError}; 7 | use std::sync::atomic::{AtomicU64, Ordering}; 8 | use std::sync::Arc; 9 | use thiserror::Error; 10 | 11 | #[get("/200")] 12 | async fn route_200() -> impl Responder { 13 | HttpResponse::Ok().body("Hello world!") 14 | } 15 | 16 | #[get("/500")] 17 | async fn route_500() -> impl Responder { 18 | HttpResponse::InternalServerError().body("Internal error") 19 | } 20 | 21 | #[derive(Clone, Default)] 22 | struct MockBackend(Arc); 23 | 24 | #[derive(Default)] 25 | struct MockBackendInner { 26 | counter: AtomicU64, 27 | } 28 | 29 | struct MockBackendInput { 30 | max: u64, 31 | output: T, 32 | backend_error: Option, 33 | } 34 | 35 | impl Backend> for MockBackend { 36 | type Output = T; 37 | type RollbackToken = (); 38 | type Error = MockError; 39 | 40 | async fn request( 41 | &self, 42 | input: MockBackendInput, 43 | ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> { 44 | if let Some(e) = input.backend_error { 45 | return Err(e); 46 | } 47 | let allow = self.0.counter.fetch_add(1, Ordering::Relaxed) < input.max; 48 | Ok((Decision::from_allowed(allow), input.output, ())) 49 | } 50 | 51 | async fn rollback(&self, _: Self::RollbackToken) -> Result<(), Self::Error> { 52 | self.0.counter.fetch_sub(1, Ordering::Relaxed); 53 | Ok(()) 54 | } 55 | } 56 | 57 | #[derive(Debug, Clone, Error)] 58 | #[error("MockError: {message}")] 59 | struct MockError { 60 | code: StatusCode, 61 | message: String, 62 | } 63 | 64 | impl Default for MockError { 65 | fn default() -> Self { 66 | MockError { 67 | code: StatusCode::INTERNAL_SERVER_ERROR, 68 | message: "Mock Error".to_string(), 69 | } 70 | } 71 | } 72 | 73 | impl ResponseError for MockError { 74 | fn status_code(&self) -> StatusCode { 75 | self.code 76 | } 77 | } 78 | 79 | #[actix_web::test] 80 | async fn test_allow_deny() { 81 | let backend = MockBackend::default(); 82 | let limiter = RateLimiter::builder(backend, |_req| async { 83 | Ok(MockBackendInput { 84 | max: 1, 85 | output: (), 86 | backend_error: None, 87 | }) 88 | }) 89 | .build(); 90 | let app = test::init_service(App::new().service(route_200).wrap(limiter)).await; 91 | assert!( 92 | test::call_service(&app, TestRequest::get().uri("/200").to_request()) 93 | .await 94 | .status() 95 | .is_success() 96 | ); 97 | assert_eq!( 98 | test::call_service(&app, TestRequest::get().uri("/200").to_request()) 99 | .await 100 | .status(), 101 | StatusCode::TOO_MANY_REQUESTS 102 | ); 103 | } 104 | 105 | #[actix_web::test] 106 | async fn test_custom_deny_response() { 107 | let backend = MockBackend::default(); 108 | let limiter = RateLimiter::builder(backend, |_req| async { 109 | Ok(MockBackendInput { 110 | max: 0, 111 | output: StatusCode::IM_A_TEAPOT, 112 | backend_error: None, 113 | }) 114 | }) 115 | .request_denied_response(|output| HttpResponse::build(*output).body("Custom denied response")) 116 | .build(); 117 | let app = test::init_service(App::new().service(route_200).wrap(limiter)).await; 118 | let response = test::call_service(&app, TestRequest::get().uri("/200").to_request()).await; 119 | assert_eq!(response.status(), StatusCode::IM_A_TEAPOT); 120 | let body = String::from_utf8(read_body(response).await.to_vec()).unwrap(); 121 | assert_eq!(body, "Custom denied response"); 122 | } 123 | 124 | #[actix_web::test] 125 | async fn test_header_transformation() { 126 | let backend = MockBackend::default(); 127 | let limiter = RateLimiter::builder(backend, |_req| async { 128 | Ok(MockBackendInput { 129 | max: u64::MAX, 130 | output: "abc".to_string(), 131 | backend_error: None, 132 | }) 133 | }) 134 | .request_allowed_transformation(Some( 135 | |headers: &mut HeaderMap, output: Option<&String>, rolled_back: bool| { 136 | assert!(!rolled_back); 137 | assert!( 138 | output.is_some(), 139 | "Backend is working so output should be some" 140 | ); 141 | headers.insert( 142 | HeaderName::from_static("test-header"), 143 | HeaderValue::from_str(output.unwrap()).unwrap(), 144 | ); 145 | }, 146 | )) 147 | .build(); 148 | let app = test::init_service(App::new().service(route_200).wrap(limiter)).await; 149 | let response = test::call_service(&app, TestRequest::get().uri("/200").to_request()).await; 150 | assert_eq!(response.status(), StatusCode::OK); 151 | assert_eq!( 152 | response 153 | .headers() 154 | .get("test-header") 155 | .unwrap() 156 | .to_str() 157 | .unwrap(), 158 | "abc" 159 | ); 160 | } 161 | 162 | #[actix_web::test] 163 | async fn test_fail_open() { 164 | let backend = MockBackend::default(); 165 | 166 | // Test first without fail open 167 | let limiter = RateLimiter::builder(backend.clone(), |_req| async { 168 | Ok(MockBackendInput { 169 | max: u64::MAX, 170 | output: (), 171 | backend_error: Some(MockError::default().into()), 172 | }) 173 | }) 174 | .build(); 175 | let app = test::init_service(App::new().service(route_200).wrap(limiter)).await; 176 | let response = test::call_service(&app, TestRequest::get().uri("/200").to_request()).await; 177 | assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); 178 | 179 | // Test again with fail open enabled 180 | let limiter = RateLimiter::builder(backend, |_req| async { 181 | Ok(MockBackendInput { 182 | max: u64::MAX, 183 | output: (), 184 | backend_error: Some(MockError::default().into()), 185 | }) 186 | }) 187 | .request_allowed_transformation(Some( 188 | |map: &mut HeaderMap, output: Option<&()>, rolled_back: bool| { 189 | assert!(!rolled_back); 190 | map.insert( 191 | HeaderName::from_static("custom-header"), 192 | HeaderValue::from_static(""), 193 | ); 194 | assert!(output.is_none()); 195 | }, 196 | )) 197 | .fail_open(true) 198 | .build(); 199 | let app = test::init_service(App::new().service(route_200).wrap(limiter)).await; 200 | let response = test::call_service(&app, TestRequest::get().uri("/200").to_request()).await; 201 | assert_eq!(response.status(), StatusCode::OK); 202 | assert!(response.headers().contains_key("custom-header")) 203 | } 204 | 205 | #[actix_web::test] 206 | async fn test_rollback() { 207 | let backend = MockBackend::default(); 208 | let limiter = RateLimiter::builder(backend.clone(), |_req| async { 209 | Ok(MockBackendInput { 210 | max: u64::MAX, 211 | output: (), 212 | backend_error: None, 213 | }) 214 | }) 215 | .rollback_server_errors() 216 | .build(); 217 | let app = test::init_service( 218 | App::new() 219 | .service(route_200) 220 | .service(route_500) 221 | .wrap(limiter), 222 | ) 223 | .await; 224 | 225 | // Confirm count increases for a 200 response 226 | let response = test::call_service(&app, TestRequest::get().uri("/200").to_request()).await; 227 | assert_eq!(response.status(), StatusCode::OK); 228 | assert_eq!(backend.0.counter.load(Ordering::Relaxed), 1); 229 | 230 | // Confirm count hasn't increased because of rollback 231 | let response = test::call_service(&app, TestRequest::get().uri("/500").to_request()).await; 232 | assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); 233 | assert_eq!(backend.0.counter.load(Ordering::Relaxed), 1); 234 | } 235 | --------------------------------------------------------------------------------