├── .github
└── workflows
│ ├── release.yaml
│ └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── examples
├── aws-lambda
│ ├── README.md
│ ├── main.py
│ └── script.sh
├── flask-server
│ ├── README.md
│ └── hello.py
└── vercel
│ ├── .gitignore
│ ├── Pipfile
│ ├── api
│ └── index.py
│ └── vercel.json
├── pyproject.toml
├── tests
├── __init__.py
├── asyncio
│ ├── __init__.py
│ ├── test_block_until_ready.py
│ ├── test_fixed_window.py
│ ├── test_sliding_window.py
│ └── test_token_bucket.py
├── conftest.py
├── test_block_until_ready.py
├── test_fixed_window.py
├── test_sliding_window.py
├── test_token_bucket.py
├── test_utils.py
└── utils.py
└── upstash_ratelimit
├── __init__.py
├── asyncio
├── __init__.py
└── ratelimit.py
├── limiter.py
├── py.typed
├── ratelimit.py
├── typing.py
└── utils.py
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on: workflow_dispatch
4 |
5 | jobs:
6 | release:
7 | runs-on: ubuntu-latest
8 |
9 | steps:
10 | - name: Checkout repository
11 | uses: actions/checkout@v2
12 |
13 | - name: Set up Python
14 | uses: actions/setup-python@v2
15 | with:
16 | python-version: 3.8
17 |
18 | - name: Install Poetry
19 | run: curl -sSL https://install.python-poetry.org | python3 -
20 |
21 | - name: Build and publish
22 | run: |
23 | poetry config pypi-token.pypi "${{ secrets.PYPI_TOKEN }}"
24 | poetry build
25 | poetry publish --no-interaction
26 |
27 | - name: Generate release tag
28 | run: echo "RELEASE_TAG=v$(poetry version | awk '{print $2}')" >> $GITHUB_ENV
29 |
30 | - name: Create GitHub Release
31 | uses: actions/create-release@v1
32 | env:
33 | GITHUB_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
34 | with:
35 | tag_name: ${{ env.RELEASE_TAG }}
36 | release_name: Release ${{ env.RELEASE_TAG }}
37 | draft: false
38 | prerelease: false
39 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | test:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Checkout repository
14 | uses: actions/checkout@v2
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v2
18 | with:
19 | python-version: 3.8
20 |
21 | - name: Install Poetry
22 | run: curl -sSL https://install.python-poetry.org | python3 -
23 |
24 | - name: Set up Poetry environment
25 | run: |
26 | poetry cache clear PyPI --all
27 | poetry install --no-root
28 |
29 | - name: Run mypy
30 | run: |
31 | poetry run mypy --show-error-codes .
32 |
33 | - name: Run tests
34 | run: |
35 | export UPSTASH_REDIS_REST_URL="${{ secrets.UPSTASH_REDIS_REST_URL }}"
36 | export UPSTASH_REDIS_REST_TOKEN="${{ secrets.UPSTASH_REDIS_REST_TOKEN }}"
37 | poetry run pytest
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | .DS_Store
163 |
164 | .vscode
165 |
166 | poetry.lock
167 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Tudor Zgîmbău
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Upstash Ratelimit Python SDK
2 |
3 | > [!NOTE]
4 | > **This project is in GA Stage.**
5 | > The Upstash Professional Support fully covers this project. It receives regular updates, and bug fixes. The Upstash team is committed to maintaining and improving its functionality.
6 |
7 | `upstash-ratelimit` is a connectionless rate limiting library for Python, designed to be used in serverless environments such as:
8 |
9 | - AWS Lambda
10 | - Vercel Serverless
11 | - Google Cloud Functions
12 | - and other environments where HTTP is preferred over TCP.
13 |
14 | The SDK is currently compatible with Python 3.8 and above.
15 |
16 |
17 | - [Upstash Ratelimit Python SDK](#upstash-ratelimit-python-sdk)
18 | - [Quick Start](#quick-start)
19 | - [Install](#install)
20 | - [Create database](#create-database)
21 | - [Usage](#usage)
22 | - [Block until ready](#block-until-ready)
23 | - [Using multiple limits](#using-multiple-limits)
24 | - [Ratelimiting algorithms](#ratelimiting-algorithms)
25 | - [Fixed Window](#fixed-window)
26 | - [Pros](#pros)
27 | - [Cons](#cons)
28 | - [Usage](#usage-1)
29 | - [Sliding Window](#sliding-window)
30 | - [Pros](#pros-1)
31 | - [Cons](#cons-1)
32 | - [Usage](#usage-2)
33 | - [Token Bucket](#token-bucket)
34 | - [Pros](#pros-2)
35 | - [Cons](#cons-2)
36 | - [Usage](#usage-3)
37 | - [Contributing](#contributing)
38 | - [Preparing the environment](#preparing-the-environment)
39 | - [Running tests](#running-tests)
40 |
41 |
42 | # Quick Start
43 |
44 | ## Install
45 |
46 | ```bash
47 | pip install upstash-ratelimit
48 | ```
49 |
50 | ## Create database
51 | To be able to use upstash-ratelimit, you need to create a database on [Upstash](https://console.upstash.com/).
52 |
53 | ## Usage
54 |
55 | For possible Redis client configurations, have a look at the [Redis SDK repository](https://github.com/upstash/redis-python).
56 |
57 | > This library supports asyncio as well. To use it, import the asyncio-based
58 | variant from the `upstash_ratelimit.asyncio` module.
59 |
60 | ```python
61 | from upstash_ratelimit import Ratelimit, FixedWindow
62 | from upstash_redis import Redis
63 |
64 | # Create a new ratelimiter, that allows 10 requests per 10 seconds
65 | ratelimit = Ratelimit(
66 | redis=Redis.from_env(),
67 | limiter=FixedWindow(max_requests=10, window=10),
68 | # Optional prefix for the keys used in Redis. This is useful
69 | # if you want to share a Redis instance with other applications
70 | # and want to avoid key collisions. The default prefix is
71 | # "@upstash/ratelimit"
72 | prefix="@upstash/ratelimit",
73 | )
74 |
75 | # Use a constant string to limit all requests with a single ratelimit
76 | # Or use a user ID, API key or IP address for individual limits.
77 | identifier = "api"
78 | response = ratelimit.limit(identifier)
79 |
80 | if not response.allowed:
81 | print("Unable to process at this time")
82 | else:
83 | do_expensive_calculation()
84 | print("Here you go!")
85 |
86 | ```
87 |
88 | The `limit` method also returns the following metadata:
89 |
90 |
91 | ```python
92 | @dataclasses.dataclass
93 | class Response:
94 | allowed: bool
95 | """
96 | Whether the request may pass(`True`) or exceeded the limit(`False`)
97 | """
98 |
99 | limit: int
100 | """
101 | Maximum number of requests allowed within a window.
102 | """
103 |
104 | remaining: int
105 | """
106 | How many requests the user has left within the current window.
107 | """
108 |
109 | reset: float
110 | """
111 | Unix timestamp in seconds when the limits are reset
112 | """
113 | ```
114 |
115 | ## Block until ready
116 |
117 | You also have the option to try and wait for a request to pass in the given timeout.
118 |
119 | It is very similar to the `limit` method and takes an identifier and returns the same
120 | response. However if the current limit has already been exceeded, it will automatically
121 | wait until the next window starts and will try again. Setting the timeout parameter (in seconds) will cause the method to block a finite amount of time.
122 |
123 | ```python
124 | from upstash_ratelimit import Ratelimit, SlidingWindow
125 | from upstash_redis import Redis
126 |
127 | # Create a new ratelimiter, that allows 10 requests per 10 seconds
128 | ratelimit = Ratelimit(
129 | redis=Redis.from_env(),
130 | limiter=SlidingWindow(max_requests=10, window=10),
131 | )
132 |
133 | response = ratelimit.block_until_ready("id", timeout=30)
134 |
135 | if not response.allowed:
136 | print("Unable to process, even after 30 seconds")
137 | else:
138 | do_expensive_calculation()
139 | print("Here you go!")
140 | ```
141 |
142 |
143 | ## Using multiple limits
144 | Sometimes you might want to apply different limits to different users. For example you might want to allow 10 requests per 10 seconds for free users, but 60 requests per 10 seconds for paid users.
145 |
146 | Here's how you could do that:
147 |
148 | ```python
149 | from upstash_ratelimit import Ratelimit, SlidingWindow
150 | from upstash_redis import Redis
151 |
152 | class MultiRL:
153 | def __init__(self) -> None:
154 | redis = Redis.from_env()
155 | self.free = Ratelimit(
156 | redis=redis,
157 | limiter=SlidingWindow(max_requests=10, window=10),
158 | prefix="ratelimit:free",
159 | )
160 |
161 | self.paid = Ratelimit(
162 | redis=redis,
163 | limiter=SlidingWindow(max_requests=60, window=10),
164 | prefix="ratelimit:paid",
165 | )
166 |
167 | # Create a new ratelimiter, that allows 10 requests per 10 seconds
168 | ratelimit = MultiRL()
169 |
170 | ratelimit.free.limit("userIP")
171 | ratelimit.paid.limit("userIP")
172 | ```
173 |
174 | # Ratelimiting algorithms
175 |
176 | ## Fixed Window
177 |
178 | This algorithm divides time into fixed durations/windows. For example each window is 10 seconds long. When a new request comes in, the current time is used to determine the window and a counter is increased. If the counter is larger than the set limit, the request is rejected.
179 |
180 | ### Pros
181 | - Very cheap in terms of data size and computation
182 | - Newer requests are not starved due to a high burst in the past
183 |
184 | ### Cons
185 | - Can cause high bursts at the window boundaries to leak through
186 | - Causes request stampedes if many users are trying to access your server, whenever a new window begins
187 |
188 | ### Usage
189 |
190 | ```python
191 | from upstash_ratelimit import Ratelimit, FixedWindow
192 | from upstash_redis import Redis
193 |
194 | ratelimit = Ratelimit(
195 | redis=Redis.from_env(),
196 | limiter=FixedWindow(max_requests=10, window=10),
197 | )
198 | ```
199 |
200 | ## Sliding Window
201 |
202 | Builds on top of fixed window but instead of a fixed window, we use a rolling window. Take this example: We have a rate limit of 10 requests per 1 minute. We divide time into 1 minute slices, just like in the fixed window algorithm. Window 1 will be from 00:00:00 to 00:01:00 (HH:MM:SS). Let's assume it is currently 00:01:15 and we have received 4 requests in the first window and 5 requests so far in the current window. The approximation to determine if the request should pass works like this:
203 |
204 | ```python
205 | limit = 10
206 |
207 | # 4 request from the old window, weighted + requests in current window
208 | rate = 4 * ((60 - 15) / 60) + 5 = 8
209 |
210 | return rate < limit # True means we should allow the request
211 | ```
212 |
213 | ### Pros
214 | - Solves the issue near boundary from fixed window.
215 |
216 | ### Cons
217 | - More expensive in terms of storage and computation
218 | - It's only an approximation because it assumes a uniform request flow in the previous window
219 |
220 | ### Usage
221 |
222 | ```python
223 | from upstash_ratelimit import Ratelimit, SlidingWindow
224 | from upstash_redis import Redis
225 |
226 | ratelimit = Ratelimit(
227 | redis=Redis.from_env(),
228 | limiter=SlidingWindow(max_requests=10, window=10),
229 | )
230 | ```
231 |
232 | ## Token Bucket
233 |
234 | Consider a bucket filled with maximum number of tokens that refills constantly at a rate per interval. Every request will remove one token from the bucket and if there is no token to take, the request is rejected.
235 |
236 | ### Pros
237 | - Bursts of requests are smoothed out and you can process them at a constant rate.
238 | - Allows setting a higher initial burst limit by setting maximum number of tokens higher than the refill rate
239 |
240 | ### Cons
241 | - Expensive in terms of computation
242 |
243 | ### Usage
244 |
245 | ```python
246 | from upstash_ratelimit import Ratelimit, TokenBucket
247 | from upstash_redis import Redis
248 |
249 | ratelimit = Ratelimit(
250 | redis=Redis.from_env(),
251 | limiter=TokenBucket(max_tokens=10, refill_rate=5, interval=10),
252 | )
253 | ```
254 |
255 | # Custom Rates
256 |
257 | When rate limiting, you may want different requests to consume different amounts of tokens.
258 | This could be useful when processing batches of requests where you want to rate limit based
259 | on items in the batch or when you want to rate limit based on the number of tokens.
260 |
261 | To achieve this, you can simply pass `rate` parameter when calling the limit method:
262 |
263 | ```python
264 |
265 | from upstash_ratelimit import Ratelimit, FixedWindow
266 | from upstash_redis import Redis
267 |
268 | ratelimit = Ratelimit(
269 | redis=Redis.from_env(),
270 | limiter=FixedWindow(max_requests=10, window=10),
271 | )
272 |
273 | # pass rate as 5 to subtract 5 from the number of
274 | # allowed requests in the window:
275 | identifier = "api"
276 | response = ratelimit.limit(identifier, rate=5)
277 | ```
278 |
279 | # Contributing
280 |
281 | ## Preparing the environment
282 | This project uses [Poetry](https://python-poetry.org) for packaging and dependency management. Make sure you are able to create the poetry shell with relevant dependencies.
283 |
284 | You will also need a database on [Upstash](https://console.upstash.com/).
285 |
286 | ## Running tests
287 | To run all the tests, make sure the poetry virtual environment activated with all
288 | the necessary dependencies. Set the `UPSTASH_REDIS_REST_URL` and `UPSTASH_REDIS_REST_TOKEN` environment variables and run:
289 |
290 | ```bash
291 | poetry run pytest
292 | ```
293 |
--------------------------------------------------------------------------------
/examples/aws-lambda/README.md:
--------------------------------------------------------------------------------
1 | ### Simple Flask Server
2 |
3 | #### Run
4 | ```bash
5 | ./script
6 | ```
7 | This script will install necessary depenencies and generate the .zip file ready for upload.
8 | #### Test
9 | Upload the `lambda.zip` file to aws lambda and test by calling the endpoint.
--------------------------------------------------------------------------------
/examples/aws-lambda/main.py:
--------------------------------------------------------------------------------
1 | from upstash_redis import Redis
2 |
3 | from upstash_ratelimit import FixedWindow, Ratelimit
4 |
5 | ratelimit = Ratelimit(
6 | redis=Redis.from_env(allow_telemetry=False),
7 | limiter=FixedWindow(max_requests=1, window=10),
8 | )
9 |
10 |
11 | def lambda_handler(event, context):
12 | response = ratelimit.limit("id")
13 | print(response)
14 |
--------------------------------------------------------------------------------
/examples/aws-lambda/script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | rm -rf dist && rm lambda.zip
3 | pip3 install --target ./dist upstash-ratelimit
4 | cp main.py ./dist/lambda_function.py
5 | cd dist && zip -r lambda.zip . && cd -
6 | mv ./dist/lambda.zip ./
7 | # rm -rf dist
--------------------------------------------------------------------------------
/examples/flask-server/README.md:
--------------------------------------------------------------------------------
1 | ### Simple Flask Server
2 |
3 | #### Run
4 | ```bash
5 | pip3 install flask
6 | flask --app hello run
7 | ```
8 |
9 | #### Test
10 | Send request to localhost:5000/request to see the rate limiter in action.
--------------------------------------------------------------------------------
/examples/flask-server/hello.py:
--------------------------------------------------------------------------------
1 | from flask import Flask # type: ignore
2 | from upstash_redis import Redis
3 |
4 | from upstash_ratelimit import FixedWindow, Ratelimit
5 |
6 | ratelimit = Ratelimit(
7 | redis=Redis.from_env(allow_telemetry=False),
8 | limiter=FixedWindow(max_requests=2, window=40),
9 | )
10 |
11 | app = Flask(__name__)
12 |
13 |
14 | @app.route("/")
15 | def hello_world():
16 | return "
Hello, World!
"
17 |
18 |
19 | @app.route("/request")
20 | def request():
21 | response = ratelimit.block_until_ready("timeout_1", 10)
22 | return f"{response}
"
23 |
--------------------------------------------------------------------------------
/examples/vercel/.gitignore:
--------------------------------------------------------------------------------
1 | .vercel
2 |
--------------------------------------------------------------------------------
/examples/vercel/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | url = "https://pypi.org/simple"
3 | verify_ssl = true
4 | name = "pypi"
5 |
6 | [packages]
7 | flask = "*"
8 | upstash-ratelimit = "*"
9 |
10 | [requires]
11 | python_version = "3.9"
12 |
--------------------------------------------------------------------------------
/examples/vercel/api/index.py:
--------------------------------------------------------------------------------
1 | from http.server import BaseHTTPRequestHandler
2 |
3 | from upstash_redis import Redis
4 |
5 | from upstash_ratelimit import FixedWindow, Ratelimit
6 |
7 | ratelimit = Ratelimit(
8 | redis=Redis.from_env(allow_telemetry=False),
9 | limiter=FixedWindow(max_requests=5, window=5),
10 | )
11 |
12 |
13 | class handler(BaseHTTPRequestHandler):
14 | def do_GET(self):
15 | response = ratelimit.limit("global")
16 | self.send_header("Content-type", "text/plain")
17 | self.send_header("X-Ratelimit-Limit", response.limit)
18 | self.send_header("X-Ratelimit-Remaining", response.remaining)
19 | self.send_header("X-Ratelimit-Reset", response.reset)
20 |
21 | self.end_headers()
22 | if not response.allowed:
23 | self.send_response(429)
24 | self.wfile.write("Come back later!".encode("utf-8"))
25 | else:
26 | self.send_response(200)
27 | self.wfile.write("Hello!".encode("utf-8"))
28 |
--------------------------------------------------------------------------------
/examples/vercel/vercel.json:
--------------------------------------------------------------------------------
1 | {
2 | "rewrites": [
3 | {
4 | "source": "/(.*)",
5 | "destination": "/api/index"
6 | }
7 | ]
8 | }
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "upstash-ratelimit"
3 | version = "1.1.0"
4 | description = "Serverless ratelimiting package from Upstash"
5 | authors = ["Upstash ", "Zgîmbău Tudor "]
6 | maintainers = ["Upstash "]
7 | readme = "README.md"
8 | repository = "https://github.com/upstash/ratelimit-python"
9 | keywords = ["ratelimit", "rate limit", "Upstash rate limit", "Redis rate limit"]
10 | classifiers = [
11 | "Development Status :: 5 - Production/Stable",
12 | "Intended Audience :: Developers",
13 | "License :: OSI Approved :: MIT License",
14 | "Operating System :: OS Independent",
15 | "Programming Language :: Python",
16 | "Programming Language :: Python :: 3",
17 | "Programming Language :: Python :: 3 :: Only",
18 | "Programming Language :: Python :: 3.8",
19 | "Programming Language :: Python :: 3.9",
20 | "Programming Language :: Python :: 3.10",
21 | "Programming Language :: Python :: 3.11",
22 | "Programming Language :: Python :: 3.12",
23 | "Programming Language :: Python :: Implementation :: CPython",
24 | "Topic :: Software Development :: Libraries",
25 | ]
26 | packages = [{include = "upstash_ratelimit"}]
27 |
28 | [tool.poetry.dependencies]
29 | python = "^3.8"
30 | upstash-redis = "^1.0.0"
31 |
32 | [tool.poetry.group.dev.dependencies]
33 | pytest = "^7.3.0"
34 | pytest-asyncio = "^0.21.0"
35 | mypy = "^1.4.1"
36 |
37 | [build-system]
38 | requires = ["poetry-core"]
39 | build-backend = "poetry.core.masonry.api"
40 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/upstash/ratelimit-py/bd70dfdb93ccfc0fea7a3d854c7942f8dce21f0a/tests/__init__.py
--------------------------------------------------------------------------------
/tests/asyncio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/upstash/ratelimit-py/bd70dfdb93ccfc0fea7a3d854c7942f8dce21f0a/tests/asyncio/__init__.py
--------------------------------------------------------------------------------
/tests/asyncio/test_block_until_ready.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from pytest import mark, raises
4 | from upstash_redis.asyncio import Redis
5 |
6 | from tests.utils import random_id
7 | from upstash_ratelimit.asyncio import FixedWindow, Ratelimit
8 |
9 |
10 | @mark.asyncio()
11 | async def test_invalid_timeout(async_redis: Redis) -> None:
12 | ratelimit = Ratelimit(
13 | redis=async_redis,
14 | limiter=FixedWindow(max_requests=1, window=1),
15 | )
16 |
17 | with raises(ValueError):
18 | await ratelimit.block_until_ready(random_id(), -1)
19 |
20 |
21 | @mark.asyncio()
22 | async def test_resolve_before_timeout(async_redis: Redis) -> None:
23 | ratelimit = Ratelimit(
24 | redis=async_redis,
25 | limiter=FixedWindow(max_requests=5, window=100),
26 | )
27 |
28 | timeout = 50
29 |
30 | start = time.time()
31 | response = await ratelimit.block_until_ready(random_id(), timeout)
32 | elapsed = time.time() - start
33 |
34 | assert elapsed < timeout
35 | assert response.allowed is True
36 |
37 |
38 | @mark.asyncio()
39 | async def test_resolve_before_timeout_when_window_resets(
40 | async_redis: Redis,
41 | ) -> None:
42 | ratelimit = Ratelimit(
43 | redis=async_redis,
44 | limiter=FixedWindow(max_requests=1, window=3),
45 | )
46 |
47 | id = random_id()
48 | timeout = 100
49 |
50 | await ratelimit.limit(id)
51 |
52 | start = time.time()
53 | response = await ratelimit.block_until_ready(id, timeout)
54 | elapsed = time.time() - start
55 |
56 | assert elapsed < timeout
57 | assert response.allowed is True
58 |
59 |
60 | @mark.asyncio()
61 | async def test_reaching_timeout(async_redis: Redis) -> None:
62 | ratelimit = Ratelimit(
63 | redis=async_redis,
64 | limiter=FixedWindow(max_requests=1, window=1, unit="d"),
65 | )
66 |
67 | id = random_id()
68 | timeout = 2
69 |
70 | await ratelimit.limit(id)
71 |
72 | start = time.time()
73 | response = await ratelimit.block_until_ready(id, timeout)
74 | elapsed = time.time() - start
75 |
76 | assert elapsed >= timeout
77 | assert response.allowed is False
78 |
--------------------------------------------------------------------------------
/tests/asyncio/test_fixed_window.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from unittest.mock import patch
3 |
4 | from pytest import approx, mark
5 | from upstash_redis.asyncio import Redis
6 |
7 | from tests.utils import random_id
8 | from upstash_ratelimit.asyncio import FixedWindow, Ratelimit
9 | from upstash_ratelimit.utils import now_s
10 |
11 |
12 | @mark.asyncio()
13 | async def test_max_requests_are_not_reached(async_redis: Redis) -> None:
14 | ratelimit = Ratelimit(
15 | redis=async_redis,
16 | limiter=FixedWindow(max_requests=5, window=10),
17 | )
18 |
19 | now = now_s()
20 | response = await ratelimit.limit(random_id())
21 |
22 | assert response.allowed is True
23 | assert response.limit == 5
24 | assert response.remaining == 4
25 | assert response.reset >= now
26 |
27 |
28 | @mark.asyncio()
29 | async def test_max_requests_are_reached(async_redis: Redis) -> None:
30 | ratelimit = Ratelimit(
31 | redis=async_redis,
32 | limiter=FixedWindow(max_requests=1, window=1, unit="d"),
33 | )
34 |
35 | id = random_id()
36 |
37 | await ratelimit.limit(id)
38 |
39 | now = now_s()
40 | response = await ratelimit.limit(id)
41 |
42 | assert response.allowed is False
43 | assert response.limit == 1
44 | assert response.remaining == 0
45 | assert response.reset >= now
46 |
47 |
48 | @mark.asyncio()
49 | async def test_window_reset(async_redis: Redis) -> None:
50 | ratelimit = Ratelimit(
51 | redis=async_redis,
52 | limiter=FixedWindow(max_requests=1, window=3),
53 | )
54 |
55 | id = random_id()
56 |
57 | await ratelimit.limit(id)
58 |
59 | await asyncio.sleep(3)
60 |
61 | now = now_s()
62 | response = await ratelimit.limit(id)
63 |
64 | assert response.allowed is True
65 | assert response.limit == 1
66 | assert response.remaining == 0
67 | assert response.reset >= now
68 |
69 |
70 | @mark.asyncio()
71 | async def test_get_remaining(async_redis: Redis) -> None:
72 | ratelimit = Ratelimit(
73 | redis=async_redis,
74 | limiter=FixedWindow(max_requests=10, window=1, unit="d"),
75 | )
76 |
77 | id = random_id()
78 | assert await ratelimit.get_remaining(id) == 10
79 | await ratelimit.limit(id)
80 | assert await ratelimit.get_remaining(id) == 9
81 |
82 |
83 | @mark.asyncio()
84 | async def test_get_reset(async_redis: Redis) -> None:
85 | ratelimit = Ratelimit(
86 | redis=async_redis,
87 | limiter=FixedWindow(max_requests=10, window=5),
88 | )
89 |
90 | with patch("time.time", return_value=1688910786.167):
91 | assert await ratelimit.get_reset(random_id()) == approx(1688910790.0)
92 |
--------------------------------------------------------------------------------
/tests/asyncio/test_sliding_window.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import List
3 | from unittest.mock import patch
4 |
5 | from pytest import approx, mark
6 | from upstash_redis.asyncio import Redis
7 |
8 | from tests.utils import random_id
9 | from upstash_ratelimit.asyncio import Ratelimit, Response, SlidingWindow
10 | from upstash_ratelimit.utils import now_s
11 |
12 |
13 | @mark.asyncio()
14 | async def test_max_requests_are_not_reached(async_redis: Redis) -> None:
15 | ratelimit = Ratelimit(
16 | redis=async_redis,
17 | limiter=SlidingWindow(max_requests=5, window=10),
18 | )
19 |
20 | now = now_s()
21 | response = await ratelimit.limit(random_id())
22 |
23 | assert response.allowed is True
24 | assert response.limit == 5
25 | assert response.remaining == 4
26 | assert response.reset >= now
27 |
28 |
29 | @mark.asyncio()
30 | async def test_max_requests_are_reached(async_redis: Redis) -> None:
31 | ratelimit = Ratelimit(
32 | redis=async_redis,
33 | limiter=SlidingWindow(max_requests=1, window=1, unit="d"),
34 | )
35 |
36 | id = random_id()
37 |
38 | await ratelimit.limit(id)
39 |
40 | now = now_s()
41 | response = await ratelimit.limit(id)
42 |
43 | assert response.allowed is False
44 | assert response.limit == 1
45 | assert response.remaining == 0
46 | assert response.reset >= now
47 |
48 |
49 | @mark.asyncio()
50 | async def test_window_reset(async_redis: Redis) -> None:
51 | ratelimit = Ratelimit(
52 | redis=async_redis,
53 | limiter=SlidingWindow(max_requests=1, window=3),
54 | )
55 |
56 | id = random_id()
57 |
58 | await ratelimit.limit(id)
59 |
60 | await asyncio.sleep(3)
61 |
62 | now = now_s()
63 | response = await ratelimit.limit(id)
64 |
65 | assert response.allowed is True
66 | assert response.limit == 1
67 | assert response.remaining == 0
68 | assert response.reset >= now
69 |
70 |
71 | @mark.asyncio()
72 | async def test_sliding(async_redis: Redis) -> None:
73 | ratelimit = Ratelimit(
74 | redis=async_redis,
75 | limiter=SlidingWindow(max_requests=5, window=3),
76 | )
77 |
78 | id = random_id()
79 |
80 | responses: List[Response] = []
81 |
82 | while True:
83 | response = await ratelimit.limit(id)
84 | responses.append(response)
85 |
86 | if len(responses) > 5 and responses[-2].reset != response.reset:
87 | break
88 |
89 | last_response = responses[-1]
90 |
91 | # We should consider some items from the last window so the
92 | # remaining should not be equal to max_requests - 1
93 | assert last_response.remaining != 4
94 |
95 |
96 | @mark.asyncio()
97 | async def test_get_remaining(async_redis: Redis) -> None:
98 | ratelimit = Ratelimit(
99 | redis=async_redis,
100 | limiter=SlidingWindow(max_requests=10, window=1, unit="d"),
101 | )
102 |
103 | id = random_id()
104 | assert await ratelimit.get_remaining(id) == 10
105 | await ratelimit.limit(id)
106 | assert await ratelimit.get_remaining(id) == 9
107 |
108 |
109 | @mark.asyncio()
110 | async def test_get_remaining_with_sliding(async_redis: Redis) -> None:
111 | ratelimit = Ratelimit(
112 | redis=async_redis,
113 | limiter=SlidingWindow(max_requests=5, window=3),
114 | )
115 |
116 | id = random_id()
117 |
118 | responses: List[Response] = []
119 |
120 | while True:
121 | response = await ratelimit.limit(id)
122 | responses.append(response)
123 |
124 | if len(responses) > 5 and responses[-2].reset != response.reset:
125 | break
126 |
127 | last_response = responses[-1]
128 | assert await ratelimit.get_remaining(id) == last_response.remaining
129 |
130 |
131 | @mark.asyncio()
132 | async def test_get_reset(async_redis: Redis) -> None:
133 | ratelimit = Ratelimit(
134 | redis=async_redis,
135 | limiter=SlidingWindow(max_requests=10, window=5),
136 | )
137 |
138 | with patch("time.time", return_value=1688910786.167):
139 | assert await ratelimit.get_reset(random_id()) == approx(1688910790.0)
140 |
--------------------------------------------------------------------------------
/tests/asyncio/test_token_bucket.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from pytest import mark
4 | from upstash_redis.asyncio import Redis
5 |
6 | from tests.utils import random_id
7 | from upstash_ratelimit.asyncio import Ratelimit, TokenBucket
8 | from upstash_ratelimit.utils import now_s
9 |
10 |
11 | @mark.asyncio()
12 | async def test_max_tokens_are_not_reached(async_redis: Redis) -> None:
13 | ratelimit = Ratelimit(
14 | redis=async_redis,
15 | limiter=TokenBucket(max_tokens=5, refill_rate=5, interval=1, unit="d"),
16 | )
17 |
18 | now = now_s()
19 | response = await ratelimit.limit(random_id())
20 |
21 | assert response.allowed is True
22 | assert response.limit == 5
23 | assert response.remaining == 4
24 | assert response.reset >= now
25 |
26 |
27 | @mark.asyncio()
28 | async def test_max_tokens_are_reached(async_redis: Redis) -> None:
29 | ratelimit = Ratelimit(
30 | redis=async_redis,
31 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=1, unit="d"),
32 | )
33 |
34 | id = random_id()
35 |
36 | await ratelimit.limit(id)
37 |
38 | now = now_s()
39 | response = await ratelimit.limit(id)
40 |
41 | assert response.allowed is False
42 | assert response.limit == 1
43 | assert response.remaining == 0
44 | assert response.reset >= now
45 |
46 |
47 | @mark.asyncio()
48 | async def test_refill(async_redis: Redis) -> None:
49 | ratelimit = Ratelimit(
50 | redis=async_redis,
51 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=3),
52 | )
53 |
54 | id = random_id()
55 |
56 | await ratelimit.limit(id)
57 |
58 | await asyncio.sleep(3)
59 |
60 | now = now_s()
61 | response = await ratelimit.limit(id)
62 |
63 | assert response.allowed is True
64 | assert response.limit == 1
65 | assert response.remaining == 0
66 | assert response.reset >= now
67 |
68 |
69 | @mark.asyncio()
70 | async def test_refill_multiple_times(async_redis: Redis) -> None:
71 | ratelimit = Ratelimit(
72 | redis=async_redis,
73 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
74 | )
75 |
76 | id = random_id()
77 |
78 | last_response = None
79 | for _ in range(10):
80 | last_response = await ratelimit.limit(id)
81 |
82 | assert last_response is not None
83 | last_remaining = last_response.remaining
84 |
85 | await asyncio.sleep(3)
86 |
87 | response = await ratelimit.limit(id)
88 | assert response.remaining >= last_remaining + 2
89 |
90 |
91 | @mark.asyncio()
92 | async def test_get_remaining(async_redis: Redis) -> None:
93 | ratelimit = Ratelimit(
94 | redis=async_redis,
95 | limiter=TokenBucket(max_tokens=10, refill_rate=10, interval=1, unit="d"),
96 | )
97 |
98 | id = random_id()
99 | assert await ratelimit.get_remaining(id) == 10
100 | await ratelimit.limit(id)
101 | assert await ratelimit.get_remaining(id) == 9
102 |
103 |
104 | @mark.asyncio()
105 | async def test_get_remaining_with_refills_that_should_be_made(
106 | async_redis: Redis,
107 | ) -> None:
108 | ratelimit = Ratelimit(
109 | redis=async_redis,
110 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
111 | )
112 |
113 | id = random_id()
114 |
115 | last_response = None
116 | for _ in range(10):
117 | last_response = await ratelimit.limit(id)
118 |
119 | assert last_response is not None
120 | last_remaining = last_response.remaining
121 |
122 | await asyncio.sleep(3)
123 |
124 | assert await ratelimit.get_remaining(id) >= last_remaining + 2
125 |
126 |
127 | @mark.asyncio()
128 | async def test_get_reset(async_redis: Redis) -> None:
129 | ratelimit = Ratelimit(
130 | redis=async_redis,
131 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
132 | )
133 |
134 | id = random_id()
135 | now = now_s()
136 | await ratelimit.limit(id)
137 |
138 | assert await ratelimit.get_reset(id) >= now + 0.9
139 |
140 |
141 | @mark.asyncio()
142 | async def test_get_reset_with_refills_that_should_be_made(
143 | async_redis: Redis,
144 | ) -> None:
145 | ratelimit = Ratelimit(
146 | redis=async_redis,
147 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
148 | )
149 |
150 | id = random_id()
151 |
152 | last_response = None
153 | for _ in range(10):
154 | last_response = await ratelimit.limit(id)
155 |
156 | assert last_response is not None
157 | last_reset = last_response.reset
158 |
159 | await asyncio.sleep(3)
160 |
161 | assert await ratelimit.get_reset(id) >= last_reset + 2
162 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest_asyncio
2 | from pytest import fixture
3 | from upstash_redis import Redis
4 | from upstash_redis.asyncio import Redis as AsyncRedis
5 |
6 |
7 | @fixture
8 | def redis():
9 | with Redis.from_env() as redis:
10 | yield redis
11 |
12 |
13 | @pytest_asyncio.fixture
14 | async def async_redis():
15 | async with AsyncRedis.from_env() as redis:
16 | yield redis
17 |
18 |
19 | @fixture(scope="session", autouse=True)
20 | def setup_cleanup():
21 | with Redis.from_env() as redis:
22 | redis.flushdb()
23 | yield
24 | redis.flushdb()
25 |
--------------------------------------------------------------------------------
/tests/test_block_until_ready.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from pytest import raises
4 | from upstash_redis import Redis
5 |
6 | from tests.utils import random_id
7 | from upstash_ratelimit import FixedWindow, Ratelimit
8 |
9 |
10 | def test_invalid_timeout(redis: Redis) -> None:
11 | ratelimit = Ratelimit(
12 | redis=redis,
13 | limiter=FixedWindow(max_requests=1, window=1),
14 | )
15 |
16 | with raises(ValueError):
17 | ratelimit.block_until_ready(random_id(), -1)
18 |
19 |
20 | def test_resolve_before_timeout(redis: Redis) -> None:
21 | ratelimit = Ratelimit(
22 | redis=redis,
23 | limiter=FixedWindow(max_requests=5, window=100),
24 | )
25 |
26 | timeout = 50
27 |
28 | start = time.time()
29 | response = ratelimit.block_until_ready(random_id(), timeout)
30 | elapsed = time.time() - start
31 |
32 | assert elapsed < timeout
33 | assert response.allowed is True
34 |
35 |
36 | def test_resolve_before_timeout_when_window_resets(redis: Redis) -> None:
37 | ratelimit = Ratelimit(
38 | redis=redis,
39 | limiter=FixedWindow(max_requests=1, window=3),
40 | )
41 |
42 | id = random_id()
43 | timeout = 100
44 |
45 | ratelimit.limit(id)
46 |
47 | start = time.time()
48 | response = ratelimit.block_until_ready(id, timeout)
49 | elapsed = time.time() - start
50 |
51 | assert elapsed < timeout
52 | assert response.allowed is True
53 |
54 |
55 | def test_reaching_timeout(redis: Redis) -> None:
56 | ratelimit = Ratelimit(
57 | redis=redis,
58 | limiter=FixedWindow(max_requests=1, window=1, unit="d"),
59 | )
60 |
61 | id = random_id()
62 | timeout = 2
63 |
64 | ratelimit.limit(id)
65 |
66 | start = time.time()
67 | response = ratelimit.block_until_ready(id, timeout)
68 | elapsed = time.time() - start
69 |
70 | assert elapsed >= timeout
71 | assert response.allowed is False
72 |
--------------------------------------------------------------------------------
/tests/test_fixed_window.py:
--------------------------------------------------------------------------------
1 | import time
2 | from unittest.mock import patch
3 |
4 | from pytest import approx
5 | from upstash_redis import Redis
6 |
7 | from tests.utils import random_id
8 | from upstash_ratelimit import FixedWindow, Ratelimit
9 | from upstash_ratelimit.utils import now_s
10 |
11 |
12 | def test_max_requests_are_not_reached(redis: Redis) -> None:
13 | ratelimit = Ratelimit(
14 | redis=redis,
15 | limiter=FixedWindow(max_requests=5, window=10),
16 | )
17 |
18 | now = now_s()
19 | response = ratelimit.limit(random_id())
20 |
21 | assert response.allowed is True
22 | assert response.limit == 5
23 | assert response.remaining == 4
24 | assert response.reset >= now
25 |
26 |
27 | def test_max_requests_are_reached(redis: Redis) -> None:
28 | ratelimit = Ratelimit(
29 | redis=redis,
30 | limiter=FixedWindow(max_requests=1, window=1, unit="d"),
31 | )
32 |
33 | id = random_id()
34 |
35 | ratelimit.limit(id)
36 |
37 | now = now_s()
38 | response = ratelimit.limit(id)
39 |
40 | assert response.allowed is False
41 | assert response.limit == 1
42 | assert response.remaining == 0
43 | assert response.reset >= now
44 |
45 |
46 | def test_window_reset(redis: Redis) -> None:
47 | ratelimit = Ratelimit(
48 | redis=redis,
49 | limiter=FixedWindow(max_requests=1, window=3),
50 | )
51 |
52 | id = random_id()
53 |
54 | ratelimit.limit(id)
55 |
56 | time.sleep(3)
57 |
58 | now = now_s()
59 | response = ratelimit.limit(id)
60 |
61 | assert response.allowed is True
62 | assert response.limit == 1
63 | assert response.remaining == 0
64 | assert response.reset >= now
65 |
66 |
67 | def test_get_remaining(redis: Redis) -> None:
68 | ratelimit = Ratelimit(
69 | redis=redis,
70 | limiter=FixedWindow(max_requests=10, window=1, unit="d"),
71 | )
72 |
73 | id = random_id()
74 | assert ratelimit.get_remaining(id) == 10
75 | ratelimit.limit(id)
76 | assert ratelimit.get_remaining(id) == 9
77 |
78 |
79 | def test_get_reset(redis: Redis) -> None:
80 | ratelimit = Ratelimit(
81 | redis=redis,
82 | limiter=FixedWindow(max_requests=10, window=5),
83 | )
84 |
85 | with patch("time.time", return_value=1688910786.167):
86 | assert ratelimit.get_reset(random_id()) == approx(1688910790.0)
87 |
88 |
89 | def test_custom_rate(redis: Redis) -> None:
90 | ratelimit = Ratelimit(
91 | redis=redis,
92 | limiter=FixedWindow(max_requests=10, window=1, unit="d"),
93 | )
94 | rate = 2
95 |
96 | id = random_id()
97 |
98 | ratelimit.limit(id)
99 | ratelimit.limit(id, rate)
100 | assert ratelimit.get_remaining(id) == 7
101 |
102 | ratelimit.limit(id, rate)
103 | assert ratelimit.get_remaining(id) == 5
104 |
--------------------------------------------------------------------------------
/tests/test_sliding_window.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import List
3 | from unittest.mock import patch
4 |
5 | from pytest import approx
6 | from upstash_redis import Redis
7 |
8 | from tests.utils import random_id
9 | from upstash_ratelimit import Ratelimit, Response, SlidingWindow
10 | from upstash_ratelimit.utils import now_s
11 |
12 |
13 | def test_max_requests_are_not_reached(redis: Redis) -> None:
14 | ratelimit = Ratelimit(
15 | redis=redis,
16 | limiter=SlidingWindow(max_requests=5, window=10),
17 | )
18 |
19 | now = now_s()
20 | response = ratelimit.limit(random_id())
21 |
22 | assert response.allowed is True
23 | assert response.limit == 5
24 | assert response.remaining == 4
25 | assert response.reset >= now
26 |
27 |
28 | def test_max_requests_are_reached(redis: Redis) -> None:
29 | ratelimit = Ratelimit(
30 | redis=redis,
31 | limiter=SlidingWindow(max_requests=1, window=1, unit="d"),
32 | )
33 |
34 | id = random_id()
35 |
36 | ratelimit.limit(id)
37 |
38 | now = now_s()
39 | response = ratelimit.limit(id)
40 |
41 | assert response.allowed is False
42 | assert response.limit == 1
43 | assert response.remaining == 0
44 | assert response.reset >= now
45 |
46 |
47 | def test_window_reset(redis: Redis) -> None:
48 | ratelimit = Ratelimit(
49 | redis=redis,
50 | limiter=SlidingWindow(max_requests=1, window=3),
51 | )
52 |
53 | id = random_id()
54 |
55 | ratelimit.limit(id)
56 |
57 | time.sleep(3)
58 |
59 | now = now_s()
60 | response = ratelimit.limit(id)
61 |
62 | assert response.allowed is True
63 | assert response.limit == 1
64 | assert response.remaining == 0
65 | assert response.reset >= now
66 |
67 |
68 | def test_sliding(redis: Redis) -> None:
69 | ratelimit = Ratelimit(
70 | redis=redis,
71 | limiter=SlidingWindow(max_requests=5, window=3),
72 | )
73 |
74 | id = random_id()
75 |
76 | responses: List[Response] = []
77 |
78 | while True:
79 | response = ratelimit.limit(id)
80 | responses.append(response)
81 |
82 | if len(responses) > 5 and responses[-2].reset != response.reset:
83 | break
84 |
85 | last_response = responses[-1]
86 |
87 | # We should consider some items from the last window so the
88 | # remaining should not be equal to max_requests - 1
89 | assert last_response.remaining != 4
90 |
91 |
92 | def test_get_remaining(redis: Redis) -> None:
93 | ratelimit = Ratelimit(
94 | redis=redis,
95 | limiter=SlidingWindow(max_requests=10, window=1, unit="d"),
96 | )
97 |
98 | id = random_id()
99 | assert ratelimit.get_remaining(id) == 10
100 | ratelimit.limit(id)
101 | assert ratelimit.get_remaining(id) == 9
102 |
103 |
104 | def test_get_remaining_with_sliding(redis: Redis) -> None:
105 | ratelimit = Ratelimit(
106 | redis=redis,
107 | limiter=SlidingWindow(max_requests=5, window=3),
108 | )
109 |
110 | id = random_id()
111 |
112 | responses: List[Response] = []
113 |
114 | while True:
115 | response = ratelimit.limit(id)
116 | responses.append(response)
117 |
118 | if len(responses) > 5 and responses[-2].reset != response.reset:
119 | break
120 |
121 | last_response = responses[-1]
122 | assert ratelimit.get_remaining(id) >= last_response.remaining
123 |
124 |
125 | def test_get_reset(redis: Redis) -> None:
126 | ratelimit = Ratelimit(
127 | redis=redis,
128 | limiter=SlidingWindow(max_requests=10, window=5),
129 | )
130 |
131 | with patch("time.time", return_value=1688910786.167):
132 | assert ratelimit.get_reset(random_id()) == approx(1688910790.0)
133 |
134 |
135 | def test_custom_rate(redis: Redis) -> None:
136 | ratelimit = Ratelimit(
137 | redis=redis,
138 | limiter=SlidingWindow(max_requests=10, window=5),
139 | )
140 | rate = 2
141 |
142 | id = random_id()
143 |
144 | ratelimit.limit(id)
145 | ratelimit.limit(id, rate)
146 | assert ratelimit.get_remaining(id) == 7
147 |
148 | ratelimit.limit(id, rate)
149 | assert ratelimit.get_remaining(id) == 5
150 |
--------------------------------------------------------------------------------
/tests/test_token_bucket.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from upstash_redis import Redis
4 |
5 | from tests.utils import random_id
6 | from upstash_ratelimit import Ratelimit, TokenBucket
7 | from upstash_ratelimit.utils import now_s
8 |
9 |
10 | def test_max_tokens_are_not_reached(redis: Redis) -> None:
11 | ratelimit = Ratelimit(
12 | redis=redis,
13 | limiter=TokenBucket(max_tokens=5, refill_rate=5, interval=1, unit="d"),
14 | )
15 |
16 | now = now_s()
17 | response = ratelimit.limit(random_id())
18 |
19 | assert response.allowed is True
20 | assert response.limit == 5
21 | assert response.remaining == 4
22 | assert response.reset >= now
23 |
24 |
25 | def test_max_tokens_are_reached(redis: Redis) -> None:
26 | ratelimit = Ratelimit(
27 | redis=redis,
28 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=1, unit="d"),
29 | )
30 |
31 | id = random_id()
32 |
33 | ratelimit.limit(id)
34 |
35 | now = now_s()
36 | response = ratelimit.limit(id)
37 |
38 | assert response.allowed is False
39 | assert response.limit == 1
40 | assert response.remaining == 0
41 | assert response.reset >= now
42 |
43 |
44 | def test_refill(redis: Redis) -> None:
45 | ratelimit = Ratelimit(
46 | redis=redis,
47 | limiter=TokenBucket(max_tokens=1, refill_rate=1, interval=3),
48 | )
49 |
50 | id = random_id()
51 |
52 | ratelimit.limit(id)
53 |
54 | time.sleep(3)
55 |
56 | now = now_s()
57 | response = ratelimit.limit(id)
58 |
59 | assert response.allowed is True
60 | assert response.limit == 1
61 | assert response.remaining == 0
62 | assert response.reset >= now
63 |
64 |
65 | def test_refill_multiple_times(redis: Redis) -> None:
66 | ratelimit = Ratelimit(
67 | redis=redis,
68 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
69 | )
70 |
71 | id = random_id()
72 |
73 | last_response = None
74 | for _ in range(10):
75 | last_response = ratelimit.limit(id)
76 |
77 | assert last_response is not None
78 | last_remaining = last_response.remaining
79 |
80 | time.sleep(3)
81 |
82 | response = ratelimit.limit(id)
83 | assert response.remaining >= last_remaining + 2
84 |
85 |
86 | def test_get_remaining(redis: Redis) -> None:
87 | ratelimit = Ratelimit(
88 | redis=redis,
89 | limiter=TokenBucket(max_tokens=10, refill_rate=10, interval=1, unit="d"),
90 | )
91 |
92 | id = random_id()
93 | assert ratelimit.get_remaining(id) == 10
94 | ratelimit.limit(id)
95 | assert ratelimit.get_remaining(id) == 9
96 |
97 |
98 | def test_get_remaining_with_refills_that_should_be_made(redis: Redis) -> None:
99 | ratelimit = Ratelimit(
100 | redis=redis,
101 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
102 | )
103 |
104 | id = random_id()
105 |
106 | last_response = None
107 | for _ in range(10):
108 | last_response = ratelimit.limit(id)
109 |
110 | assert last_response is not None
111 | last_remaining = last_response.remaining
112 |
113 | time.sleep(3)
114 |
115 | assert ratelimit.get_remaining(id) >= last_remaining + 2
116 |
117 |
118 | def test_get_reset(redis: Redis) -> None:
119 | ratelimit = Ratelimit(
120 | redis=redis,
121 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
122 | )
123 |
124 | id = random_id()
125 | now = now_s()
126 | ratelimit.limit(id)
127 |
128 | assert ratelimit.get_reset(id) >= now + 0.9
129 |
130 |
131 | def test_get_reset_with_refills_that_should_be_made(redis: Redis) -> None:
132 | ratelimit = Ratelimit(
133 | redis=redis,
134 | limiter=TokenBucket(max_tokens=1000, refill_rate=1, interval=1),
135 | )
136 |
137 | id = random_id()
138 |
139 | last_response = None
140 | for _ in range(10):
141 | last_response = ratelimit.limit(id)
142 |
143 | assert last_response is not None
144 | last_reset = last_response.reset
145 |
146 | time.sleep(3)
147 |
148 | assert ratelimit.get_reset(id) >= last_reset + 2
149 |
150 |
151 | def test_custom_rate(redis: Redis) -> None:
152 | ratelimit = Ratelimit(
153 | redis=redis,
154 | limiter=TokenBucket(max_tokens=10, refill_rate=1, interval=1),
155 | )
156 | rate = 2
157 |
158 | id = random_id()
159 |
160 | ratelimit.limit(id)
161 | ratelimit.limit(id, rate)
162 | assert ratelimit.get_remaining(id) == 7
163 |
164 | ratelimit.limit(id, rate)
165 | assert ratelimit.get_remaining(id) == 5
166 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 |
3 | from pytest import approx, mark, raises
4 |
5 | from upstash_ratelimit.typing import UnitT
6 | from upstash_ratelimit.utils import ms_to_s, now_ms, now_s, s_to_ms, to_ms
7 |
8 |
9 | @mark.parametrize(
10 | "value,unit,expected",
11 | [
12 | (1, "d", 86_400_000),
13 | (4, "h", 14_400_000),
14 | (2, "m", 120_000),
15 | (33, "s", 33_000),
16 | (42, "ms", 42),
17 | ],
18 | )
19 | def test_to_ms(value: int, unit: UnitT, expected: int) -> None:
20 | assert to_ms(value, unit) == expected
21 |
22 |
23 | def test_to_ms_invalid_unit() -> None:
24 | with raises(ValueError):
25 | to_ms(42, "invalid") # type: ignore[arg-type]
26 |
27 |
28 | def test_now_ms() -> None:
29 | with patch("time.time", return_value=42.5):
30 | assert now_ms() == 42_500
31 |
32 |
33 | def test_now_s() -> None:
34 | with patch("time.time", return_value=42.5):
35 | assert now_s() == 42.5
36 |
37 |
38 | def test_s_to_ms() -> None:
39 | assert s_to_ms(12.5) == 12_500
40 |
41 |
42 | def test_ms_to_s() -> None:
43 | assert ms_to_s(44_123) == approx(44.123)
44 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 |
4 | def random_id() -> str:
5 | return str(uuid.uuid4())
6 |
--------------------------------------------------------------------------------
/upstash_ratelimit/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.1.0"
2 |
3 | from upstash_ratelimit.limiter import FixedWindow, Response, SlidingWindow, TokenBucket
4 | from upstash_ratelimit.ratelimit import Ratelimit
5 |
6 | __all__ = [
7 | "Ratelimit",
8 | "FixedWindow",
9 | "SlidingWindow",
10 | "TokenBucket",
11 | "Response",
12 | ]
13 |
--------------------------------------------------------------------------------
/upstash_ratelimit/asyncio/__init__.py:
--------------------------------------------------------------------------------
1 | from upstash_ratelimit.asyncio.ratelimit import Ratelimit
2 | from upstash_ratelimit.limiter import FixedWindow, Response, SlidingWindow, TokenBucket
3 |
4 | __all__ = [
5 | "Ratelimit",
6 | "FixedWindow",
7 | "SlidingWindow",
8 | "TokenBucket",
9 | "Response",
10 | ]
11 |
--------------------------------------------------------------------------------
/upstash_ratelimit/asyncio/ratelimit.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Optional
3 |
4 | from upstash_redis.asyncio import Redis
5 |
6 | from upstash_ratelimit.limiter import Limiter, Response
7 | from upstash_ratelimit.utils import merge_telemetry, now_s
8 |
9 |
10 | class Ratelimit:
11 | """
12 | Provides means of ratelimitting over the HTTP-based
13 | Upstash Redis client.
14 | """
15 |
16 | def __init__(
17 | self, redis: Redis, limiter: Limiter, prefix: str = "@upstash/ratelimit"
18 | ) -> None:
19 | """
20 | :param redis: Upstash Redis instance to use.
21 | :param limiter: Ratelimiter to use. Available limiters are \
22 | `FixedWindow`, `SlidingWindow`, and `TokenBucket` which are provided \
23 | in the `limiter` module.
24 | :param prefix: Prefix to distinguish the keys used in the ratelimit \
25 | logic from others, in case the same Redis instance is reused between \
26 | different applications.
27 | """
28 |
29 | self._redis = redis
30 | merge_telemetry(redis)
31 |
32 | self._limiter = limiter
33 | self._prefix = prefix
34 |
35 | async def limit(self, identifier: str, rate: int = 1) -> Response:
36 | """
37 | Determines if a request should pass or be rejected based on the identifier
38 | and previously chosen ratelimit.
39 |
40 | Use this if you want to reject all requests that you can not handle
41 | right now.
42 |
43 | .. code-block:: python
44 |
45 | from upstash_ratelimit.asyncio import Ratelimit, SlidingWindow
46 | from upstash_redis.asyncio import Redis
47 |
48 | ratelimit = Ratelimit(
49 | redis=Redis.from_env(),
50 | limiter=SlidingWindow(max_requests=10, window=10, unit="s"),
51 | )
52 |
53 | async def main() -> None:
54 | response = await ratelimit.limit("some-id")
55 | if not response.allowed:
56 | print("Ratelimitted!")
57 |
58 | print("Good to go!")
59 |
60 | :param identifier: Identifier to ratelimit. Use a constant string to \
61 | limit all requests, or user ids, API keys, or IP addresses for \
62 | individual limits.
63 | :param rate: Rate with which to subtract from the limit of the \
64 | identifier.
65 | """
66 |
67 | key = f"{self._prefix}:{identifier}"
68 | return await self._limiter.limit_async(self._redis, key, rate)
69 |
70 | async def block_until_ready(self, identifier: str, timeout: float, rate: int = 1) -> Response:
71 | """
72 | Blocks until the request may pass or timeout is reached.
73 |
74 | This method blocks until the request may be processed or the timeout
75 | has been reached.
76 |
77 | Use this if you want to delay the request until it is ready to get
78 | processed.
79 |
80 | .. code-block:: python
81 |
82 | from upstash_ratelimit.asyncio import Ratelimit, SlidingWindow
83 | from upstash_redis.asyncio import Redis
84 |
85 | ratelimit = Ratelimit(
86 | redis=Redis.from_env(),
87 | lmiter=SlidingWindow(max_requests=10, window=10, unit="s"),
88 | )
89 |
90 | async def main() -> None:
91 | response = await ratelimit.block_until_ready("some-id", 60)
92 | if not response.allowed:
93 | print("Ratelimitted!")
94 |
95 | print("Good to go!")
96 |
97 | :param identifier: Identifier to ratelimit. Use a constant string to \
98 | limit all requests, or user ids, API keys, or IP addresses for \
99 | individual limits.
100 | :param timeout: Maximum time in seconds to wait until the request \
101 | may pass.
102 | :param rate: Rate with which to subtract from the limit of the \
103 | identifier.
104 | """
105 |
106 | if timeout <= 0:
107 | raise ValueError("Timeout must be positive")
108 |
109 | response: Optional[Response] = None
110 | deadline = now_s() + timeout
111 |
112 | while True:
113 | response = await self.limit(identifier, rate)
114 | if response.allowed:
115 | break
116 |
117 | wait = max(0, min(response.reset, deadline) - now_s())
118 | await asyncio.sleep(wait)
119 |
120 | if now_s() > deadline:
121 | break
122 |
123 | return response
124 |
125 | async def get_remaining(self, identifier: str) -> int:
126 | """
127 | Returns the number of requests left for the given identifier.
128 | """
129 |
130 | key = f"{self._prefix}:{identifier}"
131 | return await self._limiter.get_remaining_async(self._redis, key)
132 |
133 | async def get_reset(self, identifier: str) -> float:
134 | """
135 | Returns the UNIX timestamp in seconds when the remaining
136 | requests will be reset or replenished.
137 | """
138 |
139 | key = f"{self._prefix}:{identifier}"
140 | return await self._limiter.get_reset_async(self._redis, key)
141 |
--------------------------------------------------------------------------------
/upstash_ratelimit/limiter.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import dataclasses
3 | from collections.abc import Generator
4 | from typing import Any, Callable
5 |
6 | from upstash_redis import Redis
7 | from upstash_redis.asyncio import Redis as AsyncRedis
8 |
9 | from upstash_ratelimit.typing import UnitT
10 | from upstash_ratelimit.utils import ms_to_s, now_ms, to_ms
11 |
12 |
13 | @dataclasses.dataclass
14 | class Response:
15 | allowed: bool
16 | """
17 | Whether the request may pass(`True`) or exceeded the limit(`False`)
18 | """
19 |
20 | limit: int
21 | """
22 | Maximum number of requests allowed within a window.
23 | """
24 |
25 | remaining: int
26 | """
27 | How many requests the user has left within the current window.
28 | """
29 |
30 | reset: float
31 | """
32 | Unix timestamp in seconds when the limits are reset
33 | """
34 |
35 |
36 | class Limiter(abc.ABC):
37 | @abc.abstractmethod
38 | def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response:
39 | pass
40 |
41 | @abc.abstractmethod
42 | async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response:
43 | pass
44 |
45 | @abc.abstractmethod
46 | def get_remaining(self, redis: Redis, identifier: str) -> int:
47 | pass
48 |
49 | @abc.abstractmethod
50 | async def get_remaining_async(self, redis: AsyncRedis, identifier: str) -> int:
51 | pass
52 |
53 | @abc.abstractmethod
54 | def get_reset(self, redis: Redis, identifier: str) -> float:
55 | pass
56 |
57 | @abc.abstractmethod
58 | async def get_reset_async(self, redis: AsyncRedis, identifier: str) -> float:
59 | pass
60 |
61 |
62 | def _with_at_most_one_request(redis: Redis, generator: Generator) -> Any:
63 | """
64 | A function that makes at most one HTTP request over the
65 | given Redis instance.
66 |
67 | If the generator needs to execute a command, this function
68 | takes the command name and args from the generator,
69 | executes the given command, and passes the result
70 | of the command back to the generator. Then, it
71 | takes the final response from the generator and
72 | returns it.
73 |
74 | If the generator does not need to execute a command,
75 | it returns the result directly.
76 | """
77 |
78 | command_name, command_args = next(generator)
79 | if not command_name:
80 | # No need to execute a command
81 | response = next(generator)
82 | return response
83 |
84 | command: Callable = getattr(redis, command_name)
85 | command_response = command(*command_args)
86 | response = generator.send(command_response)
87 | return response
88 |
89 |
90 | async def _with_at_most_one_request_async(
91 | redis: AsyncRedis, generator: Generator
92 | ) -> Any:
93 | """
94 | Async variant of the `_with_one_request_fn` defined above.
95 | """
96 |
97 | command_name, command_args = next(generator)
98 | if not command_name:
99 | # No need to execute a command
100 | response = next(generator)
101 | return response
102 |
103 | command: Callable = getattr(redis, command_name)
104 | command_response = await command(*command_args)
105 | response = generator.send(command_response)
106 | return response
107 |
108 |
109 | class AbstractLimiter(Limiter):
110 | @abc.abstractmethod
111 | def _limit(self, identifier: str, rate: int = 1) -> Generator:
112 | pass
113 |
114 | def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response:
115 | response: Response = _with_at_most_one_request(redis, self._limit(identifier, rate))
116 | return response
117 |
118 | async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response:
119 | response: Response = await _with_at_most_one_request_async(
120 | redis, self._limit(identifier, rate)
121 | )
122 | return response
123 |
124 | @abc.abstractmethod
125 | def _get_remaining(self, identifier: str) -> Generator:
126 | pass
127 |
128 | def get_remaining(self, redis: Redis, identifier: str) -> int:
129 | remaining: int = _with_at_most_one_request(
130 | redis, self._get_remaining(identifier)
131 | )
132 | return remaining
133 |
134 | async def get_remaining_async(self, redis: AsyncRedis, identifier: str) -> int:
135 | remaining: int = await _with_at_most_one_request_async(
136 | redis, self._get_remaining(identifier)
137 | )
138 | return remaining
139 |
140 | @abc.abstractmethod
141 | def _get_reset(self, identifier: str) -> Generator:
142 | pass
143 |
144 | def get_reset(self, redis: Redis, identifier: str) -> float:
145 | reset: float = _with_at_most_one_request(redis, self._get_reset(identifier))
146 | return reset
147 |
148 | async def get_reset_async(self, redis: AsyncRedis, identifier: str) -> float:
149 | reset: float = await _with_at_most_one_request_async(
150 | redis, self._get_reset(identifier)
151 | )
152 | return reset
153 |
154 |
155 | class FixedWindow(AbstractLimiter):
156 | """
157 | The time is divided into windows of fixed length, and each request inside
158 | a window increases a counter.
159 |
160 | Once the counter reaches the maximum allowed number, all further requests
161 | are rejected.
162 |
163 | Pros:
164 | - Newer requests are not starved by old ones.
165 | - Low storage cost.
166 |
167 | Cons:
168 | - A burst of requests near the boundary of a window can result in twice the
169 | rate of requests being processed because two windows will be filled with
170 | requests quickly.
171 | """
172 |
173 | SCRIPT = """
174 | local key = KEYS[1]
175 | local window = ARGV[1]
176 | local increment_by = ARGV[2] -- increment rate per request at a given value, default is 1
177 |
178 | local r = redis.call("INCRBY", key, increment_by)
179 | if r == tonumber(increment_by) then
180 | -- The first time this key is set, the value will be equal to increment_by.
181 | -- So we only need the expire command once
182 | redis.call("PEXPIRE", key, window)
183 | end
184 |
185 | return r
186 | """
187 |
188 | def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None:
189 | """
190 | :param max_requests: Maximum number of requests allowed within a window
191 | :param window: The number of time units in a window
192 | :param unit: The unit of time
193 | """
194 |
195 | assert max_requests > 0
196 | assert window > 0
197 |
198 | self._max_requests = max_requests
199 | self._window = to_ms(window, unit)
200 |
201 | def _limit(self, identifier: str, rate: int = 1) -> Generator:
202 | curr_window = now_ms() // self._window
203 | key = f"{identifier}:{curr_window}"
204 |
205 | num_requests = yield (
206 | "eval",
207 | (FixedWindow.SCRIPT, [key], [self._window, rate]),
208 | )
209 |
210 | yield Response(
211 | allowed=num_requests <= self._max_requests,
212 | limit=self._max_requests,
213 | remaining=max(0, self._max_requests - num_requests),
214 | reset=ms_to_s((curr_window + 1) * self._window),
215 | )
216 |
217 | def _get_remaining(self, identifier: str) -> Generator:
218 | curr_window = now_ms() // self._window
219 | key = f"{identifier}:{curr_window}"
220 |
221 | num_requests = yield (
222 | "get",
223 | (key,),
224 | )
225 |
226 | if num_requests is None:
227 | yield self._max_requests
228 |
229 | yield max(0, self._max_requests - int(num_requests)) # type: ignore[arg-type]
230 |
231 | def _get_reset(self, _: str) -> Generator:
232 | yield (None, None) # Signal that we don't need to make a remote call
233 |
234 | curr_window = now_ms() // self._window
235 | yield ms_to_s((curr_window + 1) * self._window)
236 |
237 |
238 | class SlidingWindow(AbstractLimiter):
239 | """
240 | Combined approach of sliding logs and fixed window with lower storage
241 | costs than sliding logs and improved boundary behavior by calculating a
242 | weighted score between two windows.
243 |
244 | Pros:
245 | - Good performance allows this to scale to very high loads.
246 | """
247 |
248 | SCRIPT = """
249 | local current_key = KEYS[1] -- identifier including prefixes
250 | local previous_key = KEYS[2] -- key of the previous bucket
251 | local tokens = tonumber(ARGV[1]) -- tokens per window
252 | local now = ARGV[2] -- current timestamp in milliseconds
253 | local window = ARGV[3] -- interval in milliseconds
254 | local increment_by = ARGV[4] -- increment rate per request at a given value, default is 1
255 |
256 | local requests_in_current_window = redis.call("GET", current_key)
257 | if requests_in_current_window == false then
258 | requests_in_current_window = 0
259 | end
260 |
261 | local requests_in_previous_window = redis.call("GET", previous_key)
262 | if requests_in_previous_window == false then
263 | requests_in_previous_window = 0
264 | end
265 | local percentage_in_current = ( now % window ) / window
266 | -- weighted requests to consider from the previous window
267 | requests_in_previous_window = math.floor(( 1 - percentage_in_current ) * requests_in_previous_window)
268 | if requests_in_previous_window + requests_in_current_window >= tokens then
269 | return -1
270 | end
271 |
272 | local new_value = redis.call("INCRBY", current_key, increment_by)
273 | if new_value == tonumber(increment_by) then
274 | -- The first time this key is set, the value will be equal to increment_by.
275 | -- So we only need the expire command once
276 | redis.call("PEXPIRE", current_key, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second
277 | end
278 | return tokens - ( new_value + requests_in_previous_window )
279 | """
280 |
281 | def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None:
282 | """
283 | :param max_requests: Maximum number of requests allowed within a window
284 | :param window: The number of time units in a window
285 | :param unit: The unit of time
286 | """
287 |
288 | assert max_requests > 0
289 | assert window > 0
290 |
291 | self._max_requests = max_requests
292 | self._window = to_ms(window, unit)
293 |
294 | def _limit(self, identifier: str, rate: int = 1) -> Generator:
295 | now = now_ms()
296 |
297 | curr_window = now // self._window
298 | key = f"{identifier}:{curr_window}"
299 |
300 | prev_window = curr_window - 1
301 | prev_key = f"{identifier}:{prev_window}"
302 |
303 | remaining = yield (
304 | "eval",
305 | (
306 | SlidingWindow.SCRIPT,
307 | [key, prev_key],
308 | [self._max_requests, now, self._window, rate],
309 | ),
310 | )
311 |
312 | yield Response(
313 | allowed=remaining >= 0,
314 | limit=self._max_requests,
315 | remaining=max(0, remaining),
316 | reset=ms_to_s((curr_window + 1) * self._window),
317 | )
318 |
319 | def _get_remaining(self, identifier: str) -> Generator:
320 | now = now_ms()
321 |
322 | window = now // self._window
323 | key = f"{identifier}:{window}"
324 |
325 | prev_window = window - 1
326 | prev_key = f"{identifier}:{prev_window}"
327 |
328 | num_requests_, prev_num_requests_ = yield (
329 | "mget",
330 | (key, prev_key),
331 | )
332 | num_requests = int(num_requests_ or 0)
333 | prev_num_requests = int(prev_num_requests_ or 0)
334 |
335 | prev_window_weight = 1 - ((now % self._window) / self._window)
336 | prev_num_requests = int(prev_num_requests * prev_window_weight)
337 |
338 | remaining = self._max_requests - (prev_num_requests + num_requests)
339 | yield max(0, remaining)
340 |
341 | def _get_reset(self, _: str) -> Generator:
342 | yield (None, None) # Signal that we don't need to make a remote call
343 |
344 | curr_window = now_ms() // self._window
345 | yield ms_to_s((curr_window + 1) * self._window)
346 |
347 |
348 | class TokenBucket(AbstractLimiter):
349 | """
350 | A bucket is filled with maximum number of tokens that refill at a given
351 | rate per interval.
352 |
353 | Each request tries to consume one token and if the bucket is empty,
354 | the request is rejected.
355 |
356 | Pros:
357 | - Bursts of requests are smoothed out so that they can be processed at
358 | a constant rate.
359 | - Allows to set a higher initial burst limit by setting maximum number
360 | of tokens higher than the refill rate.
361 | """
362 |
363 | SCRIPT = """
364 | local key = KEYS[1] -- identifier including prefixes
365 | local max_tokens = tonumber(ARGV[1]) -- maximum number of tokens
366 | local interval = tonumber(ARGV[2]) -- size of the window in milliseconds
367 | local refill_rate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval
368 | local now = tonumber(ARGV[4]) -- current timestamp in milliseconds
369 | local increment_by = tonumber(ARGV[5]) -- how many tokens to consume, default is 1
370 |
371 | local bucket = redis.call("HMGET", key, "refilled_at", "tokens")
372 |
373 | local refilled_at
374 | local tokens
375 |
376 | if bucket[1] == false then
377 | refilled_at = now
378 | tokens = max_tokens
379 | else
380 | refilled_at = tonumber(bucket[1])
381 | tokens = tonumber(bucket[2])
382 | end
383 |
384 | if now >= refilled_at + interval then
385 | local num_refills = math.floor((now - refilled_at) / interval)
386 | tokens = math.min(max_tokens, tokens + num_refills * refill_rate)
387 |
388 | refilled_at = refilled_at + num_refills * interval
389 | end
390 |
391 | if tokens == 0 then
392 | return {-1, refilled_at + interval}
393 | end
394 |
395 | local remaining = tokens - increment_by
396 | local expire_at = math.ceil(((max_tokens - remaining) / refill_rate)) * interval
397 |
398 | redis.call("HSET", key, "refilled_at", refilled_at, "tokens", remaining)
399 | redis.call("PEXPIRE", key, expire_at)
400 | return {remaining, refilled_at + interval}
401 | """
402 |
403 | def __init__(
404 | self, max_tokens: int, refill_rate: int, interval: int, unit: UnitT = "s"
405 | ) -> None:
406 | """
407 | :param max_tokens: Maximum number of tokens in a bucket. Since a newly
408 | created bucket starts with this many tokens, it can be used to
409 | allow higher burst limits.
410 | :param refill_rate: The number of tokens that are refilled per interval
411 | :param interval: The number of time units between each refill
412 | :param unit: The unit of time
413 | """
414 | assert max_tokens > 0
415 | assert refill_rate > 0
416 | assert interval > 0
417 |
418 | self._max_tokens = max_tokens
419 | self._refill_rate = refill_rate
420 | self._interval = to_ms(interval, unit)
421 |
422 | def _limit(self, identifier: str, rate: int = 1) -> Generator:
423 | remaining, refill_at = yield (
424 | "eval",
425 | (
426 | TokenBucket.SCRIPT,
427 | [identifier],
428 | [self._max_tokens, self._interval, self._refill_rate, now_ms(), rate],
429 | ),
430 | )
431 |
432 | yield Response(
433 | allowed=remaining >= 0,
434 | limit=self._max_tokens,
435 | remaining=max(0, remaining),
436 | reset=ms_to_s(refill_at),
437 | )
438 |
439 | def _get_remaining(self, identifier: str) -> Generator:
440 | now = now_ms()
441 |
442 | refilled_at_, tokens_ = yield (
443 | "hmget",
444 | (identifier, "refilled_at", "tokens"),
445 | )
446 |
447 | if refilled_at_ is None:
448 | yield self._max_tokens
449 |
450 | refilled_at = int(refilled_at_) # type: ignore[arg-type]
451 | tokens = int(tokens_) # type: ignore[arg-type]
452 |
453 | if now >= refilled_at + self._interval:
454 | num_refills = (now - refilled_at) // self._interval
455 | tokens = min(self._max_tokens, tokens + num_refills * self._refill_rate)
456 |
457 | yield tokens
458 |
459 | def _get_reset(self, identifier: str) -> Generator:
460 | now = now_ms()
461 |
462 | refilled_at_ = yield (
463 | "hget",
464 | (identifier, "refilled_at"),
465 | )
466 |
467 | if refilled_at_ is None:
468 | yield ms_to_s(now)
469 |
470 | refilled_at = int(refilled_at_) # type: ignore[arg-type]
471 | if now >= refilled_at + self._interval:
472 | num_refills = (now - refilled_at) // self._interval
473 | refilled_at = refilled_at + num_refills * self._interval
474 |
475 | yield ms_to_s(refilled_at + self._interval)
476 |
--------------------------------------------------------------------------------
/upstash_ratelimit/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/upstash/ratelimit-py/bd70dfdb93ccfc0fea7a3d854c7942f8dce21f0a/upstash_ratelimit/py.typed
--------------------------------------------------------------------------------
/upstash_ratelimit/ratelimit.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Optional
3 |
4 | from upstash_redis import Redis
5 |
6 | from upstash_ratelimit.limiter import Limiter, Response
7 | from upstash_ratelimit.utils import merge_telemetry, now_s
8 |
9 |
10 | class Ratelimit:
11 | """
12 | Provides means of ratelimitting over the HTTP-based
13 | Upstash Redis client.
14 | """
15 |
16 | def __init__(
17 | self, redis: Redis, limiter: Limiter, prefix: str = "@upstash/ratelimit"
18 | ) -> None:
19 | """
20 | :param redis: Upstash Redis instance to use.
21 | :param limiter: Ratelimiter to use. Available limiters are \
22 | `FixedWindow`, `SlidingWindow`, and `TokenBucket` which are provided \
23 | in the `limiter` module.
24 | :param prefix: Prefix to distinguish the keys used in the ratelimit \
25 | logic from others, in case the same Redis instance is reused between \
26 | different applications.
27 | """
28 |
29 | self._redis = redis
30 | merge_telemetry(redis)
31 |
32 | self._limiter = limiter
33 | self._prefix = prefix
34 |
35 | def limit(self, identifier: str, rate: int = 1) -> Response:
36 | """
37 | Determines if a request should pass or be rejected based on the identifier
38 | and previously chosen ratelimit.
39 |
40 | Use this if you want to reject all requests that you can not handle
41 | right now.
42 |
43 | .. code-block:: python
44 |
45 | from upstash_redis import Redis
46 | from upstash_ratelimit import Ratelimit, SlidingWindow
47 |
48 | ratelimit = Ratelimit(
49 | redis=Redis.from_env(),
50 | limiter=SlidingWindow(max_requests=10, window=10, unit="s"),
51 | )
52 |
53 | response = ratelimit.limit("some-id")
54 | if not response.allowed:
55 | print("Ratelimitted!")
56 |
57 | print("Good to go!")
58 |
59 | :param identifier: Identifier to ratelimit. Use a constant string to \
60 | limit all requests, or user ids, API keys, or IP addresses for \
61 | individual limits.
62 | :param rate: Rate with which to subtract from the limit of the \
63 | identifier.
64 | """
65 |
66 | key = f"{self._prefix}:{identifier}"
67 | return self._limiter.limit(self._redis, key, rate)
68 |
69 | def block_until_ready(self, identifier: str, timeout: float, rate: int = 1) -> Response:
70 | """
71 | Blocks until the request may pass or timeout is reached.
72 |
73 | This method blocks until the request may be processed or the timeout
74 | has been reached.
75 |
76 | Use this if you want to delay the request until it is ready to get
77 | processed.
78 |
79 | .. code-block:: python
80 |
81 | from upstash_redis import Redis
82 | from upstash_ratelimit import Ratelimit, SlidingWindow
83 |
84 | ratelimit = Ratelimit(
85 | redis=Redis.from_env(),
86 | limiter=SlidingWindow(max_requests=10, window=10, unit="s"),
87 | )
88 |
89 | response = ratelimit.block_until_ready("some-id", 60)
90 | if not response.allowed:
91 | print("Ratelimitted!")
92 |
93 | print("Good to go!")
94 |
95 | :param identifier: Identifier to ratelimit. Use a constant string to \
96 | limit all requests, or user ids, API keys, or IP addresses for \
97 | individual limits.
98 | :param timeout: Maximum time in seconds to wait until the request \
99 | may pass.
100 | :param rate: Rate with which to subtract from the limit of the \
101 | identifier.
102 | """
103 |
104 | if timeout <= 0:
105 | raise ValueError("Timeout must be positive")
106 |
107 | response: Optional[Response] = None
108 | deadline = now_s() + timeout
109 |
110 | while True:
111 | response = self.limit(identifier, rate)
112 | if response.allowed:
113 | break
114 |
115 | wait = max(0, min(response.reset, deadline) - now_s())
116 | time.sleep(wait)
117 |
118 | if now_s() > deadline:
119 | break
120 |
121 | return response
122 |
123 | def get_remaining(self, identifier: str) -> int:
124 | """
125 | Returns the number of requests left for the given identifier.
126 | """
127 |
128 | key = f"{self._prefix}:{identifier}"
129 | return self._limiter.get_remaining(self._redis, key)
130 |
131 | def get_reset(self, identifier: str) -> float:
132 | """
133 | Returns the UNIX timestamp in seconds when the remaining
134 | requests will be reset or replenished.
135 | """
136 |
137 | key = f"{self._prefix}:{identifier}"
138 | return self._limiter.get_reset(self._redis, key)
139 |
--------------------------------------------------------------------------------
/upstash_ratelimit/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | UnitT = Literal["ms", "s", "m", "h", "d"]
4 | """
5 | ms: milliseconds
6 | s: seconds
7 | m: minutes
8 | h: hours
9 | d: days
10 | """
11 |
--------------------------------------------------------------------------------
/upstash_ratelimit/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Any
3 |
4 | from upstash_ratelimit import __version__
5 | from upstash_ratelimit.typing import UnitT
6 |
7 |
8 | def merge_telemetry(redis: Any) -> None:
9 | if (
10 | not hasattr(redis, "_allow_telemetry")
11 | or not redis._allow_telemetry
12 | or not hasattr(redis, "_headers")
13 | ):
14 | return
15 |
16 | sdk = redis._headers.get("Upstash-Telemetry-Sdk")
17 | if not sdk:
18 | return
19 |
20 | sdk = f"{sdk}, py-upstash-ratelimit@v{__version__}"
21 | redis._headers["Upstash-Telemetry-Sdk"] = sdk
22 |
23 |
24 | def ms_to_s(value: int) -> float:
25 | return value / 1_000
26 |
27 |
28 | def s_to_ms(value: float) -> int:
29 | return int(value * 1_000)
30 |
31 |
32 | def now_s() -> float:
33 | return time.time()
34 |
35 |
36 | def now_ms() -> int:
37 | return int(time.time() * 1_000)
38 |
39 |
40 | def to_ms(value: int, unit: UnitT) -> int:
41 | if unit == "ms":
42 | return value
43 | elif unit == "s":
44 | return value * 1_000
45 | elif unit == "m":
46 | return value * 60 * 1_000
47 | elif unit == "h":
48 | return value * 60 * 60 * 1_000
49 | elif unit == "d":
50 | return value * 24 * 60 * 60 * 1_000
51 | else:
52 | raise ValueError("Unexpected unit")
53 |
--------------------------------------------------------------------------------