├── .github └── workflows │ ├── release.yaml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── aws-lambda │ ├── README.md │ ├── main.py │ └── script.sh ├── flask-server │ ├── README.md │ └── hello.py └── vercel │ ├── .gitignore │ ├── Pipfile │ ├── api │ └── index.py │ └── vercel.json ├── pyproject.toml ├── tests ├── __init__.py ├── asyncio │ ├── __init__.py │ ├── test_block_until_ready.py │ ├── test_fixed_window.py │ ├── test_sliding_window.py │ └── test_token_bucket.py ├── conftest.py ├── test_block_until_ready.py ├── test_fixed_window.py ├── test_sliding_window.py ├── test_token_bucket.py ├── test_utils.py └── utils.py └── upstash_ratelimit ├── __init__.py ├── asyncio ├── __init__.py └── ratelimit.py ├── limiter.py ├── py.typed ├── ratelimit.py ├── typing.py └── utils.py /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: workflow_dispatch 4 | 5 | jobs: 6 | release: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout repository 11 | uses: actions/checkout@v2 12 | 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | 18 | - name: Install Poetry 19 | run: curl -sSL https://install.python-poetry.org | python3 - 20 | 21 | - name: Build and publish 22 | run: | 23 | poetry config pypi-token.pypi "${{ secrets.PYPI_TOKEN }}" 24 | poetry build 25 | poetry publish --no-interaction 26 | 27 | - name: Generate release tag 28 | run: echo "RELEASE_TAG=v$(poetry version | awk '{print $2}')" >> $GITHUB_ENV 29 | 30 | - name: Create GitHub Release 31 | uses: actions/create-release@v1 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 34 | with: 35 | tag_name: ${{ env.RELEASE_TAG }} 36 | release_name: Release ${{ env.RELEASE_TAG }} 37 | draft: false 38 | prerelease: false 39 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout repository 14 | uses: actions/checkout@v2 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.8 20 | 21 | - name: Install Poetry 22 | run: curl -sSL https://install.python-poetry.org | python3 - 23 | 24 | - name: Set up Poetry environment 25 | run: | 26 | poetry cache clear PyPI --all 27 | poetry install --no-root 28 | 29 | - name: Run mypy 30 | run: | 31 | poetry run mypy --show-error-codes . 32 | 33 | - name: Run tests 34 | run: | 35 | export UPSTASH_REDIS_REST_URL="${{ secrets.UPSTASH_REDIS_REST_URL }}" 36 | export UPSTASH_REDIS_REST_TOKEN="${{ secrets.UPSTASH_REDIS_REST_TOKEN }}" 37 | poetry run pytest 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store 163 | 164 | .vscode 165 | 166 | poetry.lock 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tudor Zgîmbău 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Upstash Ratelimit Python SDK 2 | 3 | > [!NOTE] 4 | > **This project is in GA Stage.** 5 | > The Upstash Professional Support fully covers this project. It receives regular updates, and bug fixes. The Upstash team is committed to maintaining and improving its functionality. 6 | 7 | `upstash-ratelimit` is a connectionless rate limiting library for Python, designed to be used in serverless environments such as: 8 | 9 | - AWS Lambda 10 | - Vercel Serverless 11 | - Google Cloud Functions 12 | - and other environments where HTTP is preferred over TCP. 13 | 14 | The SDK is currently compatible with Python 3.8 and above. 15 | 16 | 17 | - [Upstash Ratelimit Python SDK](#upstash-ratelimit-python-sdk) 18 | - [Quick Start](#quick-start) 19 | - [Install](#install) 20 | - [Create database](#create-database) 21 | - [Usage](#usage) 22 | - [Block until ready](#block-until-ready) 23 | - [Using multiple limits](#using-multiple-limits) 24 | - [Ratelimiting algorithms](#ratelimiting-algorithms) 25 | - [Fixed Window](#fixed-window) 26 | - [Pros](#pros) 27 | - [Cons](#cons) 28 | - [Usage](#usage-1) 29 | - [Sliding Window](#sliding-window) 30 | - [Pros](#pros-1) 31 | - [Cons](#cons-1) 32 | - [Usage](#usage-2) 33 | - [Token Bucket](#token-bucket) 34 | - [Pros](#pros-2) 35 | - [Cons](#cons-2) 36 | - [Usage](#usage-3) 37 | - [Contributing](#contributing) 38 | - [Preparing the environment](#preparing-the-environment) 39 | - [Running tests](#running-tests) 40 | 41 | 42 | # Quick Start 43 | 44 | ## Install 45 | 46 | ```bash 47 | pip install upstash-ratelimit 48 | ``` 49 | 50 | ## Create database 51 | To be able to use upstash-ratelimit, you need to create a database on [Upstash](https://console.upstash.com/). 52 | 53 | ## Usage 54 | 55 | For possible Redis client configurations, have a look at the [Redis SDK repository](https://github.com/upstash/redis-python). 56 | 57 | > This library supports asyncio as well. To use it, import the asyncio-based 58 | variant from the `upstash_ratelimit.asyncio` module. 59 | 60 | ```python 61 | from upstash_ratelimit import Ratelimit, FixedWindow 62 | from upstash_redis import Redis 63 | 64 | # Create a new ratelimiter, that allows 10 requests per 10 seconds 65 | ratelimit = Ratelimit( 66 | redis=Redis.from_env(), 67 | limiter=FixedWindow(max_requests=10, window=10), 68 | # Optional prefix for the keys used in Redis. This is useful 69 | # if you want to share a Redis instance with other applications 70 | # and want to avoid key collisions. The default prefix is 71 | # "@upstash/ratelimit" 72 | prefix="@upstash/ratelimit", 73 | ) 74 | 75 | # Use a constant string to limit all requests with a single ratelimit 76 | # Or use a user ID, API key or IP address for individual limits. 77 | identifier = "api" 78 | response = ratelimit.limit(identifier) 79 | 80 | if not response.allowed: 81 | print("Unable to process at this time") 82 | else: 83 | do_expensive_calculation() 84 | print("Here you go!") 85 | 86 | ``` 87 | 88 | The `limit` method also returns the following metadata: 89 | 90 | 91 | ```python 92 | @dataclasses.dataclass 93 | class Response: 94 | allowed: bool 95 | """ 96 | Whether the request may pass(`True`) or exceeded the limit(`False`) 97 | """ 98 | 99 | limit: int 100 | """ 101 | Maximum number of requests allowed within a window. 102 | """ 103 | 104 | remaining: int 105 | """ 106 | How many requests the user has left within the current window. 107 | """ 108 | 109 | reset: float 110 | """ 111 | Unix timestamp in seconds when the limits are reset 112 | """ 113 | ``` 114 | 115 | ## Block until ready 116 | 117 | You also have the option to try and wait for a request to pass in the given timeout. 118 | 119 | It is very similar to the `limit` method and takes an identifier and returns the same 120 | response. However if the current limit has already been exceeded, it will automatically 121 | wait until the next window starts and will try again. Setting the timeout parameter (in seconds) will cause the method to block a finite amount of time. 122 | 123 | ```python 124 | from upstash_ratelimit import Ratelimit, SlidingWindow 125 | from upstash_redis import Redis 126 | 127 | # Create a new ratelimiter, that allows 10 requests per 10 seconds 128 | ratelimit = Ratelimit( 129 | redis=Redis.from_env(), 130 | limiter=SlidingWindow(max_requests=10, window=10), 131 | ) 132 | 133 | response = ratelimit.block_until_ready("id", timeout=30) 134 | 135 | if not response.allowed: 136 | print("Unable to process, even after 30 seconds") 137 | else: 138 | do_expensive_calculation() 139 | print("Here you go!") 140 | ``` 141 | 142 | 143 | ## Using multiple limits 144 | Sometimes you might want to apply different limits to different users. For example you might want to allow 10 requests per 10 seconds for free users, but 60 requests per 10 seconds for paid users. 145 | 146 | Here's how you could do that: 147 | 148 | ```python 149 | from upstash_ratelimit import Ratelimit, SlidingWindow 150 | from upstash_redis import Redis 151 | 152 | class MultiRL: 153 | def __init__(self) -> None: 154 | redis = Redis.from_env() 155 | self.free = Ratelimit( 156 | redis=redis, 157 | limiter=SlidingWindow(max_requests=10, window=10), 158 | prefix="ratelimit:free", 159 | ) 160 | 161 | self.paid = Ratelimit( 162 | redis=redis, 163 | limiter=SlidingWindow(max_requests=60, window=10), 164 | prefix="ratelimit:paid", 165 | ) 166 | 167 | # Create a new ratelimiter, that allows 10 requests per 10 seconds 168 | ratelimit = MultiRL() 169 | 170 | ratelimit.free.limit("userIP") 171 | ratelimit.paid.limit("userIP") 172 | ``` 173 | 174 | # Ratelimiting algorithms 175 | 176 | ## Fixed Window 177 | 178 | This algorithm divides time into fixed durations/windows. For example each window is 10 seconds long. When a new request comes in, the current time is used to determine the window and a counter is increased. If the counter is larger than the set limit, the request is rejected. 179 | 180 | ### Pros 181 | - Very cheap in terms of data size and computation 182 | - Newer requests are not starved due to a high burst in the past 183 | 184 | ### Cons 185 | - Can cause high bursts at the window boundaries to leak through 186 | - Causes request stampedes if many users are trying to access your server, whenever a new window begins 187 | 188 | ### Usage 189 | 190 | ```python 191 | from upstash_ratelimit import Ratelimit, FixedWindow 192 | from upstash_redis import Redis 193 | 194 | ratelimit = Ratelimit( 195 | redis=Redis.from_env(), 196 | limiter=FixedWindow(max_requests=10, window=10), 197 | ) 198 | ``` 199 | 200 | ## Sliding Window 201 | 202 | Builds on top of fixed window but instead of a fixed window, we use a rolling window. Take this example: We have a rate limit of 10 requests per 1 minute. We divide time into 1 minute slices, just like in the fixed window algorithm. Window 1 will be from 00:00:00 to 00:01:00 (HH:MM:SS). Let's assume it is currently 00:01:15 and we have received 4 requests in the first window and 5 requests so far in the current window. The approximation to determine if the request should pass works like this: 203 | 204 | ```python 205 | limit = 10 206 | 207 | # 4 request from the old window, weighted + requests in current window 208 | rate = 4 * ((60 - 15) / 60) + 5 = 8 209 | 210 | return rate < limit # True means we should allow the request 211 | ``` 212 | 213 | ### Pros 214 | - Solves the issue near boundary from fixed window. 215 | 216 | ### Cons 217 | - More expensive in terms of storage and computation 218 | - It's only an approximation because it assumes a uniform request flow in the previous window 219 | 220 | ### Usage 221 | 222 | ```python 223 | from upstash_ratelimit import Ratelimit, SlidingWindow 224 | from upstash_redis import Redis 225 | 226 | ratelimit = Ratelimit( 227 | redis=Redis.from_env(), 228 | limiter=SlidingWindow(max_requests=10, window=10), 229 | ) 230 | ``` 231 | 232 | ## Token Bucket 233 | 234 | Consider a bucket filled with maximum number of tokens that refills constantly at a rate per interval. Every request will remove one token from the bucket and if there is no token to take, the request is rejected. 235 | 236 | ### Pros 237 | - Bursts of requests are smoothed out and you can process them at a constant rate. 238 | - Allows setting a higher initial burst limit by setting maximum number of tokens higher than the refill rate 239 | 240 | ### Cons 241 | - Expensive in terms of computation 242 | 243 | ### Usage 244 | 245 | ```python 246 | from upstash_ratelimit import Ratelimit, TokenBucket 247 | from upstash_redis import Redis 248 | 249 | ratelimit = Ratelimit( 250 | redis=Redis.from_env(), 251 | limiter=TokenBucket(max_tokens=10, refill_rate=5, interval=10), 252 | ) 253 | ``` 254 | 255 | # Custom Rates 256 | 257 | When rate limiting, you may want different requests to consume different amounts of tokens. 258 | This could be useful when processing batches of requests where you want to rate limit based 259 | on items in the batch or when you want to rate limit based on the number of tokens. 260 | 261 | To achieve this, you can simply pass `rate` parameter when calling the limit method: 262 | 263 | ```python 264 | 265 | from upstash_ratelimit import Ratelimit, FixedWindow 266 | from upstash_redis import Redis 267 | 268 | ratelimit = Ratelimit( 269 | redis=Redis.from_env(), 270 | limiter=FixedWindow(max_requests=10, window=10), 271 | ) 272 | 273 | # pass rate as 5 to subtract 5 from the number of 274 | # allowed requests in the window: 275 | identifier = "api" 276 | response = ratelimit.limit(identifier, rate=5) 277 | ``` 278 | 279 | # Contributing 280 | 281 | ## Preparing the environment 282 | This project uses [Poetry](https://python-poetry.org) for packaging and dependency management. Make sure you are able to create the poetry shell with relevant dependencies. 283 | 284 | You will also need a database on [Upstash](https://console.upstash.com/). 285 | 286 | ## Running tests 287 | To run all the tests, make sure the poetry virtual environment activated with all 288 | the necessary dependencies. Set the `UPSTASH_REDIS_REST_URL` and `UPSTASH_REDIS_REST_TOKEN` environment variables and run: 289 | 290 | ```bash 291 | poetry run pytest 292 | ``` 293 | -------------------------------------------------------------------------------- /examples/aws-lambda/README.md: -------------------------------------------------------------------------------- 1 | ### Simple Flask Server 2 | 3 | #### Run 4 | ```bash 5 | ./script 6 | ``` 7 | This script will install necessary depenencies and generate the .zip file ready for upload. 8 | #### Test 9 | Upload the `lambda.zip` file to aws lambda and test by calling the endpoint. -------------------------------------------------------------------------------- /examples/aws-lambda/main.py: -------------------------------------------------------------------------------- 1 | from upstash_redis import Redis 2 | 3 | from upstash_ratelimit import FixedWindow, Ratelimit 4 | 5 | ratelimit = Ratelimit( 6 | redis=Redis.from_env(allow_telemetry=False), 7 | limiter=FixedWindow(max_requests=1, window=10), 8 | ) 9 | 10 | 11 | def lambda_handler(event, context): 12 | response = ratelimit.limit("id") 13 | print(response) 14 | -------------------------------------------------------------------------------- /examples/aws-lambda/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm -rf dist && rm lambda.zip 3 | pip3 install --target ./dist upstash-ratelimit 4 | cp main.py ./dist/lambda_function.py 5 | cd dist && zip -r lambda.zip . && cd - 6 | mv ./dist/lambda.zip ./ 7 | # rm -rf dist -------------------------------------------------------------------------------- /examples/flask-server/README.md: -------------------------------------------------------------------------------- 1 | ### Simple Flask Server 2 | 3 | #### Run 4 | ```bash 5 | pip3 install flask 6 | flask --app hello run 7 | ``` 8 | 9 | #### Test 10 | Send request to localhost:5000/request to see the rate limiter in action. -------------------------------------------------------------------------------- /examples/flask-server/hello.py: -------------------------------------------------------------------------------- 1 | from flask import Flask # type: ignore 2 | from upstash_redis import Redis 3 | 4 | from upstash_ratelimit import FixedWindow, Ratelimit 5 | 6 | ratelimit = Ratelimit( 7 | redis=Redis.from_env(allow_telemetry=False), 8 | limiter=FixedWindow(max_requests=2, window=40), 9 | ) 10 | 11 | app = Flask(__name__) 12 | 13 | 14 | @app.route("/") 15 | def hello_world(): 16 | return "

Hello, World!

" 17 | 18 | 19 | @app.route("/request") 20 | def request(): 21 | response = ratelimit.block_until_ready("timeout_1", 10) 22 | return f"

{response}

" 23 | -------------------------------------------------------------------------------- /examples/vercel/.gitignore: -------------------------------------------------------------------------------- 1 | .vercel 2 | -------------------------------------------------------------------------------- /examples/vercel/Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | flask = "*" 8 | upstash-ratelimit = "*" 9 | 10 | [requires] 11 | python_version = "3.9" 12 | -------------------------------------------------------------------------------- /examples/vercel/api/index.py: -------------------------------------------------------------------------------- 1 | from http.server import BaseHTTPRequestHandler 2 | 3 | from upstash_redis import Redis 4 | 5 | from upstash_ratelimit import FixedWindow, Ratelimit 6 | 7 | ratelimit = Ratelimit( 8 | redis=Redis.from_env(allow_telemetry=False), 9 | limiter=FixedWindow(max_requests=5, window=5), 10 | ) 11 | 12 | 13 | class handler(BaseHTTPRequestHandler): 14 | def do_GET(self): 15 | response = ratelimit.limit("global") 16 | self.send_header("Content-type", "text/plain") 17 | self.send_header("X-Ratelimit-Limit", response.limit) 18 | self.send_header("X-Ratelimit-Remaining", response.remaining) 19 | self.send_header("X-Ratelimit-Reset", response.reset) 20 | 21 | self.end_headers() 22 | if not response.allowed: 23 | self.send_response(429) 24 | self.wfile.write("Come back later!".encode("utf-8")) 25 | else: 26 | self.send_response(200) 27 | self.wfile.write("Hello!".encode("utf-8")) 28 | -------------------------------------------------------------------------------- /examples/vercel/vercel.json: -------------------------------------------------------------------------------- 1 | { 2 | "rewrites": [ 3 | { 4 | "source": "/(.*)", 5 | "destination": "/api/index" 6 | } 7 | ] 8 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "upstash-ratelimit" 3 | version = "1.1.0" 4 | description = "Serverless ratelimiting package from Upstash" 5 | authors = ["Upstash ", "Zgîmbău Tudor "] 6 | maintainers = ["Upstash "] 7 | readme = "README.md" 8 | repository = "https://github.com/upstash/ratelimit-python" 9 | keywords = ["ratelimit", "rate limit", "Upstash rate limit", "Redis rate limit"] 10 | classifiers = [ 11 | "Development Status :: 5 - Production/Stable", 12 | "Intended Audience :: Developers", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: OS Independent", 15 | "Programming Language :: Python", 16 | "Programming Language :: Python :: 3", 17 | "Programming Language :: Python :: 3 :: Only", 18 | "Programming Language :: Python :: 3.8", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Programming Language :: Python :: Implementation :: CPython", 24 | "Topic :: Software Development :: Libraries", 25 | ] 26 | packages = [{include = "upstash_ratelimit"}] 27 | 28 | [tool.poetry.dependencies] 29 | python = "^3.8" 30 | upstash-redis = "^1.0.0" 31 | 32 | [tool.poetry.group.dev.dependencies] 33 | pytest = "^7.3.0" 34 | pytest-asyncio = "^0.21.0" 35 | mypy = "^1.4.1" 36 | 37 | [build-system] 38 | requires = ["poetry-core"] 39 | build-backend = "poetry.core.masonry.api" 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upstash/ratelimit-py/bd70dfdb93ccfc0fea7a3d854c7942f8dce21f0a/tests/__init__.py -------------------------------------------------------------------------------- /tests/asyncio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upstash/ratelimit-py/bd70dfdb93ccfc0fea7a3d854c7942f8dce21f0a/tests/asyncio/__init__.py -------------------------------------------------------------------------------- /tests/asyncio/test_block_until_ready.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from pytest import mark, raises 4 | from upstash_redis.asyncio import Redis 5 | 6 | from tests.utils import random_id 7 | from upstash_ratelimit.asyncio import FixedWindow, Ratelimit 8 | 9 | 10 | @mark.asyncio() 11 | async def test_invalid_timeout(async_redis: Redis) -> None: 12 | ratelimit = Ratelimit( 13 | redis=async_redis, 14 | limiter=FixedWindow(max_requests=1, window=1), 15 | ) 16 | 17 | with raises(ValueError): 18 | await ratelimit.block_until_ready(random_id(), -1) 19 | 20 | 21 | @mark.asyncio() 22 | async def test_resolve_before_timeout(async_redis: Redis) -> None: 23 | ratelimit = Ratelimit( 24 | redis=async_redis, 25 | limiter=FixedWindow(max_requests=5, window=100), 26 | ) 27 | 28 | timeout = 50 29 | 30 | start = time.time() 31 | response = await ratelimit.block_until_ready(random_id(), timeout) 32 | elapsed = time.time() - start 33 | 34 | assert elapsed < timeout 35 | assert response.allowed is True 36 | 37 | 38 | @mark.asyncio() 39 | async def test_resolve_before_timeout_when_window_resets( 40 | async_redis: Redis, 41 | ) -> None: 42 | ratelimit = Ratelimit( 43 | redis=async_redis, 44 | limiter=FixedWindow(max_requests=1, window=3), 45 | ) 46 | 47 | id = random_id() 48 | timeout = 100 49 | 50 | await ratelimit.limit(id) 51 | 52 | start = time.time() 53 | response = await ratelimit.block_until_ready(id, timeout) 54 | elapsed = time.time() - start 55 | 56 | assert elapsed < timeout 57 | assert response.allowed is True 58 | 59 | 60 | @mark.asyncio() 61 | async def test_reaching_timeout(async_redis: Redis) -> None: 62 | ratelimit = Ratelimit( 63 | redis=async_redis, 64 | limiter=FixedWindow(max_requests=1, window=1, unit="d"), 65 | ) 66 | 67 | id = random_id() 68 | timeout = 2 69 | 70 | await ratelimit.limit(id) 71 | 72 | start = time.time() 73 | response = await ratelimit.block_until_ready(id, timeout) 74 | elapsed = time.time() - start 75 | 76 | assert elapsed >= timeout 77 | assert response.allowed is False 78 | -------------------------------------------------------------------------------- /tests/asyncio/test_fixed_window.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from unittest.mock import patch 3 | 4 | from pytest import approx, mark 5 | from upstash_redis.asyncio import Redis 6 | 7 | from tests.utils import random_id 8 | from upstash_ratelimit.asyncio import FixedWindow, Ratelimit 9 | from upstash_ratelimit.utils import now_s 10 | 11 | 12 | @mark.asyncio() 13 | async def test_max_requests_are_not_reached(async_redis: Redis) -> None: 14 | ratelimit = Ratelimit( 15 | redis=async_redis, 16 | limiter=FixedWindow(max_requests=5, window=10), 17 | ) 18 | 19 | now = now_s() 20 | response = await ratelimit.limit(random_id()) 21 | 22 | assert response.allowed is True 23 | assert response.limit == 5 24 | assert response.remaining == 4 25 | assert response.reset >= now 26 | 27 | 28 | @mark.asyncio() 29 | async def test_max_requests_are_reached(async_redis: Redis) -> None: 30 | ratelimit = Ratelimit( 31 | redis=async_redis, 32 | limiter=FixedWindow(max_requests=1, window=1, unit="d"), 33 | ) 34 | 35 | id = random_id() 36 | 37 | await ratelimit.limit(id) 38 | 39 | now = now_s() 40 | response = await ratelimit.limit(id) 41 | 42 | assert response.allowed is False 43 | assert response.limit == 1 44 | assert response.remaining == 0 45 | assert response.reset >= now 46 | 47 | 48 | @mark.asyncio() 49 | async def test_window_reset(async_redis: Redis) -> None: 50 | ratelimit = Ratelimit( 51 | redis=async_redis, 52 | limiter=FixedWindow(max_requests=1, window=3), 53 | ) 54 | 55 | id = random_id() 56 | 57 | await ratelimit.limit(id) 58 | 59 | await asyncio.sleep(3) 60 | 61 | now = now_s() 62 | response = await ratelimit.limit(id) 63 | 64 | assert response.allowed is True 65 | assert response.limit == 1 66 | assert response.remaining == 0 67 | assert response.reset >= now 68 | 69 | 70 | @mark.asyncio() 71 | async def test_get_remaining(async_redis: Redis) -> None: 72 | ratelimit = Ratelimit( 73 | redis=async_redis, 74 | limiter=FixedWindow(max_requests=10, window=1, unit="d"), 75 | ) 76 | 77 | id = random_id() 78 | assert await ratelimit.get_remaining(id) == 10 79 | await ratelimit.limit(id) 80 | assert await ratelimit.get_remaining(id) == 9 81 | 82 | 83 | @mark.asyncio() 84 | async def test_get_reset(async_redis: Redis) -> None: 85 | ratelimit = Ratelimit( 86 | redis=async_redis, 87 | limiter=FixedWindow(max_requests=10, window=5), 88 | ) 89 | 90 | with patch("time.time", return_value=1688910786.167): 91 | assert await ratelimit.get_reset(random_id()) == approx(1688910790.0) 92 | -------------------------------------------------------------------------------- /tests/asyncio/test_sliding_window.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import List 3 | from unittest.mock import patch 4 | 5 | from pytest import approx, mark 6 | from upstash_redis.asyncio import Redis 7 | 8 | from tests.utils import random_id 9 | from upstash_ratelimit.asyncio import Ratelimit, Response, SlidingWindow 10 | from upstash_ratelimit.utils import now_s 11 | 12 | 13 | @mark.asyncio() 14 | async def test_max_requests_are_not_reached(async_redis: Redis) -> None: 15 | ratelimit = Ratelimit( 16 | redis=async_redis, 17 | limiter=SlidingWindow(max_requests=5, window=10), 18 | ) 19 | 20 | now = now_s() 21 | response = await ratelimit.limit(random_id()) 22 | 23 | assert response.allowed is True 24 | assert response.limit == 5 25 | assert response.remaining == 4 26 | assert response.reset >= now 27 | 28 | 29 | @mark.asyncio() 30 | async def test_max_requests_are_reached(async_redis: Redis) -> None: 31 | ratelimit = Ratelimit( 32 | redis=async_redis, 33 | limiter=SlidingWindow(max_requests=1, window=1, unit="d"), 34 | ) 35 | 36 | id = random_id() 37 | 38 | await ratelimit.limit(id) 39 | 40 | now = now_s() 41 | response = await ratelimit.limit(id) 42 | 43 | assert response.allowed is False 44 | assert response.limit == 1 45 | assert response.remaining == 0 46 | assert response.reset >= now 47 | 48 | 49 | @mark.asyncio() 50 | async def test_window_reset(async_redis: Redis) -> None: 51 | ratelimit = Ratelimit( 52 | redis=async_redis, 53 | limiter=SlidingWindow(max_requests=1, window=3), 54 | ) 55 | 56 | id = random_id() 57 | 58 | await ratelimit.limit(id) 59 | 60 | await asyncio.sleep(3) 61 | 62 | now = now_s() 63 | response = await ratelimit.limit(id) 64 | 65 | assert response.allowed is True 66 | assert response.limit == 1 67 | assert response.remaining == 0 68 | assert response.reset >= now 69 | 70 | 71 | @mark.asyncio() 72 | async def test_sliding(async_redis: Redis) -> None: 73 | ratelimit = Ratelimit( 74 | redis=async_redis, 75 | limiter=SlidingWindow(max_requests=5, window=3), 76 | ) 77 | 78 | id = random_id() 79 | 80 | responses: List[Response] = [] 81 | 82 | while True: 83 | response = await ratelimit.limit(id) 84 | responses.append(response) 85 | 86 | if len(responses) > 5 and responses[-2].reset != response.reset: 87 | break 88 | 89 | last_response = responses[-1] 90 | 91 | # We should consider some items from the last window so the 92 | # remaining should not be equal to max_requests - 1 93 | assert last_response.remaining != 4 94 | 95 | 96 | @mark.asyncio() 97 | async def test_get_remaining(async_redis: Redis) -> None: 98 | ratelimit = Ratelimit( 99 | redis=async_redis, 100 | limiter=SlidingWindow(max_requests=10, window=1, unit="d"), 101 | ) 102 | 103 | id = random_id() 104 | assert await ratelimit.get_remaining(id) == 10 105 | await ratelimit.limit(id) 106 | assert await ratelimit.get_remaining(id) == 9 107 | 108 | 109 | @mark.asyncio() 110 | async def test_get_remaining_with_sliding(async_redis: Redis) -> None: 111 | ratelimit = Ratelimit( 112 | redis=async_redis, 113 | limiter=SlidingWindow(max_requests=5, window=3), 114 | ) 115 | 116 | id = random_id() 117 | 118 | responses: List[Response] = [] 119 | 120 | while True: 121 | response = await ratelimit.limit(id) 122 | responses.append(response) 123 | 124 | if len(responses) > 5 and responses[-2].reset != response.reset: 125 | break 126 | 127 | last_response = responses[-1] 128 | assert await ratelimit.get_remaining(id) == last_response.remaining 129 | 130 | 131 | @mark.asyncio() 132 | async def test_get_reset(async_redis: Redis) -> None: 133 | ratelimit = Ratelimit( 134 | redis=async_redis, 135 | limiter=SlidingWindow(max_requests=10, window=5), 136 | ) 137 | 138 | with patch("time.time", return_value=1688910786.167): 139 | assert await ratelimit.get_reset(random_id()) == approx(1688910790.0) 140 | -------------------------------------------------------------------------------- /tests/asyncio/test_token_bucket.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from pytest import mark 4 | from upstash_redis.asyncio import Redis 5 | 6 | from tests.utils import random_id 7 | from upstash_ratelimit.asyncio import Ratelimit, TokenBucket 8 | from upstash_ratelimit.utils import now_s 9 | 10 | 11 | @mark.asyncio() 12 | async def test_max_tokens_are_not_reached(async_redis: Redis) -> None: 13 | ratelimit = Ratelimit( 14 | redis=async_redis, 15 | limiter=TokenBucket(max_tokens=5, refill_rate=5, interval=1, unit="d"), 16 | ) 17 | 18 | now = now_s() 19 | response = await ratelimit.limit(random_id()) 20 | 21 | assert response.allowed is True 22 | assert response.limit == 5 23 | assert response.remaining == 4 24 | assert response.reset >= now 25 | 26 | 27 | @mark.asyncio() 28 | async def test_max_tokens_are_reached(async_redis: Redis) -> None: 29 | ratelimit = Ratelimit( 30 | redis=async_redis, 31 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=1, unit="d"), 32 | ) 33 | 34 | id = random_id() 35 | 36 | await ratelimit.limit(id) 37 | 38 | now = now_s() 39 | response = await ratelimit.limit(id) 40 | 41 | assert response.allowed is False 42 | assert response.limit == 1 43 | assert response.remaining == 0 44 | assert response.reset >= now 45 | 46 | 47 | @mark.asyncio() 48 | async def test_refill(async_redis: Redis) -> None: 49 | ratelimit = Ratelimit( 50 | redis=async_redis, 51 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=3), 52 | ) 53 | 54 | id = random_id() 55 | 56 | await ratelimit.limit(id) 57 | 58 | await asyncio.sleep(3) 59 | 60 | now = now_s() 61 | response = await ratelimit.limit(id) 62 | 63 | assert response.allowed is True 64 | assert response.limit == 1 65 | assert response.remaining == 0 66 | assert response.reset >= now 67 | 68 | 69 | @mark.asyncio() 70 | async def test_refill_multiple_times(async_redis: Redis) -> None: 71 | ratelimit = Ratelimit( 72 | redis=async_redis, 73 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 74 | ) 75 | 76 | id = random_id() 77 | 78 | last_response = None 79 | for _ in range(10): 80 | last_response = await ratelimit.limit(id) 81 | 82 | assert last_response is not None 83 | last_remaining = last_response.remaining 84 | 85 | await asyncio.sleep(3) 86 | 87 | response = await ratelimit.limit(id) 88 | assert response.remaining >= last_remaining + 2 89 | 90 | 91 | @mark.asyncio() 92 | async def test_get_remaining(async_redis: Redis) -> None: 93 | ratelimit = Ratelimit( 94 | redis=async_redis, 95 | limiter=TokenBucket(max_tokens=10, refill_rate=10, interval=1, unit="d"), 96 | ) 97 | 98 | id = random_id() 99 | assert await ratelimit.get_remaining(id) == 10 100 | await ratelimit.limit(id) 101 | assert await ratelimit.get_remaining(id) == 9 102 | 103 | 104 | @mark.asyncio() 105 | async def test_get_remaining_with_refills_that_should_be_made( 106 | async_redis: Redis, 107 | ) -> None: 108 | ratelimit = Ratelimit( 109 | redis=async_redis, 110 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 111 | ) 112 | 113 | id = random_id() 114 | 115 | last_response = None 116 | for _ in range(10): 117 | last_response = await ratelimit.limit(id) 118 | 119 | assert last_response is not None 120 | last_remaining = last_response.remaining 121 | 122 | await asyncio.sleep(3) 123 | 124 | assert await ratelimit.get_remaining(id) >= last_remaining + 2 125 | 126 | 127 | @mark.asyncio() 128 | async def test_get_reset(async_redis: Redis) -> None: 129 | ratelimit = Ratelimit( 130 | redis=async_redis, 131 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 132 | ) 133 | 134 | id = random_id() 135 | now = now_s() 136 | await ratelimit.limit(id) 137 | 138 | assert await ratelimit.get_reset(id) >= now + 0.9 139 | 140 | 141 | @mark.asyncio() 142 | async def test_get_reset_with_refills_that_should_be_made( 143 | async_redis: Redis, 144 | ) -> None: 145 | ratelimit = Ratelimit( 146 | redis=async_redis, 147 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 148 | ) 149 | 150 | id = random_id() 151 | 152 | last_response = None 153 | for _ in range(10): 154 | last_response = await ratelimit.limit(id) 155 | 156 | assert last_response is not None 157 | last_reset = last_response.reset 158 | 159 | await asyncio.sleep(3) 160 | 161 | assert await ratelimit.get_reset(id) >= last_reset + 2 162 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest_asyncio 2 | from pytest import fixture 3 | from upstash_redis import Redis 4 | from upstash_redis.asyncio import Redis as AsyncRedis 5 | 6 | 7 | @fixture 8 | def redis(): 9 | with Redis.from_env() as redis: 10 | yield redis 11 | 12 | 13 | @pytest_asyncio.fixture 14 | async def async_redis(): 15 | async with AsyncRedis.from_env() as redis: 16 | yield redis 17 | 18 | 19 | @fixture(scope="session", autouse=True) 20 | def setup_cleanup(): 21 | with Redis.from_env() as redis: 22 | redis.flushdb() 23 | yield 24 | redis.flushdb() 25 | -------------------------------------------------------------------------------- /tests/test_block_until_ready.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from pytest import raises 4 | from upstash_redis import Redis 5 | 6 | from tests.utils import random_id 7 | from upstash_ratelimit import FixedWindow, Ratelimit 8 | 9 | 10 | def test_invalid_timeout(redis: Redis) -> None: 11 | ratelimit = Ratelimit( 12 | redis=redis, 13 | limiter=FixedWindow(max_requests=1, window=1), 14 | ) 15 | 16 | with raises(ValueError): 17 | ratelimit.block_until_ready(random_id(), -1) 18 | 19 | 20 | def test_resolve_before_timeout(redis: Redis) -> None: 21 | ratelimit = Ratelimit( 22 | redis=redis, 23 | limiter=FixedWindow(max_requests=5, window=100), 24 | ) 25 | 26 | timeout = 50 27 | 28 | start = time.time() 29 | response = ratelimit.block_until_ready(random_id(), timeout) 30 | elapsed = time.time() - start 31 | 32 | assert elapsed < timeout 33 | assert response.allowed is True 34 | 35 | 36 | def test_resolve_before_timeout_when_window_resets(redis: Redis) -> None: 37 | ratelimit = Ratelimit( 38 | redis=redis, 39 | limiter=FixedWindow(max_requests=1, window=3), 40 | ) 41 | 42 | id = random_id() 43 | timeout = 100 44 | 45 | ratelimit.limit(id) 46 | 47 | start = time.time() 48 | response = ratelimit.block_until_ready(id, timeout) 49 | elapsed = time.time() - start 50 | 51 | assert elapsed < timeout 52 | assert response.allowed is True 53 | 54 | 55 | def test_reaching_timeout(redis: Redis) -> None: 56 | ratelimit = Ratelimit( 57 | redis=redis, 58 | limiter=FixedWindow(max_requests=1, window=1, unit="d"), 59 | ) 60 | 61 | id = random_id() 62 | timeout = 2 63 | 64 | ratelimit.limit(id) 65 | 66 | start = time.time() 67 | response = ratelimit.block_until_ready(id, timeout) 68 | elapsed = time.time() - start 69 | 70 | assert elapsed >= timeout 71 | assert response.allowed is False 72 | -------------------------------------------------------------------------------- /tests/test_fixed_window.py: -------------------------------------------------------------------------------- 1 | import time 2 | from unittest.mock import patch 3 | 4 | from pytest import approx 5 | from upstash_redis import Redis 6 | 7 | from tests.utils import random_id 8 | from upstash_ratelimit import FixedWindow, Ratelimit 9 | from upstash_ratelimit.utils import now_s 10 | 11 | 12 | def test_max_requests_are_not_reached(redis: Redis) -> None: 13 | ratelimit = Ratelimit( 14 | redis=redis, 15 | limiter=FixedWindow(max_requests=5, window=10), 16 | ) 17 | 18 | now = now_s() 19 | response = ratelimit.limit(random_id()) 20 | 21 | assert response.allowed is True 22 | assert response.limit == 5 23 | assert response.remaining == 4 24 | assert response.reset >= now 25 | 26 | 27 | def test_max_requests_are_reached(redis: Redis) -> None: 28 | ratelimit = Ratelimit( 29 | redis=redis, 30 | limiter=FixedWindow(max_requests=1, window=1, unit="d"), 31 | ) 32 | 33 | id = random_id() 34 | 35 | ratelimit.limit(id) 36 | 37 | now = now_s() 38 | response = ratelimit.limit(id) 39 | 40 | assert response.allowed is False 41 | assert response.limit == 1 42 | assert response.remaining == 0 43 | assert response.reset >= now 44 | 45 | 46 | def test_window_reset(redis: Redis) -> None: 47 | ratelimit = Ratelimit( 48 | redis=redis, 49 | limiter=FixedWindow(max_requests=1, window=3), 50 | ) 51 | 52 | id = random_id() 53 | 54 | ratelimit.limit(id) 55 | 56 | time.sleep(3) 57 | 58 | now = now_s() 59 | response = ratelimit.limit(id) 60 | 61 | assert response.allowed is True 62 | assert response.limit == 1 63 | assert response.remaining == 0 64 | assert response.reset >= now 65 | 66 | 67 | def test_get_remaining(redis: Redis) -> None: 68 | ratelimit = Ratelimit( 69 | redis=redis, 70 | limiter=FixedWindow(max_requests=10, window=1, unit="d"), 71 | ) 72 | 73 | id = random_id() 74 | assert ratelimit.get_remaining(id) == 10 75 | ratelimit.limit(id) 76 | assert ratelimit.get_remaining(id) == 9 77 | 78 | 79 | def test_get_reset(redis: Redis) -> None: 80 | ratelimit = Ratelimit( 81 | redis=redis, 82 | limiter=FixedWindow(max_requests=10, window=5), 83 | ) 84 | 85 | with patch("time.time", return_value=1688910786.167): 86 | assert ratelimit.get_reset(random_id()) == approx(1688910790.0) 87 | 88 | 89 | def test_custom_rate(redis: Redis) -> None: 90 | ratelimit = Ratelimit( 91 | redis=redis, 92 | limiter=FixedWindow(max_requests=10, window=1, unit="d"), 93 | ) 94 | rate = 2 95 | 96 | id = random_id() 97 | 98 | ratelimit.limit(id) 99 | ratelimit.limit(id, rate) 100 | assert ratelimit.get_remaining(id) == 7 101 | 102 | ratelimit.limit(id, rate) 103 | assert ratelimit.get_remaining(id) == 5 104 | -------------------------------------------------------------------------------- /tests/test_sliding_window.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List 3 | from unittest.mock import patch 4 | 5 | from pytest import approx 6 | from upstash_redis import Redis 7 | 8 | from tests.utils import random_id 9 | from upstash_ratelimit import Ratelimit, Response, SlidingWindow 10 | from upstash_ratelimit.utils import now_s 11 | 12 | 13 | def test_max_requests_are_not_reached(redis: Redis) -> None: 14 | ratelimit = Ratelimit( 15 | redis=redis, 16 | limiter=SlidingWindow(max_requests=5, window=10), 17 | ) 18 | 19 | now = now_s() 20 | response = ratelimit.limit(random_id()) 21 | 22 | assert response.allowed is True 23 | assert response.limit == 5 24 | assert response.remaining == 4 25 | assert response.reset >= now 26 | 27 | 28 | def test_max_requests_are_reached(redis: Redis) -> None: 29 | ratelimit = Ratelimit( 30 | redis=redis, 31 | limiter=SlidingWindow(max_requests=1, window=1, unit="d"), 32 | ) 33 | 34 | id = random_id() 35 | 36 | ratelimit.limit(id) 37 | 38 | now = now_s() 39 | response = ratelimit.limit(id) 40 | 41 | assert response.allowed is False 42 | assert response.limit == 1 43 | assert response.remaining == 0 44 | assert response.reset >= now 45 | 46 | 47 | def test_window_reset(redis: Redis) -> None: 48 | ratelimit = Ratelimit( 49 | redis=redis, 50 | limiter=SlidingWindow(max_requests=1, window=3), 51 | ) 52 | 53 | id = random_id() 54 | 55 | ratelimit.limit(id) 56 | 57 | time.sleep(3) 58 | 59 | now = now_s() 60 | response = ratelimit.limit(id) 61 | 62 | assert response.allowed is True 63 | assert response.limit == 1 64 | assert response.remaining == 0 65 | assert response.reset >= now 66 | 67 | 68 | def test_sliding(redis: Redis) -> None: 69 | ratelimit = Ratelimit( 70 | redis=redis, 71 | limiter=SlidingWindow(max_requests=5, window=3), 72 | ) 73 | 74 | id = random_id() 75 | 76 | responses: List[Response] = [] 77 | 78 | while True: 79 | response = ratelimit.limit(id) 80 | responses.append(response) 81 | 82 | if len(responses) > 5 and responses[-2].reset != response.reset: 83 | break 84 | 85 | last_response = responses[-1] 86 | 87 | # We should consider some items from the last window so the 88 | # remaining should not be equal to max_requests - 1 89 | assert last_response.remaining != 4 90 | 91 | 92 | def test_get_remaining(redis: Redis) -> None: 93 | ratelimit = Ratelimit( 94 | redis=redis, 95 | limiter=SlidingWindow(max_requests=10, window=1, unit="d"), 96 | ) 97 | 98 | id = random_id() 99 | assert ratelimit.get_remaining(id) == 10 100 | ratelimit.limit(id) 101 | assert ratelimit.get_remaining(id) == 9 102 | 103 | 104 | def test_get_remaining_with_sliding(redis: Redis) -> None: 105 | ratelimit = Ratelimit( 106 | redis=redis, 107 | limiter=SlidingWindow(max_requests=5, window=3), 108 | ) 109 | 110 | id = random_id() 111 | 112 | responses: List[Response] = [] 113 | 114 | while True: 115 | response = ratelimit.limit(id) 116 | responses.append(response) 117 | 118 | if len(responses) > 5 and responses[-2].reset != response.reset: 119 | break 120 | 121 | last_response = responses[-1] 122 | assert ratelimit.get_remaining(id) >= last_response.remaining 123 | 124 | 125 | def test_get_reset(redis: Redis) -> None: 126 | ratelimit = Ratelimit( 127 | redis=redis, 128 | limiter=SlidingWindow(max_requests=10, window=5), 129 | ) 130 | 131 | with patch("time.time", return_value=1688910786.167): 132 | assert ratelimit.get_reset(random_id()) == approx(1688910790.0) 133 | 134 | 135 | def test_custom_rate(redis: Redis) -> None: 136 | ratelimit = Ratelimit( 137 | redis=redis, 138 | limiter=SlidingWindow(max_requests=10, window=5), 139 | ) 140 | rate = 2 141 | 142 | id = random_id() 143 | 144 | ratelimit.limit(id) 145 | ratelimit.limit(id, rate) 146 | assert ratelimit.get_remaining(id) == 7 147 | 148 | ratelimit.limit(id, rate) 149 | assert ratelimit.get_remaining(id) == 5 150 | -------------------------------------------------------------------------------- /tests/test_token_bucket.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from upstash_redis import Redis 4 | 5 | from tests.utils import random_id 6 | from upstash_ratelimit import Ratelimit, TokenBucket 7 | from upstash_ratelimit.utils import now_s 8 | 9 | 10 | def test_max_tokens_are_not_reached(redis: Redis) -> None: 11 | ratelimit = Ratelimit( 12 | redis=redis, 13 | limiter=TokenBucket(max_tokens=5, refill_rate=5, interval=1, unit="d"), 14 | ) 15 | 16 | now = now_s() 17 | response = ratelimit.limit(random_id()) 18 | 19 | assert response.allowed is True 20 | assert response.limit == 5 21 | assert response.remaining == 4 22 | assert response.reset >= now 23 | 24 | 25 | def test_max_tokens_are_reached(redis: Redis) -> None: 26 | ratelimit = Ratelimit( 27 | redis=redis, 28 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=1, unit="d"), 29 | ) 30 | 31 | id = random_id() 32 | 33 | ratelimit.limit(id) 34 | 35 | now = now_s() 36 | response = ratelimit.limit(id) 37 | 38 | assert response.allowed is False 39 | assert response.limit == 1 40 | assert response.remaining == 0 41 | assert response.reset >= now 42 | 43 | 44 | def test_refill(redis: Redis) -> None: 45 | ratelimit = Ratelimit( 46 | redis=redis, 47 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=3), 48 | ) 49 | 50 | id = random_id() 51 | 52 | ratelimit.limit(id) 53 | 54 | time.sleep(3) 55 | 56 | now = now_s() 57 | response = ratelimit.limit(id) 58 | 59 | assert response.allowed is True 60 | assert response.limit == 1 61 | assert response.remaining == 0 62 | assert response.reset >= now 63 | 64 | 65 | def test_refill_multiple_times(redis: Redis) -> None: 66 | ratelimit = Ratelimit( 67 | redis=redis, 68 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 69 | ) 70 | 71 | id = random_id() 72 | 73 | last_response = None 74 | for _ in range(10): 75 | last_response = ratelimit.limit(id) 76 | 77 | assert last_response is not None 78 | last_remaining = last_response.remaining 79 | 80 | time.sleep(3) 81 | 82 | response = ratelimit.limit(id) 83 | assert response.remaining >= last_remaining + 2 84 | 85 | 86 | def test_get_remaining(redis: Redis) -> None: 87 | ratelimit = Ratelimit( 88 | redis=redis, 89 | limiter=TokenBucket(max_tokens=10, refill_rate=10, interval=1, unit="d"), 90 | ) 91 | 92 | id = random_id() 93 | assert ratelimit.get_remaining(id) == 10 94 | ratelimit.limit(id) 95 | assert ratelimit.get_remaining(id) == 9 96 | 97 | 98 | def test_get_remaining_with_refills_that_should_be_made(redis: Redis) -> None: 99 | ratelimit = Ratelimit( 100 | redis=redis, 101 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 102 | ) 103 | 104 | id = random_id() 105 | 106 | last_response = None 107 | for _ in range(10): 108 | last_response = ratelimit.limit(id) 109 | 110 | assert last_response is not None 111 | last_remaining = last_response.remaining 112 | 113 | time.sleep(3) 114 | 115 | assert ratelimit.get_remaining(id) >= last_remaining + 2 116 | 117 | 118 | def test_get_reset(redis: Redis) -> None: 119 | ratelimit = Ratelimit( 120 | redis=redis, 121 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 122 | ) 123 | 124 | id = random_id() 125 | now = now_s() 126 | ratelimit.limit(id) 127 | 128 | assert ratelimit.get_reset(id) >= now + 0.9 129 | 130 | 131 | def test_get_reset_with_refills_that_should_be_made(redis: Redis) -> None: 132 | ratelimit = Ratelimit( 133 | redis=redis, 134 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1), 135 | ) 136 | 137 | id = random_id() 138 | 139 | last_response = None 140 | for _ in range(10): 141 | last_response = ratelimit.limit(id) 142 | 143 | assert last_response is not None 144 | last_reset = last_response.reset 145 | 146 | time.sleep(3) 147 | 148 | assert ratelimit.get_reset(id) >= last_reset + 2 149 | 150 | 151 | def test_custom_rate(redis: Redis) -> None: 152 | ratelimit = Ratelimit( 153 | redis=redis, 154 | limiter=TokenBucket(max_tokens=10, refill_rate=1, interval=1), 155 | ) 156 | rate = 2 157 | 158 | id = random_id() 159 | 160 | ratelimit.limit(id) 161 | ratelimit.limit(id, rate) 162 | assert ratelimit.get_remaining(id) == 7 163 | 164 | ratelimit.limit(id, rate) 165 | assert ratelimit.get_remaining(id) == 5 166 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from pytest import approx, mark, raises 4 | 5 | from upstash_ratelimit.typing import UnitT 6 | from upstash_ratelimit.utils import ms_to_s, now_ms, now_s, s_to_ms, to_ms 7 | 8 | 9 | @mark.parametrize( 10 | "value,unit,expected", 11 | [ 12 | (1, "d", 86_400_000), 13 | (4, "h", 14_400_000), 14 | (2, "m", 120_000), 15 | (33, "s", 33_000), 16 | (42, "ms", 42), 17 | ], 18 | ) 19 | def test_to_ms(value: int, unit: UnitT, expected: int) -> None: 20 | assert to_ms(value, unit) == expected 21 | 22 | 23 | def test_to_ms_invalid_unit() -> None: 24 | with raises(ValueError): 25 | to_ms(42, "invalid") # type: ignore[arg-type] 26 | 27 | 28 | def test_now_ms() -> None: 29 | with patch("time.time", return_value=42.5): 30 | assert now_ms() == 42_500 31 | 32 | 33 | def test_now_s() -> None: 34 | with patch("time.time", return_value=42.5): 35 | assert now_s() == 42.5 36 | 37 | 38 | def test_s_to_ms() -> None: 39 | assert s_to_ms(12.5) == 12_500 40 | 41 | 42 | def test_ms_to_s() -> None: 43 | assert ms_to_s(44_123) == approx(44.123) 44 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | 4 | def random_id() -> str: 5 | return str(uuid.uuid4()) 6 | -------------------------------------------------------------------------------- /upstash_ratelimit/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.0" 2 | 3 | from upstash_ratelimit.limiter import FixedWindow, Response, SlidingWindow, TokenBucket 4 | from upstash_ratelimit.ratelimit import Ratelimit 5 | 6 | __all__ = [ 7 | "Ratelimit", 8 | "FixedWindow", 9 | "SlidingWindow", 10 | "TokenBucket", 11 | "Response", 12 | ] 13 | -------------------------------------------------------------------------------- /upstash_ratelimit/asyncio/__init__.py: -------------------------------------------------------------------------------- 1 | from upstash_ratelimit.asyncio.ratelimit import Ratelimit 2 | from upstash_ratelimit.limiter import FixedWindow, Response, SlidingWindow, TokenBucket 3 | 4 | __all__ = [ 5 | "Ratelimit", 6 | "FixedWindow", 7 | "SlidingWindow", 8 | "TokenBucket", 9 | "Response", 10 | ] 11 | -------------------------------------------------------------------------------- /upstash_ratelimit/asyncio/ratelimit.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | from upstash_redis.asyncio import Redis 5 | 6 | from upstash_ratelimit.limiter import Limiter, Response 7 | from upstash_ratelimit.utils import merge_telemetry, now_s 8 | 9 | 10 | class Ratelimit: 11 | """ 12 | Provides means of ratelimitting over the HTTP-based 13 | Upstash Redis client. 14 | """ 15 | 16 | def __init__( 17 | self, redis: Redis, limiter: Limiter, prefix: str = "@upstash/ratelimit" 18 | ) -> None: 19 | """ 20 | :param redis: Upstash Redis instance to use. 21 | :param limiter: Ratelimiter to use. Available limiters are \ 22 | `FixedWindow`, `SlidingWindow`, and `TokenBucket` which are provided \ 23 | in the `limiter` module. 24 | :param prefix: Prefix to distinguish the keys used in the ratelimit \ 25 | logic from others, in case the same Redis instance is reused between \ 26 | different applications. 27 | """ 28 | 29 | self._redis = redis 30 | merge_telemetry(redis) 31 | 32 | self._limiter = limiter 33 | self._prefix = prefix 34 | 35 | async def limit(self, identifier: str, rate: int = 1) -> Response: 36 | """ 37 | Determines if a request should pass or be rejected based on the identifier 38 | and previously chosen ratelimit. 39 | 40 | Use this if you want to reject all requests that you can not handle 41 | right now. 42 | 43 | .. code-block:: python 44 | 45 | from upstash_ratelimit.asyncio import Ratelimit, SlidingWindow 46 | from upstash_redis.asyncio import Redis 47 | 48 | ratelimit = Ratelimit( 49 | redis=Redis.from_env(), 50 | limiter=SlidingWindow(max_requests=10, window=10, unit="s"), 51 | ) 52 | 53 | async def main() -> None: 54 | response = await ratelimit.limit("some-id") 55 | if not response.allowed: 56 | print("Ratelimitted!") 57 | 58 | print("Good to go!") 59 | 60 | :param identifier: Identifier to ratelimit. Use a constant string to \ 61 | limit all requests, or user ids, API keys, or IP addresses for \ 62 | individual limits. 63 | :param rate: Rate with which to subtract from the limit of the \ 64 | identifier. 65 | """ 66 | 67 | key = f"{self._prefix}:{identifier}" 68 | return await self._limiter.limit_async(self._redis, key, rate) 69 | 70 | async def block_until_ready(self, identifier: str, timeout: float, rate: int = 1) -> Response: 71 | """ 72 | Blocks until the request may pass or timeout is reached. 73 | 74 | This method blocks until the request may be processed or the timeout 75 | has been reached. 76 | 77 | Use this if you want to delay the request until it is ready to get 78 | processed. 79 | 80 | .. code-block:: python 81 | 82 | from upstash_ratelimit.asyncio import Ratelimit, SlidingWindow 83 | from upstash_redis.asyncio import Redis 84 | 85 | ratelimit = Ratelimit( 86 | redis=Redis.from_env(), 87 | lmiter=SlidingWindow(max_requests=10, window=10, unit="s"), 88 | ) 89 | 90 | async def main() -> None: 91 | response = await ratelimit.block_until_ready("some-id", 60) 92 | if not response.allowed: 93 | print("Ratelimitted!") 94 | 95 | print("Good to go!") 96 | 97 | :param identifier: Identifier to ratelimit. Use a constant string to \ 98 | limit all requests, or user ids, API keys, or IP addresses for \ 99 | individual limits. 100 | :param timeout: Maximum time in seconds to wait until the request \ 101 | may pass. 102 | :param rate: Rate with which to subtract from the limit of the \ 103 | identifier. 104 | """ 105 | 106 | if timeout <= 0: 107 | raise ValueError("Timeout must be positive") 108 | 109 | response: Optional[Response] = None 110 | deadline = now_s() + timeout 111 | 112 | while True: 113 | response = await self.limit(identifier, rate) 114 | if response.allowed: 115 | break 116 | 117 | wait = max(0, min(response.reset, deadline) - now_s()) 118 | await asyncio.sleep(wait) 119 | 120 | if now_s() > deadline: 121 | break 122 | 123 | return response 124 | 125 | async def get_remaining(self, identifier: str) -> int: 126 | """ 127 | Returns the number of requests left for the given identifier. 128 | """ 129 | 130 | key = f"{self._prefix}:{identifier}" 131 | return await self._limiter.get_remaining_async(self._redis, key) 132 | 133 | async def get_reset(self, identifier: str) -> float: 134 | """ 135 | Returns the UNIX timestamp in seconds when the remaining 136 | requests will be reset or replenished. 137 | """ 138 | 139 | key = f"{self._prefix}:{identifier}" 140 | return await self._limiter.get_reset_async(self._redis, key) 141 | -------------------------------------------------------------------------------- /upstash_ratelimit/limiter.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | from collections.abc import Generator 4 | from typing import Any, Callable 5 | 6 | from upstash_redis import Redis 7 | from upstash_redis.asyncio import Redis as AsyncRedis 8 | 9 | from upstash_ratelimit.typing import UnitT 10 | from upstash_ratelimit.utils import ms_to_s, now_ms, to_ms 11 | 12 | 13 | @dataclasses.dataclass 14 | class Response: 15 | allowed: bool 16 | """ 17 | Whether the request may pass(`True`) or exceeded the limit(`False`) 18 | """ 19 | 20 | limit: int 21 | """ 22 | Maximum number of requests allowed within a window. 23 | """ 24 | 25 | remaining: int 26 | """ 27 | How many requests the user has left within the current window. 28 | """ 29 | 30 | reset: float 31 | """ 32 | Unix timestamp in seconds when the limits are reset 33 | """ 34 | 35 | 36 | class Limiter(abc.ABC): 37 | @abc.abstractmethod 38 | def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response: 39 | pass 40 | 41 | @abc.abstractmethod 42 | async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response: 43 | pass 44 | 45 | @abc.abstractmethod 46 | def get_remaining(self, redis: Redis, identifier: str) -> int: 47 | pass 48 | 49 | @abc.abstractmethod 50 | async def get_remaining_async(self, redis: AsyncRedis, identifier: str) -> int: 51 | pass 52 | 53 | @abc.abstractmethod 54 | def get_reset(self, redis: Redis, identifier: str) -> float: 55 | pass 56 | 57 | @abc.abstractmethod 58 | async def get_reset_async(self, redis: AsyncRedis, identifier: str) -> float: 59 | pass 60 | 61 | 62 | def _with_at_most_one_request(redis: Redis, generator: Generator) -> Any: 63 | """ 64 | A function that makes at most one HTTP request over the 65 | given Redis instance. 66 | 67 | If the generator needs to execute a command, this function 68 | takes the command name and args from the generator, 69 | executes the given command, and passes the result 70 | of the command back to the generator. Then, it 71 | takes the final response from the generator and 72 | returns it. 73 | 74 | If the generator does not need to execute a command, 75 | it returns the result directly. 76 | """ 77 | 78 | command_name, command_args = next(generator) 79 | if not command_name: 80 | # No need to execute a command 81 | response = next(generator) 82 | return response 83 | 84 | command: Callable = getattr(redis, command_name) 85 | command_response = command(*command_args) 86 | response = generator.send(command_response) 87 | return response 88 | 89 | 90 | async def _with_at_most_one_request_async( 91 | redis: AsyncRedis, generator: Generator 92 | ) -> Any: 93 | """ 94 | Async variant of the `_with_one_request_fn` defined above. 95 | """ 96 | 97 | command_name, command_args = next(generator) 98 | if not command_name: 99 | # No need to execute a command 100 | response = next(generator) 101 | return response 102 | 103 | command: Callable = getattr(redis, command_name) 104 | command_response = await command(*command_args) 105 | response = generator.send(command_response) 106 | return response 107 | 108 | 109 | class AbstractLimiter(Limiter): 110 | @abc.abstractmethod 111 | def _limit(self, identifier: str, rate: int = 1) -> Generator: 112 | pass 113 | 114 | def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response: 115 | response: Response = _with_at_most_one_request(redis, self._limit(identifier, rate)) 116 | return response 117 | 118 | async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response: 119 | response: Response = await _with_at_most_one_request_async( 120 | redis, self._limit(identifier, rate) 121 | ) 122 | return response 123 | 124 | @abc.abstractmethod 125 | def _get_remaining(self, identifier: str) -> Generator: 126 | pass 127 | 128 | def get_remaining(self, redis: Redis, identifier: str) -> int: 129 | remaining: int = _with_at_most_one_request( 130 | redis, self._get_remaining(identifier) 131 | ) 132 | return remaining 133 | 134 | async def get_remaining_async(self, redis: AsyncRedis, identifier: str) -> int: 135 | remaining: int = await _with_at_most_one_request_async( 136 | redis, self._get_remaining(identifier) 137 | ) 138 | return remaining 139 | 140 | @abc.abstractmethod 141 | def _get_reset(self, identifier: str) -> Generator: 142 | pass 143 | 144 | def get_reset(self, redis: Redis, identifier: str) -> float: 145 | reset: float = _with_at_most_one_request(redis, self._get_reset(identifier)) 146 | return reset 147 | 148 | async def get_reset_async(self, redis: AsyncRedis, identifier: str) -> float: 149 | reset: float = await _with_at_most_one_request_async( 150 | redis, self._get_reset(identifier) 151 | ) 152 | return reset 153 | 154 | 155 | class FixedWindow(AbstractLimiter): 156 | """ 157 | The time is divided into windows of fixed length, and each request inside 158 | a window increases a counter. 159 | 160 | Once the counter reaches the maximum allowed number, all further requests 161 | are rejected. 162 | 163 | Pros: 164 | - Newer requests are not starved by old ones. 165 | - Low storage cost. 166 | 167 | Cons: 168 | - A burst of requests near the boundary of a window can result in twice the 169 | rate of requests being processed because two windows will be filled with 170 | requests quickly. 171 | """ 172 | 173 | SCRIPT = """ 174 | local key = KEYS[1] 175 | local window = ARGV[1] 176 | local increment_by = ARGV[2] -- increment rate per request at a given value, default is 1 177 | 178 | local r = redis.call("INCRBY", key, increment_by) 179 | if r == tonumber(increment_by) then 180 | -- The first time this key is set, the value will be equal to increment_by. 181 | -- So we only need the expire command once 182 | redis.call("PEXPIRE", key, window) 183 | end 184 | 185 | return r 186 | """ 187 | 188 | def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None: 189 | """ 190 | :param max_requests: Maximum number of requests allowed within a window 191 | :param window: The number of time units in a window 192 | :param unit: The unit of time 193 | """ 194 | 195 | assert max_requests > 0 196 | assert window > 0 197 | 198 | self._max_requests = max_requests 199 | self._window = to_ms(window, unit) 200 | 201 | def _limit(self, identifier: str, rate: int = 1) -> Generator: 202 | curr_window = now_ms() // self._window 203 | key = f"{identifier}:{curr_window}" 204 | 205 | num_requests = yield ( 206 | "eval", 207 | (FixedWindow.SCRIPT, [key], [self._window, rate]), 208 | ) 209 | 210 | yield Response( 211 | allowed=num_requests <= self._max_requests, 212 | limit=self._max_requests, 213 | remaining=max(0, self._max_requests - num_requests), 214 | reset=ms_to_s((curr_window + 1) * self._window), 215 | ) 216 | 217 | def _get_remaining(self, identifier: str) -> Generator: 218 | curr_window = now_ms() // self._window 219 | key = f"{identifier}:{curr_window}" 220 | 221 | num_requests = yield ( 222 | "get", 223 | (key,), 224 | ) 225 | 226 | if num_requests is None: 227 | yield self._max_requests 228 | 229 | yield max(0, self._max_requests - int(num_requests)) # type: ignore[arg-type] 230 | 231 | def _get_reset(self, _: str) -> Generator: 232 | yield (None, None) # Signal that we don't need to make a remote call 233 | 234 | curr_window = now_ms() // self._window 235 | yield ms_to_s((curr_window + 1) * self._window) 236 | 237 | 238 | class SlidingWindow(AbstractLimiter): 239 | """ 240 | Combined approach of sliding logs and fixed window with lower storage 241 | costs than sliding logs and improved boundary behavior by calculating a 242 | weighted score between two windows. 243 | 244 | Pros: 245 | - Good performance allows this to scale to very high loads. 246 | """ 247 | 248 | SCRIPT = """ 249 | local current_key = KEYS[1] -- identifier including prefixes 250 | local previous_key = KEYS[2] -- key of the previous bucket 251 | local tokens = tonumber(ARGV[1]) -- tokens per window 252 | local now = ARGV[2] -- current timestamp in milliseconds 253 | local window = ARGV[3] -- interval in milliseconds 254 | local increment_by = ARGV[4] -- increment rate per request at a given value, default is 1 255 | 256 | local requests_in_current_window = redis.call("GET", current_key) 257 | if requests_in_current_window == false then 258 | requests_in_current_window = 0 259 | end 260 | 261 | local requests_in_previous_window = redis.call("GET", previous_key) 262 | if requests_in_previous_window == false then 263 | requests_in_previous_window = 0 264 | end 265 | local percentage_in_current = ( now % window ) / window 266 | -- weighted requests to consider from the previous window 267 | requests_in_previous_window = math.floor(( 1 - percentage_in_current ) * requests_in_previous_window) 268 | if requests_in_previous_window + requests_in_current_window >= tokens then 269 | return -1 270 | end 271 | 272 | local new_value = redis.call("INCRBY", current_key, increment_by) 273 | if new_value == tonumber(increment_by) then 274 | -- The first time this key is set, the value will be equal to increment_by. 275 | -- So we only need the expire command once 276 | redis.call("PEXPIRE", current_key, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second 277 | end 278 | return tokens - ( new_value + requests_in_previous_window ) 279 | """ 280 | 281 | def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None: 282 | """ 283 | :param max_requests: Maximum number of requests allowed within a window 284 | :param window: The number of time units in a window 285 | :param unit: The unit of time 286 | """ 287 | 288 | assert max_requests > 0 289 | assert window > 0 290 | 291 | self._max_requests = max_requests 292 | self._window = to_ms(window, unit) 293 | 294 | def _limit(self, identifier: str, rate: int = 1) -> Generator: 295 | now = now_ms() 296 | 297 | curr_window = now // self._window 298 | key = f"{identifier}:{curr_window}" 299 | 300 | prev_window = curr_window - 1 301 | prev_key = f"{identifier}:{prev_window}" 302 | 303 | remaining = yield ( 304 | "eval", 305 | ( 306 | SlidingWindow.SCRIPT, 307 | [key, prev_key], 308 | [self._max_requests, now, self._window, rate], 309 | ), 310 | ) 311 | 312 | yield Response( 313 | allowed=remaining >= 0, 314 | limit=self._max_requests, 315 | remaining=max(0, remaining), 316 | reset=ms_to_s((curr_window + 1) * self._window), 317 | ) 318 | 319 | def _get_remaining(self, identifier: str) -> Generator: 320 | now = now_ms() 321 | 322 | window = now // self._window 323 | key = f"{identifier}:{window}" 324 | 325 | prev_window = window - 1 326 | prev_key = f"{identifier}:{prev_window}" 327 | 328 | num_requests_, prev_num_requests_ = yield ( 329 | "mget", 330 | (key, prev_key), 331 | ) 332 | num_requests = int(num_requests_ or 0) 333 | prev_num_requests = int(prev_num_requests_ or 0) 334 | 335 | prev_window_weight = 1 - ((now % self._window) / self._window) 336 | prev_num_requests = int(prev_num_requests * prev_window_weight) 337 | 338 | remaining = self._max_requests - (prev_num_requests + num_requests) 339 | yield max(0, remaining) 340 | 341 | def _get_reset(self, _: str) -> Generator: 342 | yield (None, None) # Signal that we don't need to make a remote call 343 | 344 | curr_window = now_ms() // self._window 345 | yield ms_to_s((curr_window + 1) * self._window) 346 | 347 | 348 | class TokenBucket(AbstractLimiter): 349 | """ 350 | A bucket is filled with maximum number of tokens that refill at a given 351 | rate per interval. 352 | 353 | Each request tries to consume one token and if the bucket is empty, 354 | the request is rejected. 355 | 356 | Pros: 357 | - Bursts of requests are smoothed out so that they can be processed at 358 | a constant rate. 359 | - Allows to set a higher initial burst limit by setting maximum number 360 | of tokens higher than the refill rate. 361 | """ 362 | 363 | SCRIPT = """ 364 | local key = KEYS[1] -- identifier including prefixes 365 | local max_tokens = tonumber(ARGV[1]) -- maximum number of tokens 366 | local interval = tonumber(ARGV[2]) -- size of the window in milliseconds 367 | local refill_rate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval 368 | local now = tonumber(ARGV[4]) -- current timestamp in milliseconds 369 | local increment_by = tonumber(ARGV[5]) -- how many tokens to consume, default is 1 370 | 371 | local bucket = redis.call("HMGET", key, "refilled_at", "tokens") 372 | 373 | local refilled_at 374 | local tokens 375 | 376 | if bucket[1] == false then 377 | refilled_at = now 378 | tokens = max_tokens 379 | else 380 | refilled_at = tonumber(bucket[1]) 381 | tokens = tonumber(bucket[2]) 382 | end 383 | 384 | if now >= refilled_at + interval then 385 | local num_refills = math.floor((now - refilled_at) / interval) 386 | tokens = math.min(max_tokens, tokens + num_refills * refill_rate) 387 | 388 | refilled_at = refilled_at + num_refills * interval 389 | end 390 | 391 | if tokens == 0 then 392 | return {-1, refilled_at + interval} 393 | end 394 | 395 | local remaining = tokens - increment_by 396 | local expire_at = math.ceil(((max_tokens - remaining) / refill_rate)) * interval 397 | 398 | redis.call("HSET", key, "refilled_at", refilled_at, "tokens", remaining) 399 | redis.call("PEXPIRE", key, expire_at) 400 | return {remaining, refilled_at + interval} 401 | """ 402 | 403 | def __init__( 404 | self, max_tokens: int, refill_rate: int, interval: int, unit: UnitT = "s" 405 | ) -> None: 406 | """ 407 | :param max_tokens: Maximum number of tokens in a bucket. Since a newly 408 | created bucket starts with this many tokens, it can be used to 409 | allow higher burst limits. 410 | :param refill_rate: The number of tokens that are refilled per interval 411 | :param interval: The number of time units between each refill 412 | :param unit: The unit of time 413 | """ 414 | assert max_tokens > 0 415 | assert refill_rate > 0 416 | assert interval > 0 417 | 418 | self._max_tokens = max_tokens 419 | self._refill_rate = refill_rate 420 | self._interval = to_ms(interval, unit) 421 | 422 | def _limit(self, identifier: str, rate: int = 1) -> Generator: 423 | remaining, refill_at = yield ( 424 | "eval", 425 | ( 426 | TokenBucket.SCRIPT, 427 | [identifier], 428 | [self._max_tokens, self._interval, self._refill_rate, now_ms(), rate], 429 | ), 430 | ) 431 | 432 | yield Response( 433 | allowed=remaining >= 0, 434 | limit=self._max_tokens, 435 | remaining=max(0, remaining), 436 | reset=ms_to_s(refill_at), 437 | ) 438 | 439 | def _get_remaining(self, identifier: str) -> Generator: 440 | now = now_ms() 441 | 442 | refilled_at_, tokens_ = yield ( 443 | "hmget", 444 | (identifier, "refilled_at", "tokens"), 445 | ) 446 | 447 | if refilled_at_ is None: 448 | yield self._max_tokens 449 | 450 | refilled_at = int(refilled_at_) # type: ignore[arg-type] 451 | tokens = int(tokens_) # type: ignore[arg-type] 452 | 453 | if now >= refilled_at + self._interval: 454 | num_refills = (now - refilled_at) // self._interval 455 | tokens = min(self._max_tokens, tokens + num_refills * self._refill_rate) 456 | 457 | yield tokens 458 | 459 | def _get_reset(self, identifier: str) -> Generator: 460 | now = now_ms() 461 | 462 | refilled_at_ = yield ( 463 | "hget", 464 | (identifier, "refilled_at"), 465 | ) 466 | 467 | if refilled_at_ is None: 468 | yield ms_to_s(now) 469 | 470 | refilled_at = int(refilled_at_) # type: ignore[arg-type] 471 | if now >= refilled_at + self._interval: 472 | num_refills = (now - refilled_at) // self._interval 473 | refilled_at = refilled_at + num_refills * self._interval 474 | 475 | yield ms_to_s(refilled_at + self._interval) 476 | -------------------------------------------------------------------------------- /upstash_ratelimit/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upstash/ratelimit-py/bd70dfdb93ccfc0fea7a3d854c7942f8dce21f0a/upstash_ratelimit/py.typed -------------------------------------------------------------------------------- /upstash_ratelimit/ratelimit.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Optional 3 | 4 | from upstash_redis import Redis 5 | 6 | from upstash_ratelimit.limiter import Limiter, Response 7 | from upstash_ratelimit.utils import merge_telemetry, now_s 8 | 9 | 10 | class Ratelimit: 11 | """ 12 | Provides means of ratelimitting over the HTTP-based 13 | Upstash Redis client. 14 | """ 15 | 16 | def __init__( 17 | self, redis: Redis, limiter: Limiter, prefix: str = "@upstash/ratelimit" 18 | ) -> None: 19 | """ 20 | :param redis: Upstash Redis instance to use. 21 | :param limiter: Ratelimiter to use. Available limiters are \ 22 | `FixedWindow`, `SlidingWindow`, and `TokenBucket` which are provided \ 23 | in the `limiter` module. 24 | :param prefix: Prefix to distinguish the keys used in the ratelimit \ 25 | logic from others, in case the same Redis instance is reused between \ 26 | different applications. 27 | """ 28 | 29 | self._redis = redis 30 | merge_telemetry(redis) 31 | 32 | self._limiter = limiter 33 | self._prefix = prefix 34 | 35 | def limit(self, identifier: str, rate: int = 1) -> Response: 36 | """ 37 | Determines if a request should pass or be rejected based on the identifier 38 | and previously chosen ratelimit. 39 | 40 | Use this if you want to reject all requests that you can not handle 41 | right now. 42 | 43 | .. code-block:: python 44 | 45 | from upstash_redis import Redis 46 | from upstash_ratelimit import Ratelimit, SlidingWindow 47 | 48 | ratelimit = Ratelimit( 49 | redis=Redis.from_env(), 50 | limiter=SlidingWindow(max_requests=10, window=10, unit="s"), 51 | ) 52 | 53 | response = ratelimit.limit("some-id") 54 | if not response.allowed: 55 | print("Ratelimitted!") 56 | 57 | print("Good to go!") 58 | 59 | :param identifier: Identifier to ratelimit. Use a constant string to \ 60 | limit all requests, or user ids, API keys, or IP addresses for \ 61 | individual limits. 62 | :param rate: Rate with which to subtract from the limit of the \ 63 | identifier. 64 | """ 65 | 66 | key = f"{self._prefix}:{identifier}" 67 | return self._limiter.limit(self._redis, key, rate) 68 | 69 | def block_until_ready(self, identifier: str, timeout: float, rate: int = 1) -> Response: 70 | """ 71 | Blocks until the request may pass or timeout is reached. 72 | 73 | This method blocks until the request may be processed or the timeout 74 | has been reached. 75 | 76 | Use this if you want to delay the request until it is ready to get 77 | processed. 78 | 79 | .. code-block:: python 80 | 81 | from upstash_redis import Redis 82 | from upstash_ratelimit import Ratelimit, SlidingWindow 83 | 84 | ratelimit = Ratelimit( 85 | redis=Redis.from_env(), 86 | limiter=SlidingWindow(max_requests=10, window=10, unit="s"), 87 | ) 88 | 89 | response = ratelimit.block_until_ready("some-id", 60) 90 | if not response.allowed: 91 | print("Ratelimitted!") 92 | 93 | print("Good to go!") 94 | 95 | :param identifier: Identifier to ratelimit. Use a constant string to \ 96 | limit all requests, or user ids, API keys, or IP addresses for \ 97 | individual limits. 98 | :param timeout: Maximum time in seconds to wait until the request \ 99 | may pass. 100 | :param rate: Rate with which to subtract from the limit of the \ 101 | identifier. 102 | """ 103 | 104 | if timeout <= 0: 105 | raise ValueError("Timeout must be positive") 106 | 107 | response: Optional[Response] = None 108 | deadline = now_s() + timeout 109 | 110 | while True: 111 | response = self.limit(identifier, rate) 112 | if response.allowed: 113 | break 114 | 115 | wait = max(0, min(response.reset, deadline) - now_s()) 116 | time.sleep(wait) 117 | 118 | if now_s() > deadline: 119 | break 120 | 121 | return response 122 | 123 | def get_remaining(self, identifier: str) -> int: 124 | """ 125 | Returns the number of requests left for the given identifier. 126 | """ 127 | 128 | key = f"{self._prefix}:{identifier}" 129 | return self._limiter.get_remaining(self._redis, key) 130 | 131 | def get_reset(self, identifier: str) -> float: 132 | """ 133 | Returns the UNIX timestamp in seconds when the remaining 134 | requests will be reset or replenished. 135 | """ 136 | 137 | key = f"{self._prefix}:{identifier}" 138 | return self._limiter.get_reset(self._redis, key) 139 | -------------------------------------------------------------------------------- /upstash_ratelimit/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | UnitT = Literal["ms", "s", "m", "h", "d"] 4 | """ 5 | ms: milliseconds 6 | s: seconds 7 | m: minutes 8 | h: hours 9 | d: days 10 | """ 11 | -------------------------------------------------------------------------------- /upstash_ratelimit/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any 3 | 4 | from upstash_ratelimit import __version__ 5 | from upstash_ratelimit.typing import UnitT 6 | 7 | 8 | def merge_telemetry(redis: Any) -> None: 9 | if ( 10 | not hasattr(redis, "_allow_telemetry") 11 | or not redis._allow_telemetry 12 | or not hasattr(redis, "_headers") 13 | ): 14 | return 15 | 16 | sdk = redis._headers.get("Upstash-Telemetry-Sdk") 17 | if not sdk: 18 | return 19 | 20 | sdk = f"{sdk}, py-upstash-ratelimit@v{__version__}" 21 | redis._headers["Upstash-Telemetry-Sdk"] = sdk 22 | 23 | 24 | def ms_to_s(value: int) -> float: 25 | return value / 1_000 26 | 27 | 28 | def s_to_ms(value: float) -> int: 29 | return int(value * 1_000) 30 | 31 | 32 | def now_s() -> float: 33 | return time.time() 34 | 35 | 36 | def now_ms() -> int: 37 | return int(time.time() * 1_000) 38 | 39 | 40 | def to_ms(value: int, unit: UnitT) -> int: 41 | if unit == "ms": 42 | return value 43 | elif unit == "s": 44 | return value * 1_000 45 | elif unit == "m": 46 | return value * 60 * 1_000 47 | elif unit == "h": 48 | return value * 60 * 60 * 1_000 49 | elif unit == "d": 50 | return value * 24 * 60 * 60 * 1_000 51 | else: 52 | raise ValueError("Unexpected unit") 53 | --------------------------------------------------------------------------------