├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── Dockerfile.test ├── LICENSE ├── README.md ├── justredis ├── __init__.py ├── decoder.py ├── encoder.py ├── errors.py ├── nonsync │ ├── __init__.py │ ├── cluster.py │ ├── connection.py │ ├── connectionpool.py │ ├── environment.py │ ├── environments │ │ ├── __init__.py │ │ └── anyio.py │ └── redis.py ├── sync │ ├── __init__.py │ ├── cluster.py │ ├── connection.py │ ├── connectionpool.py │ ├── environment.py │ ├── environments │ │ ├── __init__.py │ │ └── threaded.py │ └── redis.py └── utils.py ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── async │ ├── __init__.py │ ├── conftest.py │ └── test_async.py ├── conftest.py ├── redis_server.py ├── test_basic.py └── test_example.py └── tox.ini /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release a version 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.x 17 | - name: Install dependencies 18 | run: pip install -U pip setuptools wheel twine 19 | - name: Create packages 20 | run: python setup.py sdist bdist_wheel 21 | - name: Check packages 22 | run: twine check dist/* 23 | - name: Publish a Python distribution to PyPI 24 | uses: pypa/gh-action-pypi-publish@master 25 | with: 26 | user: __token__ 27 | password: ${{ secrets.pypi_password }} 28 | - name: Create Release 29 | id: create_release 30 | uses: actions/create-release@v1 31 | env: 32 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 33 | with: 34 | tag_name: ${{ github.ref }} 35 | release_name: ${{ github.ref }} 36 | draft: false 37 | prerelease: false 38 | - uses: olegtarasov/get-tag@v2 39 | id: tagName 40 | with: 41 | tagRegex: "v(.*)" 42 | tagRegexGroup: 1 43 | - name: Upload Release Asset 44 | id: upload-release-asset 45 | uses: actions/upload-release-asset@v1 46 | env: 47 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 48 | with: 49 | upload_url: ${{ steps.create_release.outputs.upload_url }} 50 | asset_path: dist/justredis-${{ steps.tagName.outputs.tag }}-py3-none-any.whl 51 | asset_name: justredis-${{ steps.tagName.outputs.tag }}-py3-none-any.whl 52 | asset_content_type: application/zip -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: [push] 4 | 5 | jobs: 6 | linux: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Build the testing Docker image 11 | run: docker build . --file Dockerfile.test --network host 12 | mac: 13 | runs-on: macOS-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Checking via homebrew 17 | run: | 18 | CI=1 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)" > /dev/null 2>&1 19 | brew install python redis > /dev/null 2>&1 20 | python3 -m pip install tox > /dev/null 2>&1 21 | python3 -m tox -e py38 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /venv*/ 3 | /.vscode/ 4 | .coverage 5 | /*.egg-info/ 6 | /.tox/ 7 | /htmlcov/ 8 | /dist/ 9 | /build/ 10 | -------------------------------------------------------------------------------- /Dockerfile.test: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=quay.io/pypa/manylinux2014_x86_64 2 | FROM $BASE_IMAGE 3 | 4 | ARG REDIS5_VERSION=5.0.9 5 | ARG REDIS6_VERSION=6.0.8 6 | RUN cd /tmp && mkdir redis5 && curl -s https://download.redis.io/releases/redis-${REDIS5_VERSION}.tar.gz | tar -xvzo -C redis5 --strip-components=1 > /dev/null 2>&1 && cd redis5 && make > /dev/null 2>&1 7 | RUN cd /tmp && mkdir redis6 && curl -s https://download.redis.io/releases/redis-${REDIS6_VERSION}.tar.gz | tar -xvzo -C redis6 --strip-components=1 > /dev/null 2>&1 && cd redis6 && make > /dev/null 2>&1 8 | 9 | ARG PYPY3_VERSION=7.3.2 10 | RUN cd /opt/python && mkdir pypy3 && curl -s -L https://downloads.python.org/pypy/pypy3.6-v${PYPY3_VERSION}-linux64.tar.bz2 | tar -xvjo -C pypy3 --strip-components=1 > /dev/null 2>&1 11 | 12 | #WORKDIR /opt 13 | #ADD https://api.github.com/repos/tzickel/justredis/git/refs/heads/master version.json 14 | #RUN git clone https://github.com/tzickel/justredis 15 | 16 | WORKDIR /opt/justredis 17 | 18 | ADD . . 19 | 20 | RUN /opt/python/cp38-cp38/bin/pip install -U tox pip setuptools wheel > /dev/null 2>&1 21 | RUN REDIS_6_PATH=/tmp/redis6/src REDIS_5_PATH=/tmp/redis5/src PATH=$PATH:/opt/python/cp39-cp39/bin:/opt/python/cp38-cp38/bin:/opt/python/cp37-cp37m/bin:/opt/python/cp36-cp36m/bin:/opt/python/cp35-cp35m/bin:/opt/python/pypy3/bin /opt/python/cp38-cp38/bin/tox 22 | RUN /opt/python/cp38-cp38/bin/python setup.py bdist_wheel -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 tzickel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## What ? 2 | 3 | A Redis client for Python supporting many Redis features and Python synchronous (Python 3.5+) and asynchronous (Python 3.6+) communication. 4 | 5 | ## [Why](https://xkcd.com/927/) ? 6 | 7 | - Transparent API (Just call the Redis commands, and the library will figure out cluster routing, script caching, etc...) 8 | - Per context and command properties (database #, decoding, RESP3 attributes) 9 | - Asynchronous I/O support with the same exact API (but with the await keyword), targeting asyncio, trio and curio (using [AnyIO](https://github.com/agronholm/anyio) which needs to be installed as well if you want async I/O) 10 | - Modular API allowing for easy support for multiple synchronous and asynchronous event loops and disabling of unneeded features 11 | - CI Testing for CPython 3.5, 3.6, 3.7, 3.8, 3.9 and PyPy3 with Redis 5 and Redis 6 12 | - No legacy support for old language features 13 | 14 | ## Redis features support table 15 | 16 | | Feature | Supported | Notes | 17 | | --- | --- | --- | 18 | | [Transactions](https://redis.io/topics/transactions) | V | See [examples](#examples) and [Transaction section](#redis-command-replacements) | 19 | | [Pub/Sub](https://redis.io/topics/pubsub) | V | See [examples](#examples) and [Pub/Sub and monitor section](#redis-command-replacements) | 20 | | [Pipelining](https://redis.io/topics/pipelining) | V | See [examples](#examples) and [Pipelining section](#pipelining) | 21 | | [Cluster](https://redis.io/topics/cluster-spec) | V | See [Cluster commands](#cluster-commands) | 22 | | [RESP3 support](https://github.com/antirez/RESP3/blob/master/spec.md) | V | See [RESP2 and RESP3 difference section](#resp2-and-resp3-difference) | 23 | | [SSL](https://redis.io/topics/encryption) | V | See the [SSL connection parameters](#settings-options) | 24 | | [Script caching](https://redis.io/commands/evalsha) | X | | 25 | | [Client side caching](https://redis.io/topics/client-side-caching) | X | | 26 | | [Sentinel](https://redis.io/topics/sentinel) | X | | 27 | 28 | ## Roadmap 29 | 30 | Getting it out of alpha: 31 | - [ ] API Finalization (your feedback is appreciated) 32 | - [ ] Should calling redis be using \_\_call__() or a special method such as "command" for refactoring ? 33 | - [ ] Is the modify() API flexiable enough ? 34 | - [ ] Is it easy to extend the module ? 35 | - [ ] Should we add helper functions for SCAN iterators and other API ? 36 | - [ ] Resolving some of the TODO in code 37 | 38 | General: 39 | - More features in the support table 40 | - Better test coverage 41 | - Resolve all TODO in code 42 | - Move documentation to topics + docstrings 43 | - Automate code convertion between sync and async 44 | - More CI checks such as flake8, pylint, etc.. 45 | 46 | ## Not on roadmap (for now?) 47 | 48 | - High level features which are not part of the Redis specification (such as locks, retry transactions, etc...) 49 | - Manual command interface (maybe for special stuff like bit operations ?) 50 | - Python 2 support (can be added, but after removing some misc syntax features) 51 | 52 | ## Installing 53 | 54 | The project can be found in PyPI as justredis. Install it via pip or requierments.txt file: 55 | 56 | ```bash 57 | pip install justredis 58 | ``` 59 | 60 | If you want to use asynchronous I/O frameworks asyncio or trio or curio with this library, you need to install the AnyIO library as well: 61 | 62 | ```bash 63 | pip install anyio 64 | ``` 65 | 66 | ## Examples 67 | 68 | ```python 69 | from justredis import Redis 70 | 71 | 72 | # Let's connect to localhost:6379 and decode the string results as utf-8 strings 73 | r = Redis(decoder="utf8") 74 | assert r("set", "a", "b") == "OK" 75 | assert r("get", "a") == "b" 76 | assert r("get", "a", decoder=None) == b"b" # But this can be changed per command 77 | 78 | 79 | # We can even run commands on a different database number 80 | with r.modify(database=1) as r1: 81 | assert r1("get", "a") == None # In this database, a was not set to b 82 | 83 | 84 | # Here we can use a transactional set of commands 85 | # Notice that when we take a connection, if we plan on cluster support, we need 86 | # to tell it a key we plan on using inside, or a specific server address 87 | with r.connection(key="a") as c: 88 | c("multi") 89 | c("set", "a", "b") 90 | c("get", "a") 91 | assert c("exec") == ["OK", "b"] 92 | 93 | 94 | # Or we can just pipeline the commands from before 95 | with r.connection(key="a") as c: 96 | result = c(("multi", ), ("set", "a", "b"), ("get", "a"), ("exec", ))[-1] 97 | assert result == ["OK", "b"] 98 | 99 | 100 | # Here is the famous increment example 101 | # Notice we take the connection inside the loop, 102 | # this is to make sure if the cluster moved the keys, it will still be ok 103 | while True: 104 | with r.connection(key="counter") as c: 105 | c("watch", "counter") 106 | value = int(c("get", "counter") or 0) 107 | c("multi") 108 | c("set", "counter", value + 1) 109 | if c("exec") is None: # Redis returns None in EXEC command when the transaction failed 110 | continue 111 | value += 1 # The value is updated if we got here 112 | break 113 | 114 | 115 | # Let's show some publish and subscribe commands, 116 | # here we use a push connection (where commands have no direct response) 117 | with r.connection(push=True) as p: 118 | p("subscribe", "hello") 119 | assert p.next_message() == ["subscribe", "hello", 1] 120 | assert p.next_message(timeout=0.1) == None # Let's wait 0.1 seconds for another result 121 | r("publish", "hello", ", World !") 122 | assert p.next_message() == ["message", "hello", ", World !"] 123 | ``` 124 | 125 | ## API 126 | 127 | ```python 128 | Redis(**kwargs) 129 | @classmethod 130 | from_url(url, **kwargs) 131 | __enter__() / __exit__() 132 | close() 133 | # kwargs options = endpoint, decoder, attributes, database 134 | __call__(*cmd, **kwargs) 135 | endpoints() 136 | # kwargs options = decoder, attributes, database 137 | modify(**kwargs) # Returns a modified settings instance (while sharing the pool) 138 | # kwargs options = key, endpoint, decoder, attributes, database 139 | connection(push=False, **kwargs) 140 | __enter__() / __exit__() 141 | close() 142 | # kwargs options = decoder, attributes, database 143 | __call__(*cmd, **kwargs) # On push connection no result for calls 144 | # kwargs options = decoder, attributes, database 145 | modify(**kwargs) # Returns a modified settings instance (while sharing the connection) 146 | 147 | # Push connection only commands 148 | # kwargs options = decoder, attributes 149 | next_message(timeout=None, **kwargs) 150 | __iter__() 151 | __next__() 152 | ``` 153 | 154 | ### URI connection options 155 | 156 | ```Redis.from_url()``` options are: 157 | 158 | Regular TCP connection 159 | ``` 160 | redis://[[username:]password@]host[:port][/database][[?option1=value1][&option2=value2]] 161 | ``` 162 | 163 | SSL TCP connection (you can use ssl instead of rediss) 164 | ``` 165 | rediss://[[username:]password@]host[:port][/database][[?option1=value1][&option2=value2]] 166 | ``` 167 | 168 | Unix domain connection (you can use unix instead of redis-socket) 169 | ``` 170 | redis-socket://[[username:]password@]path][[?option1=value1][&option2=value2]] 171 | ``` 172 | 173 | For cluster, you can replace host:port with a list of host1:port1,host2:port2,... if you want fallback options for backup. 174 | 175 | You can add options in the end from the Redis constructor options below. 176 | 177 | ### Settings options 178 | 179 | This are the ```Redis()``` constructor options: 180 | ``` 181 | pool_factory ("auto") 182 | "auto" / "cluster" - Try to figure out automatically what the Redis server type is (currently cluster / no-cluster) 183 | "pool" - Force non cluster aware connection pool (simpler code) 184 | address (None) 185 | An (address, port) tuple for tcp sockets, the default is (localhost, 6379) 186 | An string if it's a path for unix domain sockets, the default is "/tmp/redis.sock" 187 | username (None) 188 | If you have an ACL username, specify it here 189 | password (None) 190 | If you have an AUTH / ACL password, specify it here 191 | client_name (None) 192 | If you want your client to be named on connection specify it here 193 | resp_version (2) 194 | Specifies which RESP protocol version to use for connections 195 | -1 = auto detect 196 | 2 = RESP2 197 | 3 = RESP3 198 | socket_factory ("tcp") 199 | Specifies which socket type to use to connect to the redis server 200 | "tcp" tcp socket 201 | "unix" unix domain socket 202 | "ssl" tcp ssl socket 203 | connect_retry (2) 204 | How many attempts to retry connecting when establishing a new connection 205 | max_connections (None) 206 | How many maximum concurrent connections to keep to a server in the connection pool, the default is unlimited 207 | wait_timeout (None) 208 | How long (float seconds) to wait for a connection when the connection pool is full before returning an timeout error, the default is unlimited 209 | cutoff_size (6000) 210 | The maximum ammount of bytes that will be appended together instead of sent seperatly before sending data to the socket, 0 to disable this feature 211 | custom_command_class (None) 212 | Register a custom class to extend redis server commands handling 213 | encoder (None) 214 | Specify how to encode strings to bytes, it can be a string, list or dictionary that are passed directly as the parameters to str.encode, the default is "utf8" 215 | connect_timeout (None) 216 | How many (float seconds) to wait for a connection with a server to be established, the default is unlimited 217 | socket_timeout (None) 218 | How many (float seconds) to wait for a socket operation (read/write) with a server, the default is unlimited 219 | ``` 220 | 221 | This parameters can be passed to the ```Redis()``` constructor, or to the ```modify()``` method or per ```__call__()```: 222 | ``` 223 | decoder (None) 224 | Specify how to decode the string results from the server, it can be a string, list or dictionary that are passed directly as parameters to bytes.decode, the default is normal bytes conversion 225 | attributes (False) 226 | Specify if you want to handle the attributes fields from the RESP3 protocol (read the special section about this in the readme) 227 | database (None) 228 | Set which database to operate on the server, the default is 0 229 | ``` 230 | 231 | This can be provided to the ```Redis()``` constructor if you are using the cluster pool_factory: 232 | ``` 233 | addresses (None) 234 | Multiple (address, port) tuples for cluster ips for fallback. The default is ((localhost, 6379), ) 235 | ``` 236 | 237 | This can be provided to the ```Redis()``` constructor for tcp and ssl socket_factory: 238 | ``` 239 | tcp_keepalive (None) 240 | How many seconds to check the TCP connection liveness, the default is disabled 241 | tcp_nodelay (True) 242 | Enable or disable the TCP nodelay algorithm 243 | ``` 244 | 245 | This can be provided to the ```Redis()``` constructor for ssl socket_factory: 246 | ``` 247 | ssl_context (None) 248 | An Python SSL context object, the default is Python's ssl.create_default_context() 249 | ssl_cafile (None) 250 | A path to the CA certificate file on disk, works only if ssl_context is None 251 | ssl_certfile (None) 252 | A path to the server certificate file on disk, works only if ssl_context is None 253 | ssl_keyfile (None) 254 | A path to the server key file on disk, works only if ssl_context is None 255 | ``` 256 | 257 | Read the cluster and connection documentation below for the options for the ```connection()``` and ```__call__()``` API 258 | 259 | ### Exceptions 260 | 261 | ``` 262 | ValueError - Will be thrown when an invalid input was given. Nothing will be sent to the server. 263 | Error - Will be thrown when the server returned an error to a request. 264 | PipelinedExceptions - Will be thrown when some of the pipeline failed. 265 | RedisError - Will be thrown when an internal logic error has happened. 266 | CommunicationError - An I/O error has occured. 267 | ConnectionPoolError - The connection pool could not get a new connection. 268 | ProtocolError - Invalid input from the server. 269 | ``` 270 | 271 | ## Redis command replacements 272 | 273 | The following Redis commands should not be called directly, but via the library API: 274 | 275 | ### Username and Password (AUTH / ACL) 276 | 277 | If you have a username or/and password you want to use, pass them to the connection constructor, such as: 278 | 279 | ```python 280 | r = Redis(username="your_username", password="your_password") 281 | ``` 282 | 283 | ### Database selection (SELECT) 284 | 285 | You can specify the default database you want to use at the constructor: 286 | 287 | ```python 288 | r = Redis(database=1) 289 | ``` 290 | 291 | If you want to modify it afterwards for a specific set of commands, you can use a modify context for it: 292 | 293 | ```python 294 | with r.modify(database=2) as r1: 295 | r1("set", "a", "b") 296 | ``` 297 | 298 | ### Transaction (WATCH / MULTI / EXEC / DISCARD) 299 | 300 | To use the transaction commands, you must take a connection, and use all the commands inside. Please read the [Connection section](#connection-commands) below for more details. 301 | 302 | ### Pub/Sub and monitor (SUBSCRIBE / PSUBSCRIBE / UNSUBSCRIBE / PUNSUBSCRIBE / MONITOR) 303 | 304 | To use push commands, you must take a push connection, and use all the commands inside. Please read the [Connection section](#connection-commands) below for more details. 305 | 306 | ## Usage 307 | 308 | ### Connection commands 309 | 310 | The ```connection()``` method is required to be used for sending multiple commands to the same server (such as transactions) or to talk to the server in push mode (pub/sub and monitor). 311 | 312 | You can pass to the method ```push=True``` for push mode where commands have no direct response (else it defaults to a normal connection). 313 | 314 | While you do not have to pass a ```key=```, it's better to provide one you are about to use inside, in case you want to talk to a cluster later on. 315 | 316 | There are some instances you might want to talk to a specific server in a cluster (like getting keyspace notifications from it), so you can pass ```endpoint=``` instead of ```key=``` with that server's address. 317 | 318 | Check the [transaction or pubsub examples](#examples) above for syntax usage. 319 | 320 | ### Pipelining 321 | 322 | You can pipeline multiple commands together by passing an list of commands to be sent together. This is usually to have better latency. 323 | 324 | Notice that if you are talking to a cluster, the pipeline must contain commands which handle keys in the same keyslots of a given server. 325 | 326 | If some of the commands failed, an PipelinedExceptions exception will be thrown, with it's args pointing to the result of each command. 327 | 328 | Check the [pipeline example](#examples) above for syntax usage. 329 | 330 | ### Cluster commands 331 | 332 | Currently the library supports talking to Redis master servers only. It knows automatically when you are connected to a cluster (unless you disabled that feature in the constructor settings explicitly). 333 | 334 | If you want to specify multiple addresses for redundency, you can do so: 335 | 336 | ```python 337 | r = Redis(addresses=(('host1', port1), ('host2', port2))) 338 | ``` 339 | 340 | You can get the list of servers with the ```endpoints()``` method. 341 | 342 | You can also send a command to all the masters by adding ```endpoint='masters'``` to the ```__call__()```: 343 | 344 | ```python 345 | r("cluster", "info", endpoint="masters") 346 | ``` 347 | 348 | You can also open a connection to a specific instance, for example to get key space notifications or monitor it by adding ```endpoint=``` to the ```connection()``` method). 349 | 350 | ### RESP2 and RESP3 difference 351 | 352 | The library supports talking both in RESP2 and RESP3. By default it will use RESP2, because this way you'll get same response whether you are talking to a RESP3 supporting server (Redis server version 6 and above) or not. 353 | 354 | You can still tell it to use RESP3 or to auto negotiate the highest version with the specific server: 355 | 356 | ```python 357 | r = Redis(resp_version=2) # Talk RESP2 only 358 | r = Redis(resp_version=3) # Talk RESP3 only (will throw an Exception if server does not support it) 359 | r = Redis(resp_version=-1) # Talk in the highest version possible 360 | ``` 361 | 362 | You can read about RESP3 protocol and responses in the [Redis documentation](https://github.com/antirez/RESP3/blob/master/spec.md). 363 | 364 | RESP3 allows the clients to know the response type (such as strings, lists, dictionaries, sets...), and Justredis supports all of the response types. 365 | 366 | RESP3 provides an option to get with the results extra attributes. Since Python's type system cannot add the attributes easily, another configuration value was added, ```attributes``` which specifies if you care about getting this information or not, the default is False: 367 | 368 | ```python 369 | r = Redis(attributes=True) 370 | ``` 371 | 372 | If attributes is disabled, you will get the direct Python mapping of the results (set, list, dict, string, numbers, etc...) and if enabled, you will get a special object which will hold the raw data in the ```data``` attribute, and the attributes in the ```attrs``` attribute. Notice that this feature is orthogonal to choosing RESP2 / RESP3 (but in RESP2 the attrs will always be empty), for ease of development. 373 | 374 | Here is an example of the difference in Redis version 6, with and without attributes: 375 | ```python 376 | >>> import justredis 377 | >>> r = justredis.Redis() # By default it connects via RESP2 378 | >>> r("hgetall", "aaa") 379 | [b'bbb', b'ccc', b'ccc', b'ddd'] 380 | >>> r("hgetall", "aaa", attributes=True) # This is RESP2 with attributes, it has .data and .attrs 381 | Array: [String: b'bbb' , String: b'ccc' , String: b'ccc' , String: b'ddd' ] 382 | >>> r = justredis.Redis(resp_version=-1) # This will connect to Redis 6 with RESP3 383 | >>> r("hgetall", "aaa") 384 | OrderedDict([(b'bbb', b'ccc'), (b'ccc', b'ddd')]) # This is Python's OrderedDict 385 | >>> r("hgetall", "aaa", attributes=True) 386 | Map: OrderedDict([(String: b'bbb' , String: b'ccc' ), (String: b'ccc' , String: b'ddd' )]) 387 | ``` 388 | 389 | ### Thread and async safety 390 | 391 | The library is thread safe and async safe. Do not pass Connection objects between different threads or coroutines. 392 | 393 | ### Modify 394 | 395 | You can change some of the settings on a per ```__call()__``` call, or if you want multiple calls to have different settings, you can use the ```modify()``` method on a Connection or Redis objects. 396 | 397 | Currently you can change the string decoder used, the database number and the attributes flag. Check the [examples](#examples) above to see how it's done. 398 | 399 | ### Serialization and deserialization 400 | 401 | The library supports as inputs only this types: bytes, bytearray, memoryview, str, int and float. If you pass a string, it will be encoded to bytes by the given encoder option (default is utf-8). Passing anything else will result in a ValueError. 402 | 403 | The library will return the data types that the RESP protocol returns as described in the RESP section. Exceptions will always be utf-8 string encoded and for other string results, you can decide to keep them as bytes, or to decode them to a string. 404 | 405 | ### Async support 406 | 407 | The API for the asynchronous commands is exactly the same, just adding "await" where it's needed: 408 | 409 | ```python 410 | AsyncRedis(**kwargs) 411 | @classmethod 412 | from_url(url, **kwargs) 413 | async __aenter__() / __aexit__() 414 | async aclose() 415 | # kwargs options = endpoint, decoder, attributes, database 416 | async __call__(*cmd, **kwargs) 417 | async endpoints() 418 | # kwargs options = decoder, attributes, database 419 | modify(**kwargs) # Returns a modified settings instance (while sharing the pool) 420 | # kwargs options = key, endpoint, decoder, attributes, database 421 | async connection(push=False, **kwargs) 422 | async __aenter__() / async __aexit__() 423 | async aclose() 424 | # kwargs options = decoder, attributes, database 425 | async __call__(*cmd, **kwargs) # On push connection no result for calls 426 | # kwargs options = decoder, attributes, database 427 | modify(**kwargs) # Returns a modified settings instance (while sharing the connection) 428 | 429 | # Push connection only commands 430 | # kwargs options = decoder, attributes 431 | async next_message(timeout=None, **kwargs) 432 | __iter__() 433 | async __next__() 434 | ``` 435 | 436 | Don't forget there is no ```__del__()``` method in async code, so call ```aclose()``` or use async context managers when needed. 437 | 438 | ### Extending the library with more command support 439 | 440 | You can extend the Redis object to support real redis commands, and not just calling them raw, here is an example: 441 | 442 | ```python 443 | from justredis import Redis 444 | 445 | 446 | class CustomCommands: 447 | def __init__(self, base): 448 | self._base = base 449 | 450 | def get(self, key, **kwargs): 451 | return self._base("get", key, **kwargs) 452 | 453 | def set(self, key, value, **kwargs): 454 | return self._base("set", key, value, **kwargs) 455 | 456 | 457 | r = Redis(custom_command_class=CustomCommands) 458 | r.set("hi", "there") 459 | assert r.get("hi", decoder="utf8") == "hi" 460 | with r.modify(database=1) as r1: 461 | assert r1.get("hi") == None 462 | ``` 463 | -------------------------------------------------------------------------------- /justredis/__init__.py: -------------------------------------------------------------------------------- 1 | from .sync.redis import Redis 2 | from .decoder import Error 3 | 4 | # TODO (misc) keep this in sync 5 | from .errors import * 6 | 7 | 8 | try: 9 | from .nonsync.redis import Redis as AsyncRedis 10 | except ImportError: 11 | 12 | class AsyncRedis: 13 | def __init__(self, *args, **kwargs): 14 | raise Exception("Using JustRedis asynchronously requires the AnyIO library to be installed.") 15 | 16 | 17 | except SyntaxError: 18 | 19 | class AsyncRedis: 20 | def __init__(self, *args, **kwargs): 21 | raise Exception("Using JustRedis asynchronously requires Python 3.6 or above.") 22 | 23 | 24 | except AttributeError as e: 25 | 26 | class AsyncRedis: 27 | def __init__(self, e=e, *args, **kwargs): 28 | raise Exception(e.args[0]) 29 | 30 | 31 | __all__ = "AsyncRedis", "Redis", "RedisError", "CommunicationError", "ConnectionPoolError", "ProtocolError", "PipelinedExceptions", "Error" 32 | -------------------------------------------------------------------------------- /justredis/decoder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from .errors import ProtocolError 4 | 5 | 6 | # TODO (misc) We can add helpful error messages on which part the parsing failed, should we do that ? 7 | # TODO (misc) Should we make ProtocolError the catch all chain ? 8 | # TODO (misc) There is allot of code duplication here, we can merge most of it. 9 | 10 | 11 | # Python 2 bytearray implementation is less efficient, luckily it's EOL 12 | class Buffer: 13 | def __init__(self): 14 | self._buffer = bytearray() 15 | 16 | def append(self, data): 17 | self._buffer += data 18 | 19 | def __len__(self): 20 | return len(self._buffer) 21 | 22 | def skip_if_startswith(self, data): 23 | if self._buffer.startswith(data): 24 | del self._buffer[: len(data)] 25 | return True 26 | return False 27 | 28 | def takeline(self): 29 | idx = self._buffer.find(b"\r\n") 30 | if idx == -1: 31 | return None 32 | ret = self._buffer[:idx] 33 | del self._buffer[: idx + 2] 34 | return ret 35 | 36 | def take(self, nbytes): 37 | ret = self._buffer[:nbytes] 38 | del self._buffer[:nbytes] 39 | return ret 40 | 41 | def skip(self, nbytes): 42 | del self._buffer[:nbytes] 43 | 44 | 45 | class NeedMoreData: 46 | pass 47 | 48 | 49 | class Result: 50 | def __init__(self, data, attr=None): 51 | self.data = data 52 | self.attr = attr 53 | 54 | def __repr__(self): 55 | rep = "%s: %s " % (type(self).__name__, self.data) 56 | if self.attr: 57 | rep += "[%s]" % self.attr 58 | return rep 59 | 60 | 61 | class String(Result): 62 | def __bytes__(self): 63 | return bytes(self.data) 64 | 65 | 66 | # TODO (misc) rename this to ReplyError ? 67 | class Error(Result, Exception): 68 | pass 69 | 70 | 71 | class Number(Result): 72 | pass 73 | 74 | 75 | class Null(Result): 76 | pass 77 | 78 | 79 | class Double(Result): 80 | pass 81 | 82 | 83 | class Boolean(Result): 84 | pass 85 | 86 | 87 | class BigNumber(Result): 88 | pass 89 | 90 | 91 | class Array(Result): 92 | pass 93 | 94 | 95 | class Map(Result): 96 | pass 97 | 98 | 99 | class Set(Result): 100 | pass 101 | 102 | 103 | class Push(Result): 104 | pass 105 | 106 | 107 | need_more_data = NeedMoreData() 108 | 109 | 110 | def parse_decoding(decoding): 111 | if decoding is None: 112 | return bytes 113 | elif isinstance(decoding, str): 114 | return lambda x, decoding=decoding: x.decode(encoding=decoding) 115 | elif isinstance(decoding, (tuple, list)): 116 | return lambda x, decoding=decoding: x.decode(*decoding) 117 | elif isinstance(decoding, dict): 118 | return lambda x, decoding=decoding: x.decode(**decoding) 119 | else: 120 | raise ValueError("Invalid decoding: %r" % decoding) 121 | 122 | 123 | class RedisRespDecoder: 124 | # TODO (misc) maybe add decoder, and push_decoder ? 125 | def __init__(self, decoder=None, attributes=False, **kwargs): 126 | self._decoder = parse_decoding(decoder) 127 | self._attributes = attributes 128 | self._buffer = Buffer() 129 | self._result = iter(self._extract_generator()) 130 | 131 | def feed(self, data): 132 | self._buffer.append(data) 133 | 134 | def extract(self): 135 | # We don't do an try/finally here, since if an error occured, the connection should be closed anyhow... 136 | return next(self._result) 137 | 138 | def _extract_generator(self): 139 | buffer = self._buffer 140 | array_stack = [] 141 | last_array = None 142 | last_attribute = None 143 | _need_more_data = need_more_data 144 | 145 | while True: 146 | # Handle aggregate data 147 | while last_array is not None: 148 | # Still more data needed to fill aggregate response 149 | if last_array[0] != len(last_array[1]): 150 | break 151 | # Result is an array 152 | if last_array[2] == 0: 153 | msg_type = Array 154 | msg = last_array[1] 155 | # Result is an map 156 | elif last_array[2] == 1: 157 | msg_type = Map 158 | i = iter(last_array[1]) 159 | # TODO (misc) is this the best way to deal with y ? 160 | msg = OrderedDict([(bytes(y) if y.__hash__ is None else y, next(i)) for y in i]) 161 | # Result is an set 162 | elif last_array[2] == 2: 163 | msg_type = Set 164 | msg = set(last_array[1]) 165 | # Result is an attribute 166 | elif last_array[2] == 3: 167 | i = iter(last_array[1]) 168 | # TODO (misc) is this the best way to deal with y ? 169 | last_attribute = OrderedDict([(bytes(y) if y.__hash__ is None else y, next(i)) for y in i]) 170 | # Result is an push 171 | elif last_array[2] == 4: 172 | msg_type = Push 173 | msg = last_array[1] 174 | else: 175 | raise ProtocolError("Unknown aggregate type") 176 | 177 | # If it's an attribute, nothing to do 178 | if last_array[2] == 3: 179 | if array_stack: 180 | last_array = array_stack.pop() 181 | else: 182 | last_array = None 183 | else: 184 | if array_stack: 185 | tmp = array_stack.pop() 186 | if self._attributes: 187 | msg = msg_type(msg, last_array[3]) 188 | # For now this isn't done, since we handle Push in unique connections 189 | # elif msg_type is Push: 190 | # msg = msg_type(msg) 191 | tmp[1].append(msg) 192 | last_array = tmp 193 | else: 194 | if self._attributes: 195 | msg = msg_type(msg, last_array[3]) 196 | # For now this isn't done, since we handle Push in unique connections 197 | # elif msg_type is Push: 198 | # msg = msg_type(msg) 199 | last_array = None 200 | yield msg 201 | 202 | # General RESP3 parsing 203 | while len(buffer) == 0: 204 | yield _need_more_data 205 | 206 | # Simple string 207 | if buffer.skip_if_startswith(b"+"): 208 | msg_type = String 209 | while True: 210 | msg = buffer.takeline() 211 | if msg is not None: 212 | break 213 | yield _need_more_data 214 | msg = self._decoder(msg) 215 | # Simple error 216 | elif buffer.skip_if_startswith(b"-"): 217 | msg_type = Error 218 | while True: 219 | msg = buffer.takeline() 220 | if msg is not None: 221 | break 222 | yield _need_more_data 223 | msg = bytes(msg).decode("utf-8", "replace") 224 | # Number 225 | elif buffer.skip_if_startswith(b":"): 226 | msg_type = Number 227 | while True: 228 | msg = buffer.takeline() 229 | if msg is not None: 230 | break 231 | yield _need_more_data 232 | msg = int(msg) 233 | # Blob string and Verbatim string 234 | elif buffer.skip_if_startswith(b"$") or buffer.skip_if_startswith(b"="): 235 | msg_type = String 236 | while True: 237 | length = buffer.takeline() 238 | if length is not None: 239 | break 240 | yield _need_more_data 241 | # Streamed string 242 | if length == b"?": 243 | chunks = [] 244 | while True: 245 | while True: 246 | chunk_size = buffer.takeline() 247 | if chunk_size is not None: 248 | break 249 | yield _need_more_data 250 | chunk_size = bytes(chunk_size) 251 | assert chunk_size[0] == 59 252 | chunk_size = int(chunk_size[1:]) 253 | if chunk_size == 0: 254 | break 255 | while True: 256 | if len(buffer) >= chunk_size + 2: 257 | break 258 | yield _need_more_data 259 | chunks.append(buffer.take(chunk_size)) 260 | buffer.skip(2) 261 | msg = self._decoder(b"".join(chunks)) 262 | chunks = None 263 | else: 264 | length = int(length) 265 | # Legacy RESP2 support 266 | if length == -1: 267 | msg = None 268 | msg_type = Null 269 | else: 270 | while True: 271 | if len(buffer) >= length + 2: 272 | break 273 | yield _need_more_data 274 | msg = self._decoder(buffer.take(length)) 275 | buffer.skip(2) 276 | # Array 277 | elif buffer.skip_if_startswith(b"*"): 278 | while True: 279 | length = buffer.takeline() 280 | if length is not None: 281 | break 282 | yield _need_more_data 283 | # Streamed array 284 | if length == b"?": 285 | length = None 286 | else: 287 | length = int(length) 288 | # Legacy RESP2 support 289 | if length == -1: 290 | msg = None 291 | msg_type = Null 292 | else: 293 | if last_array is not None: 294 | array_stack.append(last_array) 295 | last_array = [length, [], 0, last_attribute] 296 | last_attribute = None 297 | continue 298 | # Set 299 | elif buffer.skip_if_startswith(b"~"): 300 | while True: 301 | length = buffer.takeline() 302 | if length is not None: 303 | break 304 | yield _need_more_data 305 | # Streamed set 306 | if length == b"?": 307 | length = None 308 | else: 309 | length = int(length) 310 | # Legacy RESP2 support 311 | if length == -1: 312 | msg = None 313 | msg_type = Null 314 | else: 315 | if last_array is not None: 316 | array_stack.append(last_array) 317 | last_array = [length, [], 2, last_attribute] 318 | last_attribute = None 319 | continue 320 | # Null 321 | elif buffer.skip_if_startswith(b"_"): 322 | msg_type = Null 323 | while True: 324 | line = buffer.takeline() 325 | if line is not None: 326 | break 327 | yield _need_more_data 328 | assert len(line) == 0 329 | msg = None 330 | # Double 331 | elif buffer.skip_if_startswith(b","): 332 | msg_type = Double 333 | while True: 334 | msg = buffer.takeline() 335 | if msg is not None: 336 | break 337 | yield _need_more_data 338 | msg = float(msg) 339 | # Boolean 340 | elif buffer.skip_if_startswith(b"#"): 341 | msg_type = Boolean 342 | while True: 343 | msg = buffer.takeline() 344 | if msg is not None: 345 | break 346 | yield _need_more_data 347 | if msg == b"t": 348 | msg = True 349 | elif msg == b"f": 350 | msg = False 351 | # Blob error 352 | elif buffer.skip_if_startswith(b"!"): 353 | msg_type = Error 354 | while True: 355 | length = buffer.takeline() 356 | if length is not None: 357 | break 358 | yield _need_more_data 359 | length = int(length) 360 | # Legacy RESP2 support 361 | if length == -1: 362 | msg = None 363 | msg_type = Null 364 | else: 365 | while True: 366 | if len(buffer) >= length + 2: 367 | break 368 | yield _need_more_data 369 | msg = buffer.take(length) 370 | buffer.skip(2) 371 | msg = bytes(msg).decode("utf-8", "replace") 372 | # Big number 373 | elif buffer.skip_if_startswith(b"("): 374 | msg_type = BigNumber 375 | while True: 376 | msg = buffer.takeline() 377 | if msg is not None: 378 | break 379 | yield _need_more_data 380 | msg = int(msg) 381 | # Map 382 | elif buffer.skip_if_startswith(b"%"): 383 | while True: 384 | length = buffer.takeline() 385 | if length is not None: 386 | break 387 | yield _need_more_data 388 | # Streamed map 389 | if length == b"?": 390 | length = None 391 | else: 392 | length = int(length) * 2 393 | # Legacy RESP2 support 394 | if length == -1: 395 | msg = None 396 | msg_type = Null 397 | else: 398 | if last_array is not None: 399 | array_stack.append(last_array) 400 | last_array = [length, [], 1, last_attribute] 401 | last_attribute = None 402 | continue 403 | # Attribute 404 | elif buffer.skip_if_startswith(b"|"): 405 | while True: 406 | length = buffer.takeline() 407 | if length is not None: 408 | break 409 | yield _need_more_data 410 | length = int(length) * 2 411 | # Legacy RESP2 support 412 | if length == -1: 413 | msg = None 414 | msg_type = Null 415 | else: 416 | if last_array is not None: 417 | array_stack.append(last_array) 418 | last_array = [length, [], 3] 419 | continue 420 | # Push 421 | elif buffer.skip_if_startswith(b">"): 422 | while True: 423 | length = buffer.takeline() 424 | if length is not None: 425 | break 426 | yield _need_more_data 427 | length = int(length) 428 | # Legacy RESP2 support 429 | if length == -1: 430 | msg = None 431 | msg_type = Null 432 | else: 433 | if last_array is not None: 434 | array_stack.append(last_array) 435 | last_array = [length, [], 4, last_attribute] 436 | last_attribute = None 437 | continue 438 | # End of streaming aggregate type 439 | elif buffer.skip_if_startswith(b"."): 440 | while True: 441 | tmp = buffer.takeline() 442 | if tmp is not None: 443 | break 444 | yield _need_more_data 445 | assert tmp == b"" 446 | assert last_array[0] == None 447 | last_array[0] = len(last_array[1]) 448 | continue 449 | else: 450 | raise ProtocolError("Unknown type: %s" % bytes(buffer.take(1)).decode()) 451 | 452 | # Handle legacy RESP2 Null 453 | if msg is None and msg_type is not Null: 454 | msg_type = Null 455 | 456 | if self._attributes: 457 | msg = msg_type(msg, last_attribute) 458 | # We still enforce this types, because of ambiguity with other types 459 | elif msg_type == Error: 460 | msg = msg_type(msg) 461 | 462 | last_attribute = None 463 | 464 | if last_array: 465 | last_array[1].append(msg) 466 | else: 467 | yield msg 468 | -------------------------------------------------------------------------------- /justredis/encoder.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | 4 | def encode(encoding="utf-8", errors="strict"): 5 | def encode_with_encoding(inp, encoding=encoding, errors=errors): 6 | if isinstance(inp, (bytes, bytearray, memoryview)): 7 | return inp 8 | elif isinstance(inp, str): 9 | return inp.encode(encoding, errors) 10 | elif isinstance(inp, bool): 11 | raise ValueError("Invalid input for encoding") 12 | elif isinstance(inp, int): 13 | return b"%d" % inp 14 | elif isinstance(inp, float): 15 | return b"%r" % inp 16 | raise ValueError("Invalid input for encoding") 17 | 18 | return encode_with_encoding 19 | 20 | 21 | utf8_encode = encode() 22 | 23 | 24 | def parse_encoding(encoding): 25 | if encoding is None: 26 | return utf8_encode 27 | elif isinstance(encoding, str): 28 | return encode(encoding=encoding) 29 | elif isinstance(encoding, (tuple, list)): 30 | return encode(*encoding) 31 | elif isinstance(encoding, dict): 32 | return encode(**encoding) 33 | else: 34 | raise ValueError("Invalid encoding: %r" % encoding) 35 | 36 | 37 | # We add data to encode in 2 steps to avoid an invalid encoding causing to drop the connections 38 | class RedisRespEncoder: 39 | def __init__(self, encoder=None, cutoff_size=6000, **kwargs): 40 | self._encoder = parse_encoding(encoder) 41 | self._cutoff_size = cutoff_size 42 | self._chunks = deque() 43 | 44 | def encode(self, *cmd): 45 | data = [] 46 | add_data = data.append 47 | encoder = self._encoder 48 | add_data(b"*%d\r\n" % len(cmd)) 49 | for arg in cmd: 50 | arg = encoder(arg) 51 | if isinstance(arg, memoryview): 52 | length = arg.nbytes 53 | else: 54 | length = len(arg) 55 | add_data(b"$%d\r\n" % length) 56 | add_data(arg) 57 | add_data(b"\r\n") 58 | self._chunks.extend(data) 59 | 60 | def encode_multiple(self, *cmds): 61 | data = [] 62 | add_data = data.append 63 | encoder = self._encoder 64 | for cmd in cmds: 65 | add_data(b"*%d\r\n" % len(cmd)) 66 | for arg in cmd: 67 | arg = encoder(arg) 68 | if isinstance(arg, memoryview): 69 | length = arg.nbytes 70 | else: 71 | length = len(arg) 72 | add_data(b"$%d\r\n" % length) 73 | add_data(arg) 74 | add_data(b"\r\n") 75 | self._chunks.extend(data) 76 | 77 | def extract(self): 78 | cutoff_size = self._cutoff_size 79 | chunks = self._chunks 80 | if not cutoff_size: 81 | ret = b"".join(chunks) 82 | chunks.clear() 83 | return ret 84 | length = 0 85 | ret = [] 86 | while True: 87 | if length > cutoff_size or not chunks: 88 | if not ret: 89 | return None 90 | elif len(ret) == 1: 91 | return ret[0] 92 | else: 93 | return b"".join(ret) 94 | item = chunks.popleft() 95 | item_len = len(item) 96 | if item_len > cutoff_size: 97 | if length == 0: 98 | return item 99 | else: 100 | chunks.appendleft(item) 101 | return b"".join(ret) 102 | else: 103 | ret.append(item) 104 | length += item_len 105 | -------------------------------------------------------------------------------- /justredis/errors.py: -------------------------------------------------------------------------------- 1 | class RedisError(Exception): 2 | pass 3 | 4 | 5 | class CommunicationError(RedisError): 6 | pass 7 | 8 | 9 | class ConnectionPoolError(RedisError): 10 | pass 11 | 12 | 13 | class ProtocolError(RedisError): 14 | pass 15 | 16 | 17 | class PipelinedExceptions(RedisError): 18 | pass 19 | -------------------------------------------------------------------------------- /justredis/nonsync/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzickel/justredis/12c63bba8a83f1c7934d34ea081ada6efbc563ad/justredis/nonsync/__init__.py -------------------------------------------------------------------------------- /justredis/nonsync/cluster.py: -------------------------------------------------------------------------------- 1 | from binascii import crc_hqx 2 | 3 | try: 4 | from contextlib import asynccontextmanager 5 | except: 6 | from async_generator import asynccontextmanager 7 | from random import choice 8 | 9 | from .environment import get_environment 10 | from .connectionpool import ConnectionPool 11 | from ..decoder import Error, Result 12 | from ..utils import is_multiple_commands 13 | from ..encoder import parse_encoding 14 | from ..errors import CommunicationError 15 | 16 | 17 | def calc_hashslot(key): 18 | s = key.find(b"{") 19 | if s != -1: 20 | e = key.find(b"}") 21 | if e > s + 1: 22 | key = key[s + 1 : e] 23 | return crc_hqx(key, 0) % 16384 24 | 25 | 26 | # TODO (correctness) I think I covered the multithreading sensetive parts, make sure 27 | # TODO (misc) should I lazely check if there is a cluster ? (i.e. upgrade from a default connectionpool first) 28 | # TODO (correctness) thinkg about how to ASK redirect in connection taking and pipelining 29 | # TODO (correctness) make sure I dont have an issue where if there is a connection pool limit, I can get into a deadlock here 30 | # TODO (misc) future optimization, if we can't take from last_connection beacuse of connection pool limit, choose another random one. 31 | # TODO (correctness) disable encoding on all of conn commands 32 | # TODO (correctness) when invalidating last_connection, run slots update ? 33 | # TODO (misc) should _update_slots be called on each I/O error always ? 34 | 35 | 36 | class ClusterConnectionPool: 37 | def __init__(self, addresses=None, **kwargs): 38 | address = kwargs.pop("address", None) 39 | if addresses is None: 40 | if address: 41 | addresses = (address,) 42 | else: 43 | addresses = (("localhost", 6379),) 44 | elif address is not None: 45 | raise ValueError("Do not provide both addresses, and address") 46 | self._initial_addresses = addresses 47 | self._settings = kwargs 48 | self._connections = {} 49 | self._lock = get_environment(**kwargs).lock() 50 | self._last_connection = None 51 | self._last_connection_peername = None 52 | self._slots = [] 53 | # None = unknown, False = Nop, True = Yep 54 | self._clustered = None 55 | # Command info results 56 | self._command_cache = {} 57 | self._closed = False 58 | 59 | async def aclose(self): 60 | async with self._lock: 61 | if not self._closed: 62 | for pool in self._connections.values(): 63 | await pool.aclose() 64 | self._connections.clear() 65 | self._last_connection = None 66 | self._closed = True 67 | 68 | async def _update_slots(self): 69 | # TODO (misc) or if there is a hint from MOVED, recheck if clustered? 70 | if self._clustered == False: 71 | return 72 | conn = await self.take() 73 | try: 74 | slots = await conn(b"CLUSTER", b"SLOTS") 75 | self._clustered = True 76 | except Error: 77 | slots = [] 78 | self._clustered = False 79 | except Exception: 80 | # This is done to invalidate a potentially bad server, and pick up another one randomally next try 81 | self._last_connection = self._last_connection_peername = None 82 | raise 83 | finally: 84 | await self.release(conn) 85 | # TODO (misc) what todo about holes in hashslots in response ? 86 | if isinstance(slots, Result): 87 | slots = slots.data 88 | slots.sort(key=lambda x: x[0]) 89 | if slots and isinstance(slots[0][2][0], bytes): 90 | slots = [(x[1], (x[2][0].decode(), x[2][1])) for x in slots] 91 | else: 92 | slots = [(x[1], (x[2][0], x[2][1])) for x in slots] 93 | # We weren't in a cluster before, and we aren't now 94 | if self._clustered == False and not self._slots: 95 | return 96 | # Remove connections which are not a part of the cluster anymore 97 | async with self._lock: 98 | previous_connections = set(self._connections.keys()) 99 | new_connections = set([x[1] for x in slots]) 100 | connections_to_remove = previous_connections - new_connections 101 | for address in connections_to_remove: 102 | await self._connections[address].aclose() 103 | del self._connections[address] 104 | self._slots = slots 105 | 106 | if connections_to_remove: 107 | # TODO (misc) an optimization can only do this if it's not in new_connections 108 | self._last_connection = self._last_connection_peername = None 109 | 110 | async def _connection_by_hashslot(self, hashslot): 111 | if not self._slots: 112 | await self._update_slots() 113 | if self._clustered == False: 114 | return await self.take() 115 | if not self._slots: 116 | # TODO (misc) Is this the correct exception type? 117 | raise ValueError("Could not find any slots in the redis cluster") 118 | for slot in self._slots: 119 | if hashslot <= slot[0]: 120 | break 121 | address = slot[1] 122 | return await self.take(address) 123 | 124 | async def _get_info_for_command(self, *cmd): 125 | # commands are ascii, yes ? some commands can be larger than cmd[0] for index ? meh, let's be optimistic for now 126 | command = cmd[0] 127 | encode = getattr(command, "encode", None) 128 | if encode: 129 | command = encode("ascii") 130 | command = command.upper() 131 | info = self._command_cache.get(command) 132 | if info: 133 | return info 134 | conn = await self.take() 135 | try: 136 | command_info = await conn(b"COMMAND", b"INFO", command) 137 | except Exception: 138 | # This is done to invalidate a potentially bad server, and pick up another one randomally next try 139 | self._last_connection = self._last_connection_peername = None 140 | raise 141 | finally: 142 | await self.release(conn) 143 | if isinstance(command_info, Result): 144 | command_info = command_info.data 145 | command_info = command_info[0] 146 | # TODO (misc) on null, cache it too ? 147 | if command_info: 148 | self._command_cache[command] = command_info 149 | return command_info 150 | 151 | def _address_pool(self, address): 152 | pool = self._connections.get(address) 153 | if pool is None: 154 | # with self._lock: 155 | pool = self._connections.get(address) 156 | if pool is None: 157 | pool = ConnectionPool(address=address, **self._settings) 158 | self._connections[address] = pool 159 | return pool 160 | 161 | # TODO (misc) make sure the address got here from _slots (or risk stale data) 162 | async def take(self, address=None): 163 | # if address: 164 | # return self._address_pool(address).take() 165 | # TODO (misc) is this best ? 166 | async with self._lock: 167 | if address: 168 | return await self._address_pool(address).take() 169 | if self._last_connection: 170 | try: 171 | # TODO (misc) maybe do a health check here ? if there is an exception it will be invalidated anyhow for the next time... 172 | return await self._last_connection.take() 173 | except Exception: 174 | self._last_connection = None 175 | endpoints = [x[1] for x in self._slots.copy()] 176 | # TODO (correctness) maybe after I/O failure (repeated?) always go back to initial address ? or just remove an entry from the connection when it's invalid, till it's empty ? 177 | if not endpoints: 178 | endpoints = self._initial_addresses 179 | # TODO (correctness) should we pick up randomally, or go each one in the list on each failure ? 180 | address = choice(endpoints) 181 | pool = self._address_pool(address) 182 | self._last_connection = pool 183 | # TODO (corectness) on error, try the next one immidiatly 184 | conn = await self._last_connection.take() 185 | self._last_connection_peername = conn.peername() 186 | return conn 187 | 188 | async def take_by_key(self, key, **kwargs): 189 | if not isinstance(key, (bytes, bytearray)): 190 | encode = kwargs.get("encoder", self._settings.get("encoder")) 191 | key = parse_encoding(encode)(key) 192 | hashslot = calc_hashslot(key) 193 | return await self._connection_by_hashslot(hashslot) 194 | 195 | async def take_by_cmd(self, *cmd, **kwargs): 196 | if is_multiple_commands(*cmd): 197 | found = False 198 | # Some commands have no key information (like MULTI) so scan for one which does 199 | for command in cmd: 200 | info = await self._get_info_for_command(*command) 201 | if info is None: 202 | continue 203 | index = info[3] 204 | if index != 0: 205 | found = True 206 | cmd = command 207 | break 208 | if not found: 209 | cmd = cmd[0] 210 | else: 211 | info = await self._get_info_for_command(*cmd) 212 | # This should happen only if command doesn't exist 213 | if info is None: 214 | return await self.take() 215 | # TODO (misc) maybe see if command is movablekeys, and only do this then, for optimizations 216 | index = info[3] 217 | if index == 0: 218 | conn = await self.take() 219 | try: 220 | command = [b"COMMAND", b"GETKEYS"] 221 | command.extend(cmd) 222 | # TODO (misc) what do we want to do if an exception happened here ? 223 | command_info = await conn(*command, attributes=False, decoder=None) 224 | key = command_info[0] 225 | # This happens if the command has no key info, so any connection is good 226 | except Error: 227 | return await self.take() 228 | finally: 229 | await self.release(conn) 230 | else: 231 | # The command does not contain the key, usually this will result in a usage error ? 232 | if len(cmd) - 1 < index: 233 | return await self.take() 234 | else: 235 | key = cmd[index] 236 | return await self.take_by_key(key, **kwargs) 237 | 238 | async def release(self, conn): 239 | if self._clustered: 240 | # TODO (correctness) is the peername always 100% the same as the slot address ? to be on the safe side we can store both @ metadata 241 | address = conn.peername() 242 | pool = self._connections.get(address) 243 | else: 244 | # TODO (correctness) risky, if last_connection somehow changed (multiple fallback address?), we might be returning to the wrong one, does it matter then ?!? 245 | pool = self._last_connection 246 | # The connection might have been discharged 247 | if pool is None: 248 | await conn.aclose() 249 | return 250 | await pool.release(conn) 251 | 252 | async def __call__(self, *cmd, endpoint=False, **kwargs): 253 | if not cmd: 254 | raise ValueError("No command provided") 255 | if endpoint == "masters": 256 | return await self._on_all(*cmd) 257 | if self._clustered == False: 258 | conn = await self.take() 259 | elif endpoint: 260 | conn = await self.take(endpoint) 261 | else: 262 | conn = await self.take_by_cmd(*cmd) 263 | try: 264 | return await conn(*cmd, **kwargs) 265 | except CommunicationError: 266 | await self._update_slots() 267 | raise 268 | finally: 269 | seen_moved = conn.seen_moved() 270 | seen_asked = conn.seen_asked() 271 | await self.release(conn) 272 | if seen_moved: 273 | await self._update_slots() 274 | # If the user specified he wants a specific endpoint, we won't force the issue on him. 275 | # Also if ths cmd is multiple commands, we won't know which one failed and which didn't, so we don't try as well. 276 | if endpoint == False and not is_multiple_commands(*cmd): 277 | return await self(*cmd, **kwargs) 278 | elif seen_asked: 279 | if endpoint == False and not is_multiple_commands(*cmd): 280 | return await self(*cmd, **kwargs, endpoint=seen_asked, asking=True) 281 | 282 | @asynccontextmanager 283 | async def connection(self, key=None, endpoint=None, **kwargs): 284 | if key and endpoint: 285 | raise ValueError("Cannot specify both key and endpoint when taking a connection") 286 | if endpoint: 287 | conn = await self.take(endpoint) 288 | elif self._clustered == False or key is None: 289 | conn = await self.take() 290 | else: 291 | conn = await self.take_by_key(key, **kwargs) 292 | try: 293 | conn.allow_multi(True) 294 | yield conn 295 | except CommunicationError: 296 | await self._update_slots() 297 | raise 298 | finally: 299 | # We need to clean up the connection back to a normal state. 300 | try: 301 | if not conn.closed(): 302 | await conn._command(b"DISCARD") 303 | except Error: 304 | pass 305 | finally: 306 | # We need to handle the option where there was an moved error, to not have a recursion of a connection always trying the wrong server 307 | seen_moved = conn.seen_moved() 308 | conn.allow_multi(False) 309 | await self.release(conn) 310 | if seen_moved: 311 | await self._update_slots() 312 | 313 | async def _on_all(self, *cmd, filter="master", **kwargs): 314 | if self._clustered is None: 315 | await self._update_slots() 316 | if self._clustered == False: 317 | # This will be always filled by the _update_slots (atleast) 318 | return {self._last_connection_peername: await self(*cmd, **kwargs)} 319 | # TODO (api) on an error here, raise an exception ? 320 | res = {} 321 | for address in await self.endpoints(): 322 | if address[1]["type"] != filter: 323 | continue 324 | address = address[0] 325 | try: 326 | res[address] = await self(*cmd, endpoint=address, **kwargs) 327 | except Exception as e: 328 | res[address] = e 329 | return res 330 | 331 | async def endpoints(self): 332 | if self._clustered is None: 333 | await self._update_slots() 334 | if self._clustered: 335 | return [(x[1], {"type": "master"}) for x in self._slots.copy()] 336 | else: 337 | # This will be always filled by the _update_slots (atleast) 338 | return [(self._last_connection_peername, {"type": "regular"})] 339 | -------------------------------------------------------------------------------- /justredis/nonsync/connection.py: -------------------------------------------------------------------------------- 1 | from .environment import get_environment 2 | from ..decoder import RedisRespDecoder, need_more_data, Error 3 | from ..encoder import RedisRespEncoder 4 | from ..errors import CommunicationError, PipelinedExceptions 5 | from ..utils import get_command_name, is_multiple_commands 6 | 7 | # TODO (correctness) watch for manual SELECT and set_database ! 8 | 9 | 10 | not_allowed_push_commands = set([b"MONITOR", b"SUBSCRIBE", b"PSUBSCRIBE", b"UNSUBSCRIBE", b"PUNSUBSCRIBE"]) 11 | 12 | 13 | class TimeoutError(Exception): 14 | pass 15 | 16 | 17 | timeout_error = TimeoutError() 18 | 19 | 20 | class Connection: 21 | @classmethod 22 | async def create(cls, username=None, password=None, client_name=None, resp_version=2, socket_factory="tcp", connect_retry=2, database=0, **kwargs): 23 | ret = cls() 24 | await ret._init(username, password, client_name, resp_version, socket_factory, connect_retry, database, **kwargs) 25 | return ret 26 | 27 | def __init__(self): 28 | self._socket = None 29 | 30 | # TODO (api) client_name with connection pool (?) 31 | # TODO (documentation) the username/password/client_name need the decoding of whatever **kwargs is passed 32 | async def _init(self, username=None, password=None, client_name=None, resp_version=2, socket_factory="tcp", connect_retry=2, database=0, **kwargs): 33 | if resp_version not in (-1, 2, 3): 34 | raise ValueError("Unsupported RESP protocol version %s" % resp_version) 35 | 36 | self._settings = kwargs 37 | 38 | environment = get_environment(**kwargs) 39 | connect_retry += 1 40 | while connect_retry: 41 | try: 42 | self._socket = await environment.socket(socket_factory, **kwargs) 43 | break 44 | except Exception as e: 45 | connect_retry -= 1 46 | if not connect_retry: 47 | raise CommunicationError() from e 48 | self._encoder = RedisRespEncoder(**kwargs) 49 | self._decoder = RedisRespDecoder(**kwargs) 50 | self._seen_eof = False 51 | self._peername = self._socket.peername() 52 | self._seen_moved = False 53 | self._seen_ask = False 54 | self._allow_multi = False 55 | self._default_database = self._last_database = database 56 | self._cancel_class = environment.cancelledclass() 57 | 58 | connected = False 59 | # Try to negotiate RESP3 first if RESP2 is not forced 60 | if resp_version != 2: 61 | args = [b"HELLO", b"3"] 62 | if password is not None: 63 | if username is not None: 64 | args.extend((b"AUTH", username, password)) 65 | else: 66 | args.extend((b"AUTH", b"default", password)) 67 | if client_name: 68 | args.extend((b"SETNAME", client_name)) 69 | try: 70 | # TODO (misc) do something with the result ? 71 | await self._command(*args) 72 | connected = True 73 | except Error as e: 74 | # This is to seperate an login error from the server not supporting RESP3 75 | if e.args[0].startswith("ERR "): 76 | if resp_version == 3: 77 | # TODO (misc) this want have a __cause__ is that ok ? what exception to throw here ? 78 | raise Exception("Server does not support RESP3 protocol") 79 | else: 80 | raise 81 | if not connected: 82 | if password: 83 | if username: 84 | await self._command(b"AUTH", username, password) 85 | else: 86 | await self._command(b"AUTH", password) 87 | if client_name: 88 | await self._command(b"CLIENT", b"SETNAME", client_name) 89 | if database != 0: 90 | await self._command(b"SELECT", database) 91 | 92 | async def aclose(self, force=False): 93 | if self._socket: 94 | try: 95 | await self._socket.aclose(force) 96 | except Exception: 97 | pass 98 | self._socket = None 99 | self._encoder = None 100 | self._decoder = None 101 | 102 | # TODO (misc) better check ? (maybe it's closed, but the socket doesn't know it yet..., will be known the next time though) 103 | def closed(self): 104 | return self._socket is None 105 | 106 | def peername(self): 107 | return self._peername 108 | 109 | async def _send(self, *cmd): 110 | try: 111 | if is_multiple_commands(*cmd): 112 | self._encoder.encode_multiple(*cmd) 113 | else: 114 | self._encoder.encode(*cmd) 115 | while True: 116 | data = self._encoder.extract() 117 | if data is None: 118 | break 119 | await self._socket.send(data) 120 | except ValueError as e: 121 | raise 122 | except self._cancel_class: 123 | await self.aclose(True) 124 | raise 125 | except Exception as e: 126 | await self.aclose(True) 127 | raise CommunicationError("I/O error while trying to send a command") from e 128 | 129 | # TODO (misc) should a decoding error be considered an CommunicationError ? 130 | async def _recv(self, timeout=False): 131 | try: 132 | while True: 133 | res = self._decoder.extract() 134 | if res == need_more_data: 135 | if self._seen_eof: 136 | await self.aclose() 137 | raise EOFError("Connection reached EOF") 138 | else: 139 | data = await self._socket.recv(timeout) 140 | if data == b"": 141 | self._seen_eof = True 142 | elif data is None: 143 | return timeout_error 144 | else: 145 | # TODO This check if because another context can close us while we were reading (we can instead simply not remove self._decoder on close) 146 | if not self._decoder: 147 | raise Exception("Connection already closed") 148 | self._decoder.feed(data) 149 | continue 150 | return res 151 | except self._cancel_class: 152 | await self.aclose(True) 153 | raise 154 | except Exception as e: 155 | await self.aclose(True) 156 | raise CommunicationError("Error while trying to read a reply") from e 157 | 158 | async def pushed_message(self, timeout=False, decoder=False, attributes=None): 159 | orig_decoder = None 160 | if decoder != False or attributes is not None: 161 | orig_decoder = self._decoder 162 | kwargs = self._settings.copy() 163 | if decoder != False: 164 | kwargs["decoder"] = decoder 165 | if attributes is not None: 166 | kwargs["attributes"] = attributes 167 | self._decoder = RedisRespDecoder(**kwargs) 168 | try: 169 | res = await self._recv(timeout) 170 | if res == timeout_error: 171 | return None 172 | return res 173 | finally: 174 | if orig_decoder is not None: 175 | self._decoder = orig_decoder 176 | 177 | async def push_command(self, *cmd): 178 | await self._send(*cmd) 179 | 180 | async def set_database(self, database): 181 | if database is None: 182 | if self._default_database != self._last_database: 183 | await self._command(b"SELECT", self._default_database) 184 | self._last_database = self._default_database 185 | else: 186 | if database != self._last_database: 187 | await self._command(b"SELECT", database) 188 | self._last_database = database 189 | 190 | async def __call__(self, *cmd, decoder=False, attributes=None, database=None): 191 | if not cmd: 192 | raise ValueError("No command provided") 193 | orig_decoder = None 194 | if decoder != False or attributes is not None: 195 | orig_decoder = self._decoder 196 | kwargs = self._settings.copy() 197 | if decoder != False: 198 | kwargs["decoder"] = decoder 199 | if attributes is not None: 200 | kwargs["attributes"] = attributes 201 | self._decoder = RedisRespDecoder(**kwargs) 202 | try: 203 | await self.set_database(database) 204 | if is_multiple_commands(*cmd): 205 | return await self._commands(*cmd) 206 | else: 207 | return await self._command(*cmd) 208 | finally: 209 | if orig_decoder is not None: 210 | self._decoder = orig_decoder 211 | 212 | async def _command(self, *cmd): 213 | command_name = get_command_name(cmd) 214 | if command_name in not_allowed_push_commands: 215 | raise ValueError("Command %s is not allowed to be called directly, use the appropriate API instead" % cmd) 216 | if command_name == b"MULTI" and not self._allow_multi: 217 | raise ValueError("Take a connection if you want to use MULTI command.") 218 | await self._send(*cmd) 219 | res = await self._recv() 220 | if isinstance(res, Error): 221 | if res.args[0].startswith("MOVED "): 222 | self._seen_moved = True 223 | if res.args[0].startswith("ASK "): 224 | _, _, address = res.args[0].split(" ") 225 | self._seen_ask = address 226 | raise res 227 | if res == timeout_error: 228 | await self.aclose(True) 229 | raise timeout_error 230 | return res 231 | 232 | async def _commands(self, *cmds): 233 | for cmd in cmds: 234 | command_name = get_command_name(cmd) 235 | if command_name in not_allowed_push_commands: 236 | raise ValueError("Command %s is not allowed to be called directly, use the appropriate API instead" % cmd) 237 | if command_name == b"MULTI" and not self._allow_multi: 238 | raise ValueError("Take a connection if you want to use MULTI command.") 239 | await self._send(*cmds) 240 | res = [] 241 | found_errors = False 242 | for _ in cmds: 243 | try: 244 | result = await self._recv() 245 | if isinstance(result, Error): 246 | if result.args[0].startswith("MOVED "): 247 | self.seen_moved = True 248 | found_errors = True 249 | if result == timeout_error: 250 | await self.aclose(True) 251 | except Exception as e: 252 | result = e 253 | found_errors = True 254 | res.append(result) 255 | if found_errors: 256 | raise PipelinedExceptions(res) 257 | return res 258 | 259 | def seen_moved(self): 260 | if self._seen_moved: 261 | self._seen_moved = False 262 | return True 263 | return False 264 | 265 | def seen_asked(self): 266 | if self._seen_ask: 267 | ret = self._seen_ask 268 | self._seen_ask = False 269 | return ret 270 | return False 271 | 272 | def allow_multi(self, allow): 273 | self._allow_multi = allow 274 | -------------------------------------------------------------------------------- /justredis/nonsync/connectionpool.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | try: 4 | from contextlib import asynccontextmanager 5 | except: 6 | from async_generator import asynccontextmanager 7 | 8 | 9 | from .connection import Connection 10 | from ..errors import ConnectionPoolError 11 | from ..decoder import Error 12 | from .environment import get_environment 13 | 14 | 15 | # TODO (misc) can we relax the _lock ? 16 | 17 | 18 | class ConnectionPool: 19 | def __init__(self, max_connections=None, wait_timeout=None, **kwargs): 20 | self._max_connections = max_connections 21 | self._wait_timeout = wait_timeout 22 | self._connection_settings = kwargs 23 | 24 | self._lock = get_environment(**kwargs).lock() 25 | self._shield = get_environment(**kwargs).shield 26 | self._limit = get_environment(**kwargs).semaphore(max_connections) if max_connections else None 27 | self._connections_available = deque() 28 | self._connections_in_use = set() 29 | self._closed = False 30 | 31 | async def aclose(self): 32 | async with self._lock: 33 | if not self._closed: 34 | # We do this first, so if another thread calls release it won't get back to the pool 35 | for connection in self._connections_available: 36 | await connection.aclose() 37 | for connection in self._connections_in_use: 38 | await connection.aclose() 39 | self._connections_available.clear() 40 | self._connections_in_use.clear() 41 | self._limit = get_environment(**self.connection_settings).semaphore(self._max_connections) if self._max_connections else None 42 | self._closed = True 43 | 44 | async def take(self): 45 | if self._closed: 46 | raise ConnectionPoolError("Pool already closed") 47 | # TODO (correctness) cluster depends on this failing if closed ! 48 | try: 49 | while True: 50 | conn = self._connections_available.popleft() 51 | if not conn.closed(): 52 | break 53 | if self._limit is not None: 54 | await self._limit.release() 55 | except IndexError: 56 | if self._limit is not None and not await self._limit.acquire(self._wait_timeout): 57 | raise ConnectionPoolError("Could not acquire an connection form the pool") 58 | try: 59 | conn = await Connection.create(**self._connection_settings) 60 | except Exception: 61 | if self._limit is not None: 62 | await self._limit.release() 63 | raise 64 | self._connections_in_use.add(conn) 65 | return conn 66 | 67 | async def release(self, conn): 68 | async with self._shield(): 69 | async with self._lock: 70 | try: 71 | self._connections_in_use.remove(conn) 72 | # TODO (correctness) should we release the self._limit here as well ? (or just make close forever) 73 | # If this fails, it's a connection from a previous cycle, don't reuse it 74 | except KeyError: 75 | await conn.aclose() 76 | return 77 | if not conn.closed(): 78 | self._connections_available.append(conn) 79 | elif self._limit is not None: 80 | await self._limit.release() 81 | 82 | async def __call__(self, *cmd, **kwargs): 83 | if not cmd: 84 | raise ValueError("No command provided") 85 | conn = await self.take() 86 | try: 87 | return await conn(*cmd, **kwargs) 88 | finally: 89 | await self.release(conn) 90 | 91 | @asynccontextmanager 92 | async def connection(self, **kwargs): 93 | conn = await self.take() 94 | try: 95 | conn.allow_multi(True) 96 | yield conn 97 | finally: 98 | # We need to clean up the connection back to a normal state. 99 | try: 100 | await conn._command(b"DISCARD") 101 | except Exception: 102 | pass 103 | conn.allow_multi(False) 104 | await self.release(conn) 105 | 106 | async def endpoints(self): 107 | conn = await self.take() 108 | try: 109 | return [(conn.peername(), {"type": "regular"})] 110 | finally: 111 | await self.release(conn) 112 | -------------------------------------------------------------------------------- /justredis/nonsync/environment.py: -------------------------------------------------------------------------------- 1 | from .environments.anyio import AnyIOEnvironment 2 | 3 | 4 | def get_environment(environment=AnyIOEnvironment, **kargs): 5 | if environment == "anyio": 6 | environment = AnyIOEnvironment 7 | return environment 8 | -------------------------------------------------------------------------------- /justredis/nonsync/environments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzickel/justredis/12c63bba8a83f1c7934d34ea081ada6efbc563ad/justredis/nonsync/environments/__init__.py -------------------------------------------------------------------------------- /justredis/nonsync/environments/anyio.py: -------------------------------------------------------------------------------- 1 | import anyio 2 | 3 | try: 4 | anyio.create_tcp_listener 5 | except AttributeError: 6 | raise AttributeError("You are using an old and incompatible AnyIO version, the minimum required version is AnyIO 2.0.0 .") 7 | import socket 8 | import sys 9 | import ssl 10 | 11 | 12 | platform = "" 13 | if sys.platform.startswith("linux"): 14 | platform = "linux" 15 | elif sys.platform.startswith("darwin"): 16 | platform = "darwin" 17 | elif sys.platform.startswith("win"): 18 | platform = "windows" 19 | 20 | 21 | async def tcpsocket(address=None, connect_timeout=None, tcp_keepalive=None, tcp_nodelay=None, **kwargs): 22 | if address is None: 23 | address = ("localhost", 6379) 24 | async with anyio.fail_after(connect_timeout): 25 | sock = await anyio.connect_tcp(address[0], address[1]) 26 | if tcp_nodelay is not None: 27 | if tcp_nodelay: 28 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 29 | else: 30 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) 31 | if tcp_keepalive: 32 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 33 | if platform == "linux": 34 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keepalive) 35 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keepalive // 3) 36 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) 37 | elif platform == "darwin": 38 | sock.setsockopt(socket.IPPROTO_TCP, 0x10, tcp_keepalive // 3) 39 | elif platform == "windows": 40 | sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, tcp_keepalive * 1000, tcp_keepalive // 3 * 1000)) 41 | return sock 42 | 43 | 44 | async def unixsocket(address=None, connect_timeout=None, **kwargs): 45 | if address is None: 46 | address = "/tmp/redis.sock" 47 | async with anyio.fail_after(connect_timeout): 48 | sock = await anyio.connect_unix(address) 49 | return sock 50 | 51 | 52 | # TODO (misc) should we enable server hostname enforcment ? give it as an option ? what about cluster ? 53 | async def sslsocket(address=None, ssl_context=None, **kwargs): 54 | if address is None: 55 | address = ("localhost", 6379) 56 | if ssl_context is None: 57 | ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 58 | cafile = kwargs.get("ssl_cafile") 59 | if cafile: 60 | ssl_context.load_verify_locations(cafile) 61 | certfile = kwargs.get("ssl_certfile") 62 | keyfile = kwargs.get("ssl_keyfile") 63 | if certfile: 64 | ssl_context.load_cert_chain(certfile, keyfile) 65 | return await tcpsocket(ssl_context=ssl_context, tls=True, **kwargs) 66 | 67 | 68 | class SocketWrapper: 69 | @classmethod 70 | async def create(cls, socket_factory, buffersize=2 ** 16, socket_timeout=None, **kwargs): 71 | ret = cls() 72 | await ret._init(socket_factory, buffersize, socket_timeout, **kwargs) 73 | return ret 74 | 75 | async def _init(self, socket_factory, buffersize=2 ** 16, socket_timeout=None, **kwargs): 76 | self._buffersize = buffersize 77 | self._socket_timeout = socket_timeout 78 | self._socket = await socket_factory(**kwargs) 79 | 80 | async def aclose(self, force=False): 81 | if force: 82 | await anyio.aclose_forcefully(self._socket) 83 | else: 84 | await self._socket.aclose() 85 | 86 | async def send(self, data): 87 | if self._socket_timeout: 88 | async with anyio.fail_after(self._socket_timeout): 89 | await self._socket.send(data) 90 | else: 91 | await self._socket.send(data) 92 | 93 | # If you override this, make sure to return an empty bytes for EOF and a None for timeout ! 94 | async def recv(self, timeout=False): 95 | if timeout is False: 96 | timeout = self._socket_timeout 97 | if timeout: 98 | try: 99 | async with anyio.fail_after(timeout): 100 | return await self._socket.receive(self._buffersize) 101 | except TimeoutError: 102 | return None 103 | else: 104 | return await self._socket.receive(self._buffersize) 105 | 106 | def peername(self): 107 | peername = self._socket.extra(anyio.abc.SocketAttribute.remote_address) 108 | if isinstance(peername, (list, tuple)): 109 | peername = peername[:2] 110 | return peername 111 | 112 | 113 | class OurSemaphore: 114 | def __init__(self, value): 115 | self._semaphore = anyio.create_capcity_limiter(value) 116 | 117 | async def release(self): 118 | await self._semaphore.release() 119 | 120 | async def acquire(self, timeout=None): 121 | if timeout: 122 | async with anyio.fail_after(timeout): 123 | await self._semaphore.acquire() 124 | else: 125 | await self._semaphore.acquire() 126 | 127 | 128 | class AnyIOEnvironment: 129 | @staticmethod 130 | async def socket(socket_type="tcp", **kwargs): 131 | if socket_type == "tcp": 132 | socket_type = tcpsocket 133 | elif socket_type == "unix": 134 | socket_type = unixsocket 135 | elif socket_type == "ssl": 136 | socket_type = sslsocket 137 | else: 138 | raise NotImplementedError("Unknown socket type: %s" % socket_type) 139 | return await SocketWrapper.create(socket_type, **kwargs) 140 | 141 | @staticmethod 142 | def semaphore(limit): 143 | return OurSemaphore(limit) 144 | 145 | @staticmethod 146 | def lock(): 147 | return anyio.create_lock() 148 | 149 | # async only? 150 | @staticmethod 151 | def shield(): 152 | return anyio.open_cancel_scope(shield=True) 153 | 154 | # async only? 155 | @staticmethod 156 | def cancelledclass(): 157 | return anyio.get_cancelled_exc_class() 158 | -------------------------------------------------------------------------------- /justredis/nonsync/redis.py: -------------------------------------------------------------------------------- 1 | from .connectionpool import ConnectionPool 2 | from .cluster import ClusterConnectionPool 3 | from ..decoder import Error 4 | from ..utils import parse_url, merge_dicts 5 | 6 | 7 | # TODO (misc) document all the kwargs everywhere 8 | # TODO (api) internal remove from connectionpool the __enter__/__exit__ and use take(**kwargs)/release 9 | 10 | 11 | # We do this seperation to allow changing per command and connection settings easily 12 | class ModifiedRedis: 13 | def __init__(self, connection_pool, custom_command_class=None, **kwargs): 14 | self._connection_pool = connection_pool 15 | self._custom_command_class = custom_command_class 16 | self._custom_command = custom_command_class(self) if self._custom_command_class else None 17 | self._settings = kwargs 18 | 19 | async def aclose(self): 20 | self._connection_pool = self._settings = None 21 | 22 | async def __aenter__(self): 23 | return self 24 | 25 | async def __aexit__(self, *args): 26 | await self.aclose() 27 | 28 | async def __call__(self, *cmd, **kwargs): 29 | settings = merge_dicts(self._settings, kwargs) 30 | if settings is None: 31 | return await self._connection_pool(*cmd) 32 | else: 33 | return await self._connection_pool(*cmd, **settings) 34 | 35 | async def connection(self, *args, push=False, **kwargs): 36 | if args: 37 | raise ValueError("Please specify the connection arguments as named arguments (i.e. push=..., key=...)") 38 | wrapper = PushConnection if push else Connection 39 | # TODO (async) do i need await on the connection itself ? 40 | return await wrapper.create(self._connection_pool.connection(**kwargs), **self._settings) 41 | 42 | async def endpoints(self): 43 | return await self._connection_pool.endpoints() 44 | 45 | def modify(self, **kwargs): 46 | settings = self._settings.copy() 47 | settings.update(kwargs) 48 | return ModifiedRedis(self._connection_pool, custom_command_class=self._custom_command_class, **settings) 49 | 50 | def __getattr__(self, attribute): 51 | if not self._custom_command: 52 | raise AttributeError("No such attribute: %s" % attribute) 53 | return getattr(self._custom_command, attribute) 54 | 55 | 56 | # TODO (api) should we implement an callback for when slots have changed ? 57 | class Redis(ModifiedRedis): 58 | @classmethod 59 | def from_url(cls, url, **kwargs): 60 | res = parse_url(url) 61 | res.update(kwargs) 62 | return cls(**res) 63 | 64 | def __init__(self, pool_factory=ClusterConnectionPool, custom_command_class=None, **kwargs): 65 | """ 66 | Currently documented in README.md 67 | """ 68 | self._connection_pool = None 69 | if pool_factory == "pool": 70 | pool_factory = ConnectionPool 71 | elif pool_factory in ("auto", "cluster"): 72 | pool_factory = ClusterConnectionPool 73 | if not hasattr(pool_factory, "__call__"): 74 | raise AttributeError("A valid pool_factory is required, if you want to set address, use .from_url() or address=(host, port)") 75 | super(Redis, self).__init__(pool_factory(**kwargs), custom_command_class=custom_command_class) 76 | 77 | async def aclose(self): 78 | if self._connection_pool: 79 | await self._connection_pool.aclose() 80 | self._connection_pool = None 81 | 82 | 83 | # TODO (api) add a modified_class here as well. 84 | class ModifiedConnection: 85 | def __init__(self, connection, **kwargs): 86 | self._connection = connection 87 | self._settings = kwargs 88 | 89 | async def aclose(self): 90 | self._connection = self._settings = None 91 | 92 | async def __aenter__(self): 93 | return self 94 | 95 | async def __aexit__(self, *args): 96 | await self.aclose() 97 | 98 | async def __call__(self, *cmd, **kwargs): 99 | settings = merge_dicts(self._settings, kwargs) 100 | if settings is None: 101 | return await self._connection(*cmd) 102 | else: 103 | return await self._connection(*cmd, **settings) 104 | 105 | def modify(self, **kwargs): 106 | settings = self._settings.copy() 107 | settings.update(kwargs) 108 | return ModifiedConnection(self._connection, **settings) 109 | 110 | 111 | class Connection(ModifiedConnection): 112 | @classmethod 113 | async def create(cls, connection, **kwargs): 114 | conn = await connection.__aenter__() 115 | ret = cls(conn, **kwargs) 116 | ret._connection_context = connection 117 | return ret 118 | 119 | def __init__(self, connection, **kwargs): 120 | super(Connection, self).__init__(connection, **kwargs) 121 | 122 | async def aclose(self): 123 | if self._connection_context: 124 | # TODO (correctness) is this correct? 125 | await self._connection_context.__aexit__(None, None, None) 126 | self._connection = None 127 | self._connection_context = None 128 | self._settings = None 129 | 130 | 131 | class PushConnection(Connection): 132 | async def aclose(self): 133 | if self._connection: 134 | # We close the connection here, since it's both hard to reset the state of the connection, and this is usually not done / at low frequency. 135 | await self._connection.aclose() 136 | await super(PushConnection, self).aclose() 137 | 138 | async def __call__(self, *cmd, **kwargs): 139 | settings = merge_dicts(self._settings, kwargs) 140 | if settings is None: 141 | return await self._connection.push_command(*cmd) 142 | else: 143 | return await self._connection.push_command(*cmd, **settings) 144 | 145 | async def next_message(self, *args, timeout=None, **kwargs): 146 | if args: 147 | raise ValueError("Please specify the next_message arguments as named arguments (i.e. timeout=...)") 148 | settings = merge_dicts(self._settings, kwargs) 149 | if settings is None: 150 | return await self._connection.pushed_message(timeout=timeout) 151 | else: 152 | return await self._connection.pushed_message(timeout=timeout, **settings) 153 | 154 | def __iter__(self): 155 | return self 156 | 157 | # TODO (async) is this correct ? 158 | async def __next__(self): 159 | return await self._connection.pushed_message() 160 | -------------------------------------------------------------------------------- /justredis/sync/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzickel/justredis/12c63bba8a83f1c7934d34ea081ada6efbc563ad/justredis/sync/__init__.py -------------------------------------------------------------------------------- /justredis/sync/cluster.py: -------------------------------------------------------------------------------- 1 | from binascii import crc_hqx 2 | from contextlib import contextmanager 3 | from random import choice 4 | 5 | from .environment import get_environment 6 | from .connectionpool import ConnectionPool 7 | from ..decoder import Error, Result 8 | from ..utils import is_multiple_commands 9 | from ..encoder import parse_encoding 10 | from ..errors import CommunicationError 11 | 12 | 13 | def calc_hashslot(key): 14 | s = key.find(b"{") 15 | if s != -1: 16 | e = key.find(b"}") 17 | if e > s + 1: 18 | key = key[s + 1 : e] 19 | return crc_hqx(key, 0) % 16384 20 | 21 | 22 | # TODO (correctness) I think I covered the multithreading sensetive parts, make sure 23 | # TODO (misc) should I lazely check if there is a cluster ? (i.e. upgrade from a default connectionpool first) 24 | # TODO (correctness) thinkg about how to ASK redirect in connection taking and pipelining 25 | # TODO (correctness) make sure I dont have an issue where if there is a connection pool limit, I can get into a deadlock here 26 | # TODO (misc) future optimization, if we can't take from last_connection beacuse of connection pool limit, choose another random one. 27 | # TODO (correctness) disable encoding on all of conn commands 28 | # TODO (correctness) when invalidating last_connection, run slots update ? 29 | # TODO (misc) should _update_slots be called on each I/O error always ? 30 | 31 | 32 | class ClusterConnectionPool: 33 | def __init__(self, addresses=None, **kwargs): 34 | address = kwargs.pop("address", None) 35 | if addresses is None: 36 | if address: 37 | addresses = (address,) 38 | else: 39 | addresses = (("localhost", 6379),) 40 | elif address is not None: 41 | raise ValueError("Do not provide both addresses, and address") 42 | self._initial_addresses = addresses 43 | self._settings = kwargs 44 | self._connections = {} 45 | self._lock = get_environment(**kwargs).lock() 46 | self._last_connection = None 47 | self._last_connection_peername = None 48 | self._slots = [] 49 | # None = unknown, False = Nop, True = Yep 50 | self._clustered = None 51 | # Command info results 52 | self._command_cache = {} 53 | self._closed = False 54 | 55 | def __del__(self): 56 | self.close() 57 | 58 | def close(self): 59 | with self._lock: 60 | if not self._closed: 61 | for pool in self._connections.values(): 62 | pool.close() 63 | self._connections.clear() 64 | self._last_connection = None 65 | self._closed = True 66 | 67 | def _update_slots(self): 68 | # TODO (misc) or if there is a hint from MOVED, recheck if clustered? 69 | if self._clustered == False: 70 | return 71 | conn = self.take() 72 | try: 73 | slots = conn(b"CLUSTER", b"SLOTS") 74 | self._clustered = True 75 | except Error: 76 | slots = [] 77 | self._clustered = False 78 | except Exception: 79 | # This is done to invalidate a potentially bad server, and pick up another one randomally next try 80 | self._last_connection = self._last_connection_peername = None 81 | raise 82 | finally: 83 | self.release(conn) 84 | # TODO (misc) what todo about holes in hashslots in response ? 85 | if isinstance(slots, Result): 86 | slots = slots.data 87 | slots.sort(key=lambda x: x[0]) 88 | if slots and isinstance(slots[0][2][0], bytes): 89 | slots = [(x[1], (x[2][0].decode(), x[2][1])) for x in slots] 90 | else: 91 | slots = [(x[1], (x[2][0], x[2][1])) for x in slots] 92 | # We weren't in a cluster before, and we aren't now 93 | if self._clustered == False and not self._slots: 94 | return 95 | # Remove connections which are not a part of the cluster anymore 96 | with self._lock: 97 | previous_connections = set(self._connections.keys()) 98 | new_connections = set([x[1] for x in slots]) 99 | connections_to_remove = previous_connections - new_connections 100 | for address in connections_to_remove: 101 | self._connections[address].close() 102 | del self._connections[address] 103 | self._slots = slots 104 | 105 | if connections_to_remove: 106 | # TODO (misc) an optimization can only do this if it's not in new_connections 107 | self._last_connection = self._last_connection_peername = None 108 | 109 | def _connection_by_hashslot(self, hashslot): 110 | if not self._slots: 111 | self._update_slots() 112 | if self._clustered == False: 113 | return self.take() 114 | if not self._slots: 115 | # TODO (misc) Is this the correct exception type? 116 | raise ValueError("Could not find any slots in the redis cluster") 117 | for slot in self._slots: 118 | if hashslot <= slot[0]: 119 | break 120 | address = slot[1] 121 | return self.take(address) 122 | 123 | def _get_info_for_command(self, *cmd): 124 | # commands are ascii, yes ? some commands can be larger than cmd[0] for index ? meh, let's be optimistic for now 125 | command = cmd[0] 126 | encode = getattr(command, "encode", None) 127 | if encode: 128 | command = encode("ascii") 129 | command = command.upper() 130 | info = self._command_cache.get(command) 131 | if info: 132 | return info 133 | conn = self.take() 134 | try: 135 | command_info = conn(b"COMMAND", b"INFO", command) 136 | except Exception: 137 | # This is done to invalidate a potentially bad server, and pick up another one randomally next try 138 | self._last_connection = self._last_connection_peername = None 139 | raise 140 | finally: 141 | self.release(conn) 142 | if isinstance(command_info, Result): 143 | command_info = command_info.data 144 | command_info = command_info[0] 145 | # TODO (misc) on null, cache it too ? 146 | if command_info: 147 | self._command_cache[command] = command_info 148 | return command_info 149 | 150 | def _address_pool(self, address): 151 | pool = self._connections.get(address) 152 | if pool is None: 153 | # with self._lock: 154 | pool = self._connections.get(address) 155 | if pool is None: 156 | pool = ConnectionPool(address=address, **self._settings) 157 | self._connections[address] = pool 158 | return pool 159 | 160 | # TODO (misc) make sure the address got here from _slots (or risk stale data) 161 | def take(self, address=None): 162 | # if address: 163 | # return self._address_pool(address).take() 164 | # TODO (misc) is this best ? 165 | with self._lock: 166 | if address: 167 | return self._address_pool(address).take() 168 | if self._last_connection: 169 | try: 170 | # TODO (misc) maybe do a health check here ? if there is an exception it will be invalidated anyhow for the next time... 171 | return self._last_connection.take() 172 | except Exception: 173 | self._last_connection = None 174 | endpoints = [x[1] for x in self._slots.copy()] 175 | # TODO (correctness) maybe after I/O failure (repeated?) always go back to initial address ? or just remove an entry from the connection when it's invalid, till it's empty ? 176 | if not endpoints: 177 | endpoints = self._initial_addresses 178 | # TODO (correctness) should we pick up randomally, or go each one in the list on each failure ? 179 | address = choice(endpoints) 180 | pool = self._address_pool(address) 181 | self._last_connection = pool 182 | # TODO (corectness) on error, try the next one immidiatly 183 | conn = self._last_connection.take() 184 | self._last_connection_peername = conn.peername() 185 | return conn 186 | 187 | def take_by_key(self, key, **kwargs): 188 | if not isinstance(key, (bytes, bytearray)): 189 | encode = kwargs.get("encoder", self._settings.get("encoder")) 190 | key = parse_encoding(encode)(key) 191 | hashslot = calc_hashslot(key) 192 | return self._connection_by_hashslot(hashslot) 193 | 194 | def take_by_cmd(self, *cmd, **kwargs): 195 | if is_multiple_commands(*cmd): 196 | found = False 197 | # Some commands have no key information (like MULTI) so scan for one which does 198 | for command in cmd: 199 | info = self._get_info_for_command(*command) 200 | if info is None: 201 | continue 202 | index = info[3] 203 | if index != 0: 204 | found = True 205 | cmd = command 206 | break 207 | if not found: 208 | cmd = cmd[0] 209 | else: 210 | info = self._get_info_for_command(*cmd) 211 | # This should happen only if command doesn't exist 212 | if info is None: 213 | return self.take() 214 | # TODO (misc) maybe see if command is movablekeys, and only do this then, for optimizations 215 | index = info[3] 216 | if index == 0: 217 | conn = self.take() 218 | try: 219 | command = [b"COMMAND", b"GETKEYS"] 220 | command.extend(cmd) 221 | # TODO (misc) what do we want to do if an exception happened here ? 222 | command_info = conn(*command, attributes=False, decoder=None) 223 | key = command_info[0] 224 | # This happens if the command has no key info, so any connection is good 225 | except Error: 226 | return self.take() 227 | finally: 228 | self.release(conn) 229 | else: 230 | # The command does not contain the key, usually this will result in a usage error ? 231 | if len(cmd) - 1 < index: 232 | return self.take() 233 | else: 234 | key = cmd[index] 235 | return self.take_by_key(key, **kwargs) 236 | 237 | def release(self, conn): 238 | if self._clustered: 239 | # TODO (correctness) is the peername always 100% the same as the slot address ? to be on the safe side we can store both @ metadata 240 | address = conn.peername() 241 | pool = self._connections.get(address) 242 | else: 243 | # TODO (correctness) risky, if last_connection somehow changed (multiple fallback address?), we might be returning to the wrong one, does it matter then ?!? 244 | pool = self._last_connection 245 | # The connection might have been discharged 246 | if pool is None: 247 | conn.close() 248 | return 249 | pool.release(conn) 250 | 251 | def __call__(self, *cmd, endpoint=False, **kwargs): 252 | if not cmd: 253 | raise ValueError("No command provided") 254 | if endpoint == "masters": 255 | return self._on_all(*cmd) 256 | if self._clustered == False: 257 | conn = self.take() 258 | elif endpoint: 259 | conn = self.take(endpoint) 260 | else: 261 | conn = self.take_by_cmd(*cmd) 262 | try: 263 | return conn(*cmd, **kwargs) 264 | except CommunicationError: 265 | self._update_slots() 266 | raise 267 | finally: 268 | seen_moved = conn.seen_moved() 269 | seen_asked = conn.seen_asked() 270 | self.release(conn) 271 | if seen_moved: 272 | self._update_slots() 273 | # If the user specified he wants a specific endpoint, we won't force the issue on him. 274 | # Also if ths cmd is multiple commands, we won't know which one failed and which didn't, so we don't try as well. 275 | if endpoint == False and not is_multiple_commands(*cmd): 276 | return self(*cmd, **kwargs) 277 | elif seen_asked: 278 | if endpoint == False and not is_multiple_commands(*cmd): 279 | return self(*cmd, **kwargs, endpoint=seen_asked, asking=True) 280 | 281 | @contextmanager 282 | def connection(self, key=None, endpoint=None, **kwargs): 283 | if key and endpoint: 284 | raise ValueError("Cannot specify both key and endpoint when taking a connection") 285 | if endpoint: 286 | conn = self.take(endpoint) 287 | elif self._clustered == False or key is None: 288 | conn = self.take() 289 | else: 290 | conn = self.take_by_key(key, **kwargs) 291 | try: 292 | conn.allow_multi(True) 293 | yield conn 294 | except CommunicationError: 295 | self._update_slots() 296 | raise 297 | finally: 298 | # We need to clean up the connection back to a normal state. 299 | try: 300 | if not conn.closed(): 301 | conn._command(b"DISCARD") 302 | except Error: 303 | pass 304 | finally: 305 | # We need to handle the option where there was an moved error, to not have a recursion of a connection always trying the wrong server 306 | seen_moved = conn.seen_moved() 307 | conn.allow_multi(False) 308 | self.release(conn) 309 | if seen_moved: 310 | self._update_slots() 311 | 312 | def _on_all(self, *cmd, filter="master", **kwargs): 313 | if self._clustered is None: 314 | self._update_slots() 315 | if self._clustered == False: 316 | # This will be always filled by the _update_slots (atleast) 317 | return {self._last_connection_peername: self(*cmd, **kwargs)} 318 | # TODO (api) on an error here, raise an exception ? 319 | res = {} 320 | for address in self.endpoints(): 321 | if address[1]["type"] != filter: 322 | continue 323 | address = address[0] 324 | try: 325 | res[address] = self(*cmd, endpoint=address, **kwargs) 326 | except Exception as e: 327 | res[address] = e 328 | return res 329 | 330 | def endpoints(self): 331 | if self._clustered is None: 332 | self._update_slots() 333 | if self._clustered: 334 | return [(x[1], {"type": "master"}) for x in self._slots.copy()] 335 | else: 336 | # This will be always filled by the _update_slots (atleast) 337 | return [(self._last_connection_peername, {"type": "regular"})] 338 | -------------------------------------------------------------------------------- /justredis/sync/connection.py: -------------------------------------------------------------------------------- 1 | from .environment import get_environment 2 | from ..decoder import RedisRespDecoder, need_more_data, Error 3 | from ..encoder import RedisRespEncoder 4 | from ..errors import CommunicationError, PipelinedExceptions 5 | from ..utils import get_command_name, is_multiple_commands 6 | 7 | 8 | # TODO (correctness) watch for manual SELECT and set_database ! 9 | 10 | 11 | not_allowed_push_commands = set([b"MONITOR", b"SUBSCRIBE", b"PSUBSCRIBE", b"UNSUBSCRIBE", b"PUNSUBSCRIBE"]) 12 | 13 | 14 | class TimeoutError(Exception): 15 | pass 16 | 17 | 18 | timeout_error = TimeoutError() 19 | 20 | 21 | class Connection: 22 | @classmethod 23 | def create(cls, username=None, password=None, client_name=None, resp_version=2, socket_factory="tcp", connect_retry=2, database=0, **kwargs): 24 | ret = cls() 25 | ret._init(username, password, client_name, resp_version, socket_factory, connect_retry, database, **kwargs) 26 | return ret 27 | 28 | def __init__(self): 29 | self._socket = None 30 | 31 | # TODO (api) client_name with connection pool (?) 32 | # TODO (documentation) the username/password/client_name need the decoding of whatever **kwargs is passed 33 | def _init(self, username=None, password=None, client_name=None, resp_version=2, socket_factory="tcp", connect_retry=2, database=0, **kwargs): 34 | resp_version = int(resp_version) 35 | connect_retry = int(connect_retry) 36 | database = int(database) 37 | 38 | if resp_version not in (-1, 2, 3): 39 | raise ValueError("Unsupported RESP protocol version %s" % resp_version) 40 | 41 | self._settings = kwargs 42 | 43 | environment = get_environment(**kwargs) 44 | connect_retry += 1 45 | while connect_retry: 46 | try: 47 | self._socket = environment.socket(socket_factory, **kwargs) 48 | break 49 | except Exception as e: 50 | connect_retry -= 1 51 | if not connect_retry: 52 | raise CommunicationError() from e 53 | self._encoder = RedisRespEncoder(**kwargs) 54 | self._decoder = RedisRespDecoder(**kwargs) 55 | self._seen_eof = False 56 | self._peername = self._socket.peername() 57 | self._seen_moved = False 58 | self._seen_ask = False 59 | self._allow_multi = False 60 | self._default_database = self._last_database = database 61 | 62 | connected = False 63 | # Try to negotiate RESP3 first if RESP2 is not forced 64 | if resp_version != 2: 65 | args = [b"HELLO", b"3"] 66 | if password is not None: 67 | if username is not None: 68 | args.extend((b"AUTH", username, password)) 69 | else: 70 | args.extend((b"AUTH", b"default", password)) 71 | if client_name: 72 | args.extend((b"SETNAME", client_name)) 73 | try: 74 | # TODO (misc) do something with the result ? 75 | self._command(*args) 76 | connected = True 77 | except Error as e: 78 | # This is to seperate an login error from the server not supporting RESP3 79 | if e.args[0].startswith("ERR "): 80 | if resp_version == 3: 81 | # TODO (misc) this want have a __cause__ is that ok ? what exception to throw here ? 82 | raise Exception("Server does not support RESP3 protocol") 83 | else: 84 | raise 85 | if not connected: 86 | if password: 87 | if username: 88 | self._command(b"AUTH", username, password) 89 | else: 90 | self._command(b"AUTH", password) 91 | if client_name: 92 | self._command(b"CLIENT", b"SETNAME", client_name) 93 | if database != 0: 94 | self._command(b"SELECT", database) 95 | 96 | def __del__(self): 97 | self.close() 98 | 99 | def close(self): 100 | if self._socket: 101 | try: 102 | self._socket.close() 103 | except Exception: 104 | pass 105 | self._socket = None 106 | self._encoder = None 107 | self._decoder = None 108 | 109 | # TODO (misc) better check ? (maybe it's closed, but the socket doesn't know it yet..., will be known the next time though) 110 | def closed(self): 111 | return self._socket is None 112 | 113 | def peername(self): 114 | return self._peername 115 | 116 | def _send(self, *cmd): 117 | try: 118 | if is_multiple_commands(*cmd): 119 | self._encoder.encode_multiple(*cmd) 120 | else: 121 | self._encoder.encode(*cmd) 122 | while True: 123 | data = self._encoder.extract() 124 | if data is None: 125 | break 126 | self._socket.send(data) 127 | except ValueError as e: 128 | raise 129 | except Exception as e: 130 | self.close() 131 | raise CommunicationError("I/O error while trying to send a command") from e 132 | except BaseException: 133 | self.close() 134 | raise 135 | 136 | # TODO (misc) should a decoding error be considered an CommunicationError ? 137 | def _recv(self, timeout=False): 138 | try: 139 | while True: 140 | res = self._decoder.extract() 141 | if res == need_more_data: 142 | if self._seen_eof: 143 | self.close() 144 | raise EOFError("Connection reached EOF") 145 | else: 146 | data = self._socket.recv(timeout) 147 | if data == b"": 148 | self._seen_eof = True 149 | elif data is None: 150 | return timeout_error 151 | else: 152 | # TODO This check if because another context can close us while we were reading (we can instead simply not remove self._decoder on close) 153 | if not self._decoder: 154 | raise Exception("Connection already closed") 155 | self._decoder.feed(data) 156 | continue 157 | return res 158 | except Exception as e: 159 | self.close() 160 | raise CommunicationError("Error while trying to read a reply") from e 161 | except BaseException: 162 | self.close() 163 | raise 164 | 165 | def pushed_message(self, timeout=False, decoder=False, attributes=None): 166 | orig_decoder = None 167 | if decoder != False or attributes is not None: 168 | orig_decoder = self._decoder 169 | kwargs = self._settings.copy() 170 | if decoder != False: 171 | kwargs["decoder"] = decoder 172 | if attributes is not None: 173 | kwargs["attributes"] = attributes 174 | self._decoder = RedisRespDecoder(**kwargs) 175 | try: 176 | res = self._recv(timeout) 177 | if res == timeout_error: 178 | return None 179 | return res 180 | finally: 181 | if orig_decoder is not None: 182 | self._decoder = orig_decoder 183 | 184 | def push_command(self, *cmd): 185 | self._send(*cmd) 186 | 187 | def set_database(self, database): 188 | if database is None: 189 | if self._default_database != self._last_database: 190 | self._command(b"SELECT", self._default_database) 191 | self._last_database = self._default_database 192 | else: 193 | if database != self._last_database: 194 | self._command(b"SELECT", database) 195 | self._last_database = database 196 | 197 | def __call__(self, *cmd, decoder=False, attributes=None, database=None, asking=False): 198 | if not cmd: 199 | raise ValueError("No command provided") 200 | orig_decoder = None 201 | if decoder != False or attributes is not None: 202 | orig_decoder = self._decoder 203 | kwargs = self._settings.copy() 204 | if decoder != False: 205 | kwargs["decoder"] = decoder 206 | if attributes is not None: 207 | kwargs["attributes"] = attributes 208 | self._decoder = RedisRespDecoder(**kwargs) 209 | try: 210 | self.set_database(database) 211 | if is_multiple_commands(*cmd): 212 | return self._commands(*cmd) 213 | else: 214 | if asking: 215 | self._command(b"ASKING") 216 | return self._command(*cmd) 217 | finally: 218 | if orig_decoder is not None: 219 | self._decoder = orig_decoder 220 | 221 | def _command(self, *cmd): 222 | command_name = get_command_name(cmd) 223 | if command_name in not_allowed_push_commands: 224 | raise ValueError("Command %s is not allowed to be called directly, use the appropriate API instead" % cmd) 225 | if command_name == b"MULTI" and not self._allow_multi: 226 | raise ValueError("Take a connection if you want to use MULTI command.") 227 | self._send(*cmd) 228 | res = self._recv() 229 | if isinstance(res, Error): 230 | if res.args[0].startswith("MOVED "): 231 | self._seen_moved = True 232 | if res.args[0].startswith("ASK "): 233 | _, _, address = res.args[0].split(" ") 234 | self._seen_ask = address 235 | raise res 236 | if res == timeout_error: 237 | self.close() 238 | raise timeout_error 239 | return res 240 | 241 | def _commands(self, *cmds): 242 | for cmd in cmds: 243 | command_name = get_command_name(cmd) 244 | if command_name in not_allowed_push_commands: 245 | raise ValueError("Command %s is not allowed to be called directly, use the appropriate API instead" % cmd) 246 | if command_name == b"MULTI" and not self._allow_multi: 247 | raise ValueError("Take a connection if you want to use MULTI command.") 248 | self._send(*cmds) 249 | res = [] 250 | found_errors = False 251 | for _ in cmds: 252 | try: 253 | result = self._recv() 254 | if isinstance(result, Error): 255 | if result.args[0].startswith("MOVED "): 256 | self.seen_moved = True 257 | found_errors = True 258 | if result == timeout_error: 259 | self.close() 260 | except Exception as e: 261 | result = e 262 | found_errors = True 263 | res.append(result) 264 | if found_errors: 265 | raise PipelinedExceptions(res) 266 | return res 267 | 268 | def seen_moved(self): 269 | if self._seen_moved: 270 | self._seen_moved = False 271 | return True 272 | return False 273 | 274 | def seen_asked(self): 275 | if self._seen_ask: 276 | ret = self._seen_ask 277 | self._seen_ask = False 278 | return ret 279 | return False 280 | 281 | def allow_multi(self, allow): 282 | self._allow_multi = allow 283 | -------------------------------------------------------------------------------- /justredis/sync/connectionpool.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from contextlib import contextmanager 3 | 4 | 5 | from .connection import Connection 6 | from ..errors import ConnectionPoolError 7 | from ..decoder import Error 8 | from .environment import get_environment 9 | 10 | 11 | # TODO (misc) can we relax the _lock ? 12 | 13 | 14 | class ConnectionPool: 15 | def __init__(self, max_connections=None, wait_timeout=None, **kwargs): 16 | self._max_connections = max_connections 17 | self._wait_timeout = wait_timeout 18 | self._connection_settings = kwargs 19 | 20 | self._lock = get_environment(**kwargs).lock() 21 | self._limit = get_environment(**kwargs).semaphore(max_connections) if max_connections else None 22 | self._connections_available = deque() 23 | self._connections_in_use = set() 24 | self._closed = False 25 | 26 | def __del__(self): 27 | self.close() 28 | 29 | def close(self): 30 | with self._lock: 31 | if not self._closed: 32 | # We do this first, so if another thread calls release it won't get back to the pool 33 | for connection in self._connections_available: 34 | connection.close() 35 | for connection in self._connections_in_use: 36 | connection.close() 37 | self._connections_available.clear() 38 | self._connections_in_use.clear() 39 | self._limit = get_environment(**self.connection_settings).semaphore(self._max_connections) if self._max_connections else None 40 | self._closed = True 41 | 42 | def take(self): 43 | if self._closed: 44 | raise ConnectionPoolError("Pool already closed") 45 | # TODO (correctness) cluster depends on this failing if closed ! guess we should add a health check 46 | try: 47 | while True: 48 | conn = self._connections_available.popleft() 49 | if not conn.closed(): 50 | break 51 | if self._limit is not None: 52 | self._limit.release() 53 | except IndexError: 54 | if self._limit is not None and not self._limit.acquire(self._wait_timeout): 55 | raise ConnectionPoolError("Could not acquire an connection form the pool") 56 | try: 57 | conn = Connection.create(**self._connection_settings) 58 | except Exception: 59 | if self._limit is not None: 60 | self._limit.release() 61 | raise 62 | self._connections_in_use.add(conn) 63 | return conn 64 | 65 | def release(self, conn): 66 | with self._lock: 67 | try: 68 | self._connections_in_use.remove(conn) 69 | # TODO (correctness) should we release the self._limit here as well ? (or just make close forever) 70 | # If this fails, it's a connection from a previous cycle, don't reuse it 71 | except KeyError: 72 | conn.close() 73 | return 74 | if not conn.closed(): 75 | self._connections_available.append(conn) 76 | elif self._limit is not None: 77 | self._limit.release() 78 | 79 | def __call__(self, *cmd, **kwargs): 80 | if not cmd: 81 | raise ValueError("No command provided") 82 | conn = self.take() 83 | try: 84 | return conn(*cmd, **kwargs) 85 | finally: 86 | self.release(conn) 87 | 88 | @contextmanager 89 | def connection(self, **kwargs): 90 | conn = self.take() 91 | try: 92 | conn.allow_multi(True) 93 | yield conn 94 | finally: 95 | # We need to clean up the connection back to a normal state. 96 | try: 97 | conn._command(b"DISCARD") 98 | except Exception: 99 | pass 100 | conn.allow_multi(False) 101 | self.release(conn) 102 | 103 | def endpoints(self): 104 | conn = self.take() 105 | try: 106 | return [(conn.peername(), {"type": "regular"})] 107 | finally: 108 | self.release(conn) 109 | -------------------------------------------------------------------------------- /justredis/sync/environment.py: -------------------------------------------------------------------------------- 1 | from .environments.threaded import ThreadedEnvironment 2 | 3 | 4 | def get_environment(environment=ThreadedEnvironment, **kargs): 5 | if environment == "threaded": 6 | environment = ThreadedEnvironment 7 | return environment 8 | -------------------------------------------------------------------------------- /justredis/sync/environments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzickel/justredis/12c63bba8a83f1c7934d34ea081ada6efbc563ad/justredis/sync/environments/__init__.py -------------------------------------------------------------------------------- /justredis/sync/environments/threaded.py: -------------------------------------------------------------------------------- 1 | from threading import Lock, Semaphore 2 | import socket 3 | import sys 4 | import ssl 5 | 6 | 7 | platform = "" 8 | if sys.platform.startswith("linux"): 9 | platform = "linux" 10 | elif sys.platform.startswith("darwin"): 11 | platform = "darwin" 12 | elif sys.platform.startswith("win"): 13 | platform = "windows" 14 | 15 | 16 | def tcpsocket(address=None, connect_timeout=None, socket_timeout=None, tcp_keepalive=None, tcp_nodelay=True, **kwargs): 17 | if address is None: 18 | address = ("localhost", 6379) 19 | sock = socket.create_connection(address, connect_timeout) # AWAIT 20 | sock.settimeout(socket_timeout) 21 | 22 | if tcp_nodelay is not None: 23 | if tcp_nodelay: 24 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 25 | else: 26 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) 27 | if tcp_keepalive: 28 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 29 | if platform == "linux": 30 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keepalive) 31 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keepalive // 3) 32 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) 33 | elif platform == "darwin": 34 | sock.setsockopt(socket.IPPROTO_TCP, 0x10, tcp_keepalive // 3) 35 | elif platform == "windows": 36 | sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, tcp_keepalive * 1000, tcp_keepalive // 3 * 1000)) 37 | else: 38 | # TODO (misc) warning maybe instead ? 39 | raise NotImplementedError("Unknown platform, cannot set tcp_keepalive") 40 | return sock 41 | 42 | 43 | def unixsocket(address=None, connect_timeout=None, socket_timeout=None, **kwargs): 44 | if address is None: 45 | address = "/tmp/redis.sock" 46 | sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 47 | sock.settimeout(connect_timeout) 48 | sock.connect(address) # AWAIT 49 | sock.settimeout(socket_timeout) 50 | return sock 51 | 52 | 53 | # TODO (misc) should we enable server hostname enforcment ? give it as an option ? what about cluster ? 54 | def sslsocket(address=None, ssl_context=None, **kwargs): 55 | if address is None: 56 | address = ("localhost", 6379) 57 | sock = tcpsocket(address=address, **kwargs) 58 | if ssl_context is None: 59 | ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 60 | cafile = kwargs.get("ssl_cafile") 61 | if cafile: 62 | ssl_context.load_verify_locations(cafile) 63 | certfile = kwargs.get("ssl_certfile") 64 | keyfile = kwargs.get("ssl_keyfile") 65 | if certfile: 66 | ssl_context.load_cert_chain(certfile, keyfile) 67 | return ssl_context.wrap_socket(sock, server_hostname=address[0]) 68 | 69 | 70 | class SocketWrapper: 71 | def __init__(self, socket_factory, buffersize=2 ** 16, **kwargs): 72 | self._buffer = bytearray(buffersize) 73 | self._view = memoryview(self._buffer) 74 | self._socket = socket_factory(**kwargs) 75 | 76 | def close(self): 77 | self._socket.close() 78 | 79 | def send(self, data): 80 | self._socket.sendall(data) # AWAIT 81 | 82 | # If you override this, make sure to return an empty bytes for EOF and a None for timeout ! 83 | def recv(self, timeout=False): 84 | if timeout != False: 85 | old_timeout = self._socket.gettimeout() 86 | self._socket.settimeout(timeout) 87 | try: 88 | r = self._socket.recv_into(self._buffer) # AWAIT 89 | except socket.timeout: 90 | return None 91 | finally: 92 | # TODO (misc) if the socket is closed already may this fail again use else ? 93 | self._socket.settimeout(old_timeout) 94 | else: 95 | r = self._socket.recv_into(self._buffer) # AWAIT 96 | return self._view[:r] 97 | 98 | def peername(self): 99 | peername = self._socket.getpeername() 100 | # TODO (misc) is there a lib where this is not the case ?, we can also just return the peername in the connect functions. 101 | if self._socket.family == socket.AF_INET6: 102 | peername = peername[:2] 103 | return peername 104 | 105 | 106 | class OurSemaphore: 107 | def __init__(self, value): 108 | self._semaphore = Semaphore(value) 109 | 110 | def release(self): 111 | self._semaphore.release() 112 | 113 | def acquire(self, timeout=None): 114 | self._semaphore.acquire(True, timeout) # AWAIT 115 | 116 | 117 | class OurLock: 118 | def __init__(self): 119 | self._lock = Lock() 120 | 121 | def __enter__(self): 122 | self._lock.acquire() # AWAIT 123 | 124 | def __exit__(self, *args): 125 | self._lock.release() # AWAIT 126 | 127 | 128 | class ThreadedEnvironment: 129 | @staticmethod 130 | def socket(socket_type="tcp", **kwargs): 131 | if socket_type == "tcp": 132 | socket_type = tcpsocket 133 | elif socket_type == "unix": 134 | socket_type = unixsocket 135 | elif socket_type == "ssl": 136 | socket_type = sslsocket 137 | else: 138 | raise NotImplementedError("Unknown socket type: %s" % socket_type) 139 | return SocketWrapper(socket_type, **kwargs) 140 | 141 | @staticmethod 142 | def semaphore(limit): 143 | return OurSemaphore(limit) 144 | 145 | @staticmethod 146 | def lock(): 147 | return OurLock() 148 | -------------------------------------------------------------------------------- /justredis/sync/redis.py: -------------------------------------------------------------------------------- 1 | from .connectionpool import ConnectionPool 2 | from .cluster import ClusterConnectionPool 3 | from ..decoder import Error 4 | from ..utils import parse_url, merge_dicts 5 | 6 | 7 | # TODO (misc) document all the kwargs everywhere 8 | # TODO (api) internal remove from connectionpool the __enter__/__exit__ and use take(**kwargs)/release 9 | 10 | 11 | # We do this seperation to allow changing per command and connection settings easily 12 | class ModifiedRedis: 13 | def __init__(self, connection_pool, custom_command_class=None, **kwargs): 14 | self._connection_pool = connection_pool 15 | self._custom_command_class = custom_command_class 16 | self._custom_command = custom_command_class(self) if self._custom_command_class else None 17 | self._settings = kwargs 18 | 19 | def __del__(self): 20 | self.close() 21 | 22 | def close(self): 23 | self._connection_pool = self._settings = None 24 | 25 | def __enter__(self): 26 | return self 27 | 28 | def __exit__(self, *args): 29 | self.close() 30 | 31 | def __call__(self, *cmd, **kwargs): 32 | settings = merge_dicts(self._settings, kwargs) 33 | if settings is None: 34 | return self._connection_pool(*cmd) 35 | else: 36 | return self._connection_pool(*cmd, **settings) 37 | 38 | def connection(self, *args, push=False, **kwargs): 39 | if args: 40 | raise ValueError("Please specify the connection arguments as named arguments (i.e. push=..., key=...)") 41 | wrapper = PushConnection if push else Connection 42 | return wrapper.create(self._connection_pool.connection(**kwargs), **self._settings) 43 | 44 | def endpoints(self): 45 | return self._connection_pool.endpoints() 46 | 47 | def modify(self, **kwargs): 48 | settings = self._settings.copy() 49 | settings.update(kwargs) 50 | return ModifiedRedis(self._connection_pool, custom_command_class=self._custom_command_class, **settings) 51 | 52 | def __getattr__(self, attribute): 53 | if not self._custom_command: 54 | raise AttributeError("No such attribute: %s" % attribute) 55 | return getattr(self._custom_command, attribute) 56 | 57 | 58 | # TODO (api) should we implement an callback for when slots have changed ? 59 | class Redis(ModifiedRedis): 60 | @classmethod 61 | def from_url(cls, url, **kwargs): 62 | res = parse_url(url) 63 | res.update(kwargs) 64 | return cls(**res) 65 | 66 | def __init__(self, pool_factory=ClusterConnectionPool, custom_command_class=None, **kwargs): 67 | """ 68 | Currently documented in README.md 69 | """ 70 | self._connection_pool = None 71 | if pool_factory == "pool": 72 | pool_factory = ConnectionPool 73 | elif pool_factory in ("auto", "cluster"): 74 | pool_factory = ClusterConnectionPool 75 | if not hasattr(pool_factory, "__call__"): 76 | raise AttributeError("A valid pool_factory is required, if you want to set address, use .from_url() or address=(host, port)") 77 | super(Redis, self).__init__(pool_factory(**kwargs), custom_command_class=custom_command_class) 78 | 79 | def __del__(self): 80 | self.close() 81 | 82 | def close(self): 83 | if self._connection_pool: 84 | self._connection_pool.close() 85 | self._connection_pool = None 86 | 87 | 88 | # TODO (api) add a modified_class here as well. 89 | class ModifiedConnection: 90 | def __init__(self, connection, **kwargs): 91 | self._connection = connection 92 | self._settings = kwargs 93 | 94 | def __del__(self): 95 | self.close() 96 | 97 | def close(self): 98 | self._connection = self._settings = None 99 | 100 | def __enter__(self): 101 | return self 102 | 103 | def __exit__(self, *args): 104 | self.close() 105 | 106 | def __call__(self, *cmd, **kwargs): 107 | settings = merge_dicts(self._settings, kwargs) 108 | if settings is None: 109 | return self._connection(*cmd) 110 | else: 111 | return self._connection(*cmd, **settings) 112 | 113 | def modify(self, **kwargs): 114 | settings = self._settings.copy() 115 | settings.update(kwargs) 116 | return ModifiedConnection(self._connection, **settings) 117 | 118 | 119 | class Connection(ModifiedConnection): 120 | @classmethod 121 | def create(cls, connection, **kwargs): 122 | conn = connection.__enter__() 123 | ret = cls(conn, **kwargs) 124 | ret._connection_context = connection 125 | return ret 126 | 127 | def __init__(self, connection, **kwargs): 128 | super(Connection, self).__init__(connection, **kwargs) 129 | 130 | def __del__(self): 131 | self.close() 132 | 133 | def close(self): 134 | if self._connection_context: 135 | # TODO (correctness) is this correct? 136 | self._connection_context.__exit__(None, None, None) 137 | self._connection = None 138 | self._connection_context = None 139 | self._settings = None 140 | 141 | 142 | class PushConnection(Connection): 143 | def close(self): 144 | if self._connection: 145 | # We close the connection here, since it's both hard to reset the state of the connection, and this is usually not done / at low frequency. 146 | self._connection.close() 147 | super(PushConnection, self).close() 148 | 149 | def __call__(self, *cmd, **kwargs): 150 | settings = merge_dicts(self._settings, kwargs) 151 | if settings is None: 152 | return self._connection.push_command(*cmd) 153 | else: 154 | return self._connection.push_command(*cmd, **settings) 155 | 156 | def next_message(self, *args, timeout=None, **kwargs): 157 | if args: 158 | raise ValueError("Please specify the next_message arguments as named arguments (i.e. timeout=...)") 159 | settings = merge_dicts(self._settings, kwargs) 160 | if settings is None: 161 | return self._connection.pushed_message(timeout=timeout) 162 | else: 163 | return self._connection.pushed_message(timeout=timeout, **settings) 164 | 165 | def __iter__(self): 166 | return self 167 | 168 | def __next__(self): 169 | return self._connection.pushed_message() 170 | -------------------------------------------------------------------------------- /justredis/utils.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlparse, parse_qsl 2 | 3 | 4 | # TODO (correctness) do I need to url encoding escape something ? 5 | # TODO (misc) the validation and conversion from strings should be done at the other side 6 | def parse_url(url): 7 | result = urlparse(url) 8 | res = {} 9 | 10 | if result.scheme == "redis": 11 | res["socket_factory"] = "tcp" 12 | elif result.scheme == "redis-socket" or result.scheme == "unix": 13 | res["socket_factory"] = "unix" 14 | elif result.scheme == "rediss" or result.scheme == "ssl": 15 | res["socket_factory"] = "ssl" 16 | else: 17 | raise NotImplementedError("Not implmented connection scheme: %s" % result.scheme) 18 | 19 | if result.username: 20 | if result.password: 21 | res["username"] = result.username 22 | res["password"] = result.password 23 | else: 24 | res["password"] = result.username 25 | 26 | if res["socket_factory"] == "unix": 27 | res["address"] = result.path 28 | else: 29 | addresses = result.netloc.split("@")[-1].split(",") 30 | parsed_addresses = [] 31 | for address in addresses: 32 | data = address.split(":", 1) 33 | if data[0] == "": 34 | data[0] = "localhost" 35 | if len(data) == 1: 36 | parsed_addresses.append((data[0], 6379)) 37 | else: 38 | parsed_addresses.append((data[0], int(data[1]))) 39 | 40 | if len(parsed_addresses) == 1: 41 | res["address"] = parsed_addresses[0] 42 | else: 43 | res["addresses"] = parsed_addresses 44 | 45 | if result.path and result.path != "/": 46 | res["database"] = result.path[1:] 47 | 48 | if result.query: 49 | res.update(dict(parse_qsl(result.query))) 50 | 51 | return res 52 | 53 | 54 | def merge_dicts(parent, child): 55 | if not parent and not child: 56 | return None 57 | elif not parent: 58 | return child 59 | elif not child: 60 | return parent 61 | tmp = parent.copy() 62 | tmp.update(child) 63 | return tmp 64 | 65 | 66 | # TODO (misc) can we do all those commands better, maybe with a special class for CustomCommand parameters? 67 | def get_command_name(cmd): 68 | cmd = cmd[0] 69 | cmd = cmd.upper() 70 | if isinstance(cmd, str): 71 | cmd = cmd.encode() 72 | return cmd 73 | 74 | 75 | def is_multiple_commands(*cmd): 76 | if isinstance(cmd[0], (tuple, list)): 77 | return True 78 | else: 79 | return False 80 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = justredis 3 | version = 0.0.1a3 4 | description = A Redis client for Python supporting many Redis features and Python synchronous (Python 3.5+) and asynchronous (Python 3.6+) communication. 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | keywords = redis 8 | author = tzickel 9 | project_urls = 10 | Source code = https://github.com/tzickel/justredis 11 | Issue tracker = https://github.com/tzickel/justredis/issues 12 | license = MIT 13 | url = https://github.com/tzickel/justredis 14 | classifiers = 15 | Development Status :: 3 - Alpha 16 | License :: OSI Approved :: MIT License 17 | Programming Language :: Python :: Implementation :: CPython 18 | Programming Language :: Python :: Implementation :: PyPy 19 | Programming Language :: Python :: 3 :: Only 20 | Programming Language :: Python :: 3.5 21 | Programming Language :: Python :: 3.6 22 | Programming Language :: Python :: 3.7 23 | Programming Language :: Python :: 3.8 24 | 25 | [options] 26 | packages = find: 27 | python_requires = >= 3.5 28 | 29 | [options.packages.find] 30 | exclude = 31 | tests 32 | 33 | [tool:pytest] 34 | testpaths = tests 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzickel/justredis/12c63bba8a83f1c7934d34ea081ada6efbc563ad/tests/__init__.py -------------------------------------------------------------------------------- /tests/async/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzickel/justredis/12c63bba8a83f1c7934d34ea081ada6efbc563ad/tests/async/__init__.py -------------------------------------------------------------------------------- /tests/async/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | from justredis import AsyncRedis 5 | 6 | # TODO (misc) in the future allow direct redis instances ? 7 | # TODO (misc) enforce REDIS_PATH as input, don't blindly accept it (so tests run twice) 8 | 9 | 10 | def get_runtime_params_for_redis(dockerimage="redis"): 11 | redis_6_path = os.getenv("REDIS_6_PATH") 12 | redis_5_path = os.getenv("REDIS_5_PATH") 13 | if dockerimage in ("redis", "redis:6") and redis_6_path: 14 | return {"extrapath": redis_6_path} 15 | elif dockerimage == "redis:5" and redis_5_path: 16 | return {"extrapath": os.getenv("REDIS_5_PATH")} 17 | elif os.getenv("REDIS_USE_DOCKER") is not None: 18 | return {"dockerimage": dockerimage} 19 | else: 20 | return {} 21 | 22 | 23 | async def redis_with_client(dockerimage="redis", extraparams="", **kwargs): 24 | from .. import redis_server 25 | 26 | if isinstance(dockerimage, (tuple, list)): 27 | dockerimage = dockerimage[0] 28 | instance = redis_server.RedisServer(extraparams=extraparams, **get_runtime_params_for_redis(dockerimage)) 29 | try: 30 | async with AsyncRedis(address=("localhost", instance.port), resp_version=-1, **kwargs) as r: 31 | yield r 32 | finally: 33 | instance.close() 34 | 35 | 36 | async def redis_cluster_with_client(dockerimage="redis", extraparams=""): 37 | from .. import redis_server 38 | 39 | if isinstance(dockerimage, (tuple, list)): 40 | dockerimage = dockerimage[0] 41 | servers, stdout = redis_server.start_cluster(3, extraparams=extraparams, **get_runtime_params_for_redis(dockerimage)) 42 | try: 43 | async with AsyncRedis(address=("localhost", servers[0].port), resp_version=-1) as r: 44 | import anyio 45 | 46 | wait = 60 47 | while wait: 48 | result = await r(b"CLUSTER", b"INFO", endpoint="masters") 49 | ready = True 50 | for res in result.values(): 51 | if isinstance(res, Exception): 52 | raise res 53 | if b"cluster_state:ok" not in res: 54 | ready = False 55 | break 56 | if ready: 57 | break 58 | await anyio.sleep(1) 59 | wait -= 1 60 | if not wait: 61 | raise Exception("Cluster is down, could not run test") 62 | yield r 63 | finally: 64 | for server in servers: 65 | server.close() 66 | 67 | 68 | # TODO (misc) No better way to do it pytest ? 69 | def generate_fixture_params(cluster=True): 70 | params = [] 71 | versions = ("5", "6") 72 | for version in versions: 73 | if cluster != "only": 74 | params.append(("redis:%s" % version, False)) 75 | if cluster: 76 | params.append(("redis:%s" % version, True)) 77 | return params 78 | 79 | 80 | @pytest.fixture(params=generate_fixture_params()) 81 | async def client(request): 82 | if request.param[1]: 83 | async for item in redis_cluster_with_client(request.param[0]): 84 | yield item 85 | else: 86 | async for item in redis_with_client(request.param[0]): 87 | yield item 88 | 89 | 90 | @pytest.fixture(params=generate_fixture_params("only")) 91 | async def cluster_client(request): 92 | if request.param[1]: 93 | async for item in redis_cluster_with_client(request.param[0]): 94 | yield item 95 | else: 96 | async for item in redis_with_client(request.param[0]): 97 | yield item 98 | 99 | 100 | @pytest.fixture(params=generate_fixture_params(False)) 101 | async def no_cluster_client(request): 102 | if request.param[1]: 103 | async for item in redis_cluster_with_client(request.param[0]): 104 | yield item 105 | else: 106 | async for item in redis_with_client(request.param[0]): 107 | yield item 108 | 109 | 110 | @pytest.fixture(params=generate_fixture_params()) 111 | async def client_with_blah_password(request): 112 | async for item in redis_with_client(request.param, extraparams="--requirepass blah", password="blah"): 113 | yield item 114 | -------------------------------------------------------------------------------- /tests/async/test_async.py: -------------------------------------------------------------------------------- 1 | try: 2 | import anyio 3 | except: 4 | pass 5 | import pytest 6 | from justredis import AsyncRedis, Error, CommunicationError 7 | 8 | 9 | @pytest.mark.anyio 10 | async def test_connection_error(): 11 | with pytest.raises(CommunicationError): 12 | async with AsyncRedis(address=("127.0.0.222", 11121)) as r: 13 | await r("set", "a", "b") 14 | 15 | 16 | @pytest.mark.anyio 17 | async def test_auth(client_with_blah_password): 18 | address = (await client_with_blah_password.endpoints())[0][0] 19 | # No password 20 | async with AsyncRedis(address=address) as r: 21 | with pytest.raises(Error) as exc_info: 22 | await r("set", "auth_a", "b") 23 | assert exc_info.value.args[0].startswith("NOAUTH ") 24 | 25 | # Wrong password 26 | async with AsyncRedis(address=address, password="nop") as r: 27 | with pytest.raises(Error) as exc_info: 28 | await r("set", "auth_a", "b") 29 | # Changes between Redis 5 and Redis 6 30 | assert exc_info.value.args[0].startswith("WRONGPASS ") or exc_info.value.args[0].startswith("ERR invalid password") 31 | 32 | # Correct password 33 | async with AsyncRedis(address=address, password="blah") as r: 34 | assert await r("set", "auth_a", "b") == b"OK" 35 | 36 | 37 | @pytest.mark.anyio 38 | async def test_simple(client): 39 | r = client 40 | assert await r("set", "simple_a", "a") == b"OK" 41 | assert await r("set", "simple_b", "b") == b"OK" 42 | assert await r("set", "simple_c", "c") == b"OK" 43 | assert await r("set", "simple_{a}b", "d") == b"OK" 44 | assert await r("get", "simple_a") == b"a" 45 | assert await r("get", "simple_b") == b"b" 46 | assert await r("get", "simple_c") == b"c" 47 | assert await r("get", "simple_{a}b") == b"d" 48 | 49 | 50 | @pytest.mark.anyio 51 | async def test_modify_database(no_cluster_client): 52 | r = no_cluster_client 53 | await r("set", "modify_database_a_0", "a") 54 | # TODO (api) is this ok ? 55 | async with await r.modify(database=1).connection(key="a") as c: 56 | assert await c("get", "modify_database_a_0") == None 57 | assert await c("set", "modify_database_a_1", "a") == b"OK" 58 | 59 | 60 | @pytest.mark.anyio 61 | async def test_modify_database_cluster(cluster_client): 62 | r = cluster_client 63 | await r("set", "modify_database_cluster_a_0", "a") 64 | with pytest.raises(Error) as exc_info: 65 | async with await r.modify(database=1).connection(key="a") as c: 66 | assert await c("get", "modify_database_a_0") == None 67 | assert await c("set", "modify_database_a_1", "a") == b"OK" 68 | assert exc_info.value.args[0] == ("ERR SELECT is not allowed in cluster mode") 69 | 70 | 71 | @pytest.mark.anyio 72 | async def test_notallowed(client): 73 | r = client 74 | with pytest.raises(Error) as exc_info: 75 | await r("auth", "asd") 76 | assert exc_info.value.args[0].startswith("ERR ") 77 | 78 | 79 | @pytest.mark.anyio 80 | async def test_some_encodings(client): 81 | r = client 82 | with pytest.raises(ValueError): 83 | await r("set", "a", True) 84 | assert await r("incrbyfloat", "float_check", 0.1) == b"0.1" 85 | with pytest.raises(ValueError): 86 | await r("set", "a", [1, 2]) 87 | await r("set", "{check}_a", "a") 88 | await r("set", "{check}_b", "b") 89 | assert await r("get", "{check}_a", decoder="utf8") == "a" 90 | assert await r("mget", "{check}_a", "{check}_b", decoder="utf8") == ["a", "b"] 91 | 92 | 93 | @pytest.mark.anyio 94 | async def test_chunk_encoded_command(client): 95 | r = client 96 | assert await r("set", "test_chunk_encoded_command_a", b"test_chunk_encoded_command_a" * 10 * 1024) == b"OK" 97 | assert await r("get", "test_chunk_encoded_command_a") == b"test_chunk_encoded_command_a" * 10 * 1024 98 | assert await r("mget", "test_chunk_encoded_command_a" * 3500, "test_chunk_encoded_command_a" * 3500, "test_chunk_encoded_command_a" * 3500) == [None, None, None] 99 | 100 | 101 | @pytest.mark.anyio 102 | async def test_eval(client): 103 | r = client 104 | assert await r("set", "evaltest", "a") == b"OK" 105 | assert await r("eval", "return redis.call('get',KEYS[1])", 1, "evaltest") == b"a" 106 | assert await r("eval", "return redis.call('get',KEYS[1])", 1, "evaltestno") == None 107 | assert await r("eval", "return redis.call('get',KEYS[1])", 1, "evaltest") == b"a" 108 | assert await r("eval", "return redis.call('get',KEYS[1])", 1, "evaltestno") == None 109 | assert await r("script", "flush") == b"OK" 110 | assert await r("eval", "return redis.call('get',KEYS[1])", 1, "evaltest") == b"a" 111 | assert await r("eval", "return redis.call('get',KEYS[1])", 1, "evaltestno") == None 112 | 113 | 114 | @pytest.mark.anyio 115 | async def test_pipeline(client): 116 | r = client 117 | assert await r(("set", "abc", "def"), ("get", "abc")) == [b"OK", b"def"] 118 | 119 | 120 | # TODO (misc) add some extra checks here for invalid states 121 | @pytest.mark.anyio 122 | async def test_multi(client): 123 | r = client 124 | async with await r.connection(key="a") as c: 125 | await c("multi") 126 | await c("set", "a", "b") 127 | await c("get", "a") 128 | assert await c("exec") == [b"OK", b"b"] 129 | 130 | # TODO (misc) kinda lame 131 | try: 132 | async with await r.modify(database=2).connection(key="a") as c1: 133 | await c1("multi") 134 | await c1("set", "a", "b") 135 | async with await r.modify(database=3).connection(key="a") as c2: 136 | await c2("multi") 137 | await c2("set", "a1", "c") 138 | await c2("get", "a") 139 | assert await c2("exec") == [b"OK", None] 140 | await c1("mget", "a", "a1") 141 | assert await c1("exec") == [b"OK", [b"b", None]] 142 | assert await r.modify(database=2)("get", "a") == b"b" 143 | except Error as e: 144 | if e.args[0] == "ERR SELECT is not allowed in cluster mode": 145 | pass 146 | else: 147 | raise 148 | 149 | 150 | @pytest.mark.anyio 151 | async def test_multidiscard(client): 152 | r = client 153 | async with await r.connection(key="a") as c: 154 | await c("multi") 155 | with pytest.raises(Error): 156 | await c("nothing") 157 | await c("discard") 158 | await c("multi") 159 | await c("set", "a", "b") 160 | assert await c("exec") == [b"OK"] 161 | 162 | 163 | @pytest.mark.anyio 164 | async def test_pubsub(client): 165 | r = client 166 | async with await r.connection(push=True) as p: 167 | await p("subscribe", "hi") 168 | await p("psubscribe", b"bye") 169 | assert await p.next_message() == [b"subscribe", b"hi", 1] 170 | assert await p.next_message() == [b"psubscribe", b"bye", 2] 171 | assert await p.next_message(timeout=0.1) == None 172 | await r("publish", "hi", "there") 173 | assert await p.next_message(timeout=0.1) == [b"message", b"hi", b"there"] 174 | await r("publish", "bye", "there") 175 | assert await p.next_message(timeout=0.1) == [b"pmessage", b"bye", b"bye", b"there"] 176 | await p("ping") 177 | # RESP2 and RESP3 behave differently here, so check for both 178 | assert await p.next_message() in (b"PONG", [b"pong", b""]) 179 | await p("ping", b"hi") 180 | assert await p.next_message() in (b"hi", [b"pong", b"hi"]) 181 | await p("unsubscribe", "hi") 182 | 183 | 184 | @pytest.mark.anyio 185 | async def test_misc(client): 186 | r = client 187 | # This tests an command which redis server says keys start in index 2. 188 | await r("object", "help") 189 | # Check command with no keys 190 | await r("client", "list") 191 | 192 | 193 | @pytest.mark.anyio 194 | async def test_server_no_cluster(no_cluster_client): 195 | r = no_cluster_client 196 | await r("set", "cluster_aa", "a") == b"OK" 197 | await r("set", "cluster_bb", "b") == b"OK" 198 | await r("set", "cluster_cc", "c") == b"OK" 199 | result = await r("keys", "cluster_*", endpoint="masters") 200 | result = list(result.values()) 201 | result = [i for s in result for i in s] 202 | assert set(result) == set([b"cluster_aa", b"cluster_bb", b"cluster_cc"]) 203 | 204 | 205 | @pytest.mark.anyio 206 | async def test_server_cluster(cluster_client): 207 | r = cluster_client 208 | # TODO (misc) split keys to 3 comps 209 | await r("set", "cluster_aa", "a") == b"OK" 210 | await r("set", "cluster_bb", "b") == b"OK" 211 | await r("set", "cluster_cc", "c") == b"OK" 212 | result = await r("keys", "cluster_*", endpoint="masters") 213 | assert len(result) == 3 214 | result = list(result.values()) 215 | result = [i for s in result for i in s] 216 | assert set(result) == set([b"cluster_aa", b"cluster_bb", b"cluster_cc"]) 217 | 218 | 219 | @pytest.mark.anyio 220 | async def test_moved_no_cluster(no_cluster_client): 221 | r = no_cluster_client 222 | await r("set", "aa", "a") == b"OK" 223 | await r("set", "bb", "b") == b"OK" 224 | await r("set", "cc", "c") == b"OK" 225 | result = await r("get", "aa", endpoint="masters") 226 | result = list(result.values()) 227 | assert result == [b"a"] 228 | 229 | 230 | @pytest.mark.anyio 231 | async def test_moved_cluster(cluster_client): 232 | r = cluster_client 233 | await r("set", "aa", "a") == b"OK" 234 | await r("set", "bb", "b") == b"OK" 235 | await r("set", "cc", "c") == b"OK" 236 | assert await r("get", "aa") == b"a" 237 | assert await r("get", "bb") == b"b" 238 | assert await r("get", "cc") == b"c" 239 | result = await r("get", "aa", endpoint="masters") 240 | result = list(result.values()) 241 | assert b"a" in result 242 | assert len([x for x in result if isinstance(x, Error) and x.args[0].startswith("MOVED ")]) == 2 243 | 244 | 245 | @pytest.mark.anyio 246 | async def test_cancel(client): 247 | r = client 248 | 249 | async with anyio.create_task_group() as tg: 250 | await tg.spawn(r, "blpop", "a", 20) 251 | await anyio.sleep(1) 252 | await tg.cancel_scope.cancel() 253 | 254 | await anyio.sleep(1) 255 | async with anyio.create_task_group() as tg: 256 | await tg.spawn(r, "blpop", "a", 1) 257 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import warnings 4 | 5 | 6 | # Disable async tests if anyio is not installed 7 | try: 8 | import anyio 9 | except: 10 | anyio = None 11 | 12 | 13 | @pytest.mark.hookwrapper 14 | def pytest_ignore_collect(path, config): 15 | outcome = yield 16 | if "async" in str(path) and anyio is None: 17 | warnings.warn("Skipping async tests because AnyIO is not installed.") 18 | outcome.force_result(True) 19 | 20 | 21 | from justredis import Redis 22 | 23 | # TODO (misc) in the future allow direct redis instances ? 24 | # TODO (misc) enforce REDIS_PATH as input, don't blindly accept it (so tests run twice) 25 | 26 | 27 | def get_runtime_params_for_redis(dockerimage="redis"): 28 | redis_6_path = os.getenv("REDIS_6_PATH") 29 | redis_5_path = os.getenv("REDIS_5_PATH") 30 | if dockerimage in ("redis", "redis:6") and redis_6_path: 31 | return {"extrapath": redis_6_path} 32 | elif dockerimage == "redis:5" and redis_5_path: 33 | return {"extrapath": os.getenv("REDIS_5_PATH")} 34 | elif os.getenv("REDIS_USE_DOCKER"): 35 | return {"dockerimage": dockerimage} 36 | else: 37 | return {} 38 | 39 | 40 | def redis_with_client(dockerimage="redis", extraparams="", **kwargs): 41 | from . import redis_server 42 | 43 | if isinstance(dockerimage, (tuple, list)): 44 | dockerimage = dockerimage[0] 45 | instance = redis_server.RedisServer(extraparams=extraparams, **get_runtime_params_for_redis(dockerimage)) 46 | with Redis(address=("localhost", instance.port), resp_version=-1, **kwargs) as r: 47 | try: 48 | yield r 49 | finally: 50 | instance.close() 51 | 52 | 53 | def redis_cluster_with_client(dockerimage="redis", extraparams=""): 54 | from . import redis_server 55 | 56 | if isinstance(dockerimage, (tuple, list)): 57 | dockerimage = dockerimage[0] 58 | servers, stdout = redis_server.start_cluster(3, extraparams=extraparams, **get_runtime_params_for_redis(dockerimage)) 59 | try: 60 | with Redis(address=("localhost", servers[0].port), resp_version=-1) as r: 61 | import time 62 | 63 | wait = 60 64 | while wait: 65 | result = r(b"CLUSTER", b"INFO", endpoint="masters") 66 | ready = True 67 | for res in result.values(): 68 | if isinstance(res, Exception): 69 | raise res 70 | if b"cluster_state:ok" not in res: 71 | ready = False 72 | break 73 | if ready: 74 | break 75 | time.sleep(1) 76 | wait -= 1 77 | if not wait: 78 | raise Exception("Cluster is down, could not run test") 79 | yield r 80 | finally: 81 | for server in servers: 82 | server.close() 83 | 84 | 85 | # TODO (misc) No better way to do it pytest ? 86 | def generate_fixture_params(cluster=True): 87 | params = [] 88 | versions = ("5", "6") 89 | for version in versions: 90 | if cluster != "only": 91 | params.append(("redis:%s" % version, False)) 92 | if cluster: 93 | params.append(("redis:%s" % version, True)) 94 | return params 95 | 96 | 97 | @pytest.fixture(scope="module", params=generate_fixture_params()) 98 | def client(request): 99 | if request.param[1]: 100 | for item in redis_cluster_with_client(request.param[0]): 101 | yield item 102 | else: 103 | for item in redis_with_client(request.param[0]): 104 | yield item 105 | 106 | 107 | @pytest.fixture(scope="module", params=generate_fixture_params("only")) 108 | def cluster_client(request): 109 | if request.param[1]: 110 | for item in redis_cluster_with_client(request.param[0]): 111 | yield item 112 | else: 113 | for item in redis_with_client(request.param[0]): 114 | yield item 115 | 116 | 117 | @pytest.fixture(scope="module", params=generate_fixture_params(False)) 118 | def no_cluster_client(request): 119 | if request.param[1]: 120 | for item in redis_cluster_with_client(request.param[0]): 121 | yield item 122 | else: 123 | for item in redis_with_client(request.param[0]): 124 | yield item 125 | 126 | 127 | @pytest.fixture(scope="module", params=generate_fixture_params()) 128 | def client_with_blah_password(request): 129 | for item in redis_with_client(request.param, extraparams="--requirepass blah", password="blah"): 130 | yield item 131 | -------------------------------------------------------------------------------- /tests/redis_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import signal 4 | import subprocess 5 | import sys 6 | 7 | 8 | platform = "" 9 | if sys.platform.startswith("linux"): 10 | platform = "linux" 11 | elif sys.platform.startswith("darwin"): 12 | platform = "darwin" 13 | elif sys.platform.startswith("win"): 14 | platform = "windows" 15 | 16 | 17 | # TODO (misc) make this a contextmanager to cleanup properly on failures (altough currently the caller handles this) 18 | def start_cluster(masters, dockerimage=None, extraparams="", extrapath="", ipv4=True): 19 | addr = "127.0.0.1" if ipv4 else "::1" 20 | ret = [] 21 | if dockerimage is None: 22 | subprocess.call("rm /tmp/justredis_cluster*.conf", shell=True) 23 | for x in range(masters): 24 | ret.append(RedisServer(dockerimage=dockerimage, extraparams="--cluster-enabled yes --cluster-config-file /tmp/justredis_cluster%d.conf" % x, extrapath=extrapath)) 25 | # Requires redis-cli from version 5 for cluster management support 26 | orig_path = os.getenv("PATH") 27 | try: 28 | if extrapath: 29 | os.putenv("PATH", os.pathsep.join((os.getenv("PATH"), extrapath))) 30 | if dockerimage: 31 | stdout = subprocess.Popen( 32 | "docker run -i --rm --net=host " + dockerimage + " redis-cli --cluster create " + " ".join(["%s:%d" % (addr, server.port) for server in ret]), 33 | stdin=subprocess.PIPE, 34 | stdout=subprocess.PIPE, 35 | shell=True, 36 | ).communicate(b"yes\n") 37 | else: 38 | stdout = subprocess.Popen( 39 | "redis-cli --cluster create " + " ".join(["%s:%d" % (addr, server.port) for server in ret]), stdin=subprocess.PIPE, stdout=subprocess.PIPE, shell=True 40 | ).communicate(b"yes\n") 41 | finally: 42 | os.putenv("PATH", orig_path) 43 | if stdout[0] == b"": 44 | raise Exception("Empty output from redis-cli ?") 45 | return ret, stdout[0] 46 | 47 | 48 | class RedisServer(object): 49 | def __init__(self, dockerimage=None, extraparams="", extrapath=""): 50 | self._proc = None 51 | while True: 52 | self.close() 53 | self._port = random.randint(1025, 65535) 54 | if dockerimage: 55 | # cmd = 'docker run --rm -p {port}:6379 {image} --save {extraparams}'.format(port=self.port, image=dockerimage, extraparams=extraparams) 56 | cmd = "docker run --rm --net=host {image} --save --port {port} {extraparams}".format(port=self.port, image=dockerimage, extraparams=extraparams) 57 | else: 58 | cmd = "redis-server --save --port {port} {extraparams}".format(port=self.port, extraparams=extraparams) 59 | kwargs = {} 60 | if platform == "linux": 61 | kwargs["preexec_fn"] = os.setsid 62 | orig_path = os.getenv("PATH") 63 | try: 64 | if extrapath: 65 | os.putenv("PATH", os.pathsep.join((os.getenv("PATH"), extrapath))) 66 | self._proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) 67 | finally: 68 | os.putenv("PATH", orig_path) 69 | seen_redis = False 70 | while True: 71 | line = self._proc.stdout.readline() 72 | if b"Redis" in line: 73 | seen_redis = True 74 | if line == b"": 75 | self._port = None 76 | break 77 | elif b"Ready to accept connections" in line: 78 | break 79 | elif b"Opening Unix socket" in line and b"Address already in use" in line: 80 | raise Exception("Unix domain already in use") 81 | # usually could not find docker image 82 | if not seen_redis: 83 | raise Exception("Could not run redis") 84 | if self._port: 85 | break 86 | 87 | @property 88 | def port(self): 89 | return self._port 90 | 91 | def close(self): 92 | if self._proc: 93 | try: 94 | self._proc.stdout.close() 95 | if platform == "linux": 96 | os.killpg(os.getpgid(self._proc.pid), signal.SIGTERM) 97 | else: 98 | self._proc.kill() 99 | self._proc.wait() 100 | except Exception: 101 | pass 102 | self._proc = None 103 | 104 | def __del__(self): 105 | self.close() 106 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from justredis import Redis, Error, CommunicationError 3 | 4 | 5 | # TODO (misc) copy all of misc/example.py into here 6 | 7 | 8 | def test_connection_error(): 9 | with pytest.raises(CommunicationError): 10 | with Redis(address=("127.0.0.222", 11121)) as r: 11 | r("set", "a", "b") 12 | 13 | 14 | def test_auth(client_with_blah_password): 15 | address = client_with_blah_password.endpoints()[0][0] 16 | # No password 17 | with Redis(address=address) as r: 18 | with pytest.raises(Error) as exc_info: 19 | r("set", "auth_a", "b") 20 | assert exc_info.value.args[0].startswith("NOAUTH ") 21 | 22 | # Wrong password 23 | with Redis(address=address, password="nop") as r: 24 | with pytest.raises(Error) as exc_info: 25 | r("set", "auth_a", "b") 26 | # Changes between Redis 5 and Redis 6 27 | assert exc_info.value.args[0].startswith("WRONGPASS ") or exc_info.value.args[0].startswith("ERR invalid password") 28 | 29 | # Correct password 30 | with Redis(address=address, password="blah") as r: 31 | assert r("set", "auth_a", "b") == b"OK" 32 | 33 | 34 | def test_simple(client): 35 | r = client 36 | assert r("set", "simple_a", "a") == b"OK" 37 | assert r("set", "simple_b", "b") == b"OK" 38 | assert r("set", "simple_c", "c") == b"OK" 39 | assert r("set", "simple_{a}b", "d") == b"OK" 40 | assert r("get", "simple_a") == b"a" 41 | assert r("get", "simple_b") == b"b" 42 | assert r("get", "simple_c") == b"c" 43 | assert r("get", "simple_{a}b") == b"d" 44 | 45 | 46 | def test_modify_database(no_cluster_client): 47 | r = no_cluster_client 48 | r("set", "modify_database_a_0", "a") 49 | with r.modify(database=1).connection(key="a") as c: 50 | assert c("get", "modify_database_a_0") == None 51 | assert c("set", "modify_database_a_1", "a") == b"OK" 52 | 53 | 54 | def test_modify_database_cluster(cluster_client): 55 | r = cluster_client 56 | r("set", "modify_database_cluster_a_0", "a") 57 | with pytest.raises(Error) as exc_info: 58 | with r.modify(database=1).connection(key="a") as c: 59 | assert c("get", "modify_database_a_0") == None 60 | assert c("set", "modify_database_a_1", "a") == b"OK" 61 | assert exc_info.value.args[0] == ("ERR SELECT is not allowed in cluster mode") 62 | 63 | 64 | def test_notallowed(client): 65 | r = client 66 | with pytest.raises(Error) as exc_info: 67 | r("auth", "asd") 68 | assert exc_info.value.args[0].startswith("ERR ") 69 | 70 | 71 | def test_some_encodings(client): 72 | r = client 73 | with pytest.raises(ValueError): 74 | r("set", "a", True) 75 | assert r("incrbyfloat", "float_check", 0.1) == b"0.1" 76 | with pytest.raises(ValueError): 77 | r("set", "a", [1, 2]) 78 | r("set", "{check}_a", "a") 79 | r("set", "{check}_b", "b") 80 | assert r("get", "{check}_a", decoder="utf8") == "a" 81 | assert r("mget", "{check}_a", "{check}_b", decoder="utf8") == ["a", "b"] 82 | 83 | 84 | def test_chunk_encoded_command(client): 85 | r = client 86 | assert r("set", "test_chunk_encoded_command_a", b"test_chunk_encoded_command_a" * 10 * 1024) == b"OK" 87 | assert r("get", "test_chunk_encoded_command_a") == b"test_chunk_encoded_command_a" * 10 * 1024 88 | assert r("mget", "test_chunk_encoded_command_a" * 3500, "test_chunk_encoded_command_a" * 3500, "test_chunk_encoded_command_a" * 3500) == [None, None, None] 89 | 90 | 91 | def test_eval(client): 92 | r = client 93 | assert r("set", "evaltest", "a") == b"OK" 94 | assert r("eval", "return redis.call('get',KEYS[1])", 1, "evaltest") == b"a" 95 | assert r("eval", "return redis.call('get',KEYS[1])", 1, "evaltestno") == None 96 | assert r("eval", "return redis.call('get',KEYS[1])", 1, "evaltest") == b"a" 97 | assert r("eval", "return redis.call('get',KEYS[1])", 1, "evaltestno") == None 98 | assert r("script", "flush") == b"OK" 99 | assert r("eval", "return redis.call('get',KEYS[1])", 1, "evaltest") == b"a" 100 | assert r("eval", "return redis.call('get',KEYS[1])", 1, "evaltestno") == None 101 | 102 | 103 | def test_pipeline(client): 104 | r = client 105 | assert r(("set", "abc", "def"), ("get", "abc")) == [b"OK", b"def"] 106 | 107 | 108 | # TODO (misc) add some extra checks here for invalid states 109 | def test_multi(client): 110 | r = client 111 | with r.connection(key="a") as c: 112 | c("multi") 113 | c("set", "a", "b") 114 | c("get", "a") 115 | assert c("exec") == [b"OK", b"b"] 116 | 117 | # TODO (misc) kinda lame 118 | try: 119 | with r.modify(database=2).connection(key="a") as c1: 120 | c1("multi") 121 | c1("set", "a", "b") 122 | with r.modify(database=3).connection(key="a") as c2: 123 | c2("multi") 124 | c2("set", "a1", "c") 125 | c2("get", "a") 126 | assert c2("exec") == [b"OK", None] 127 | c1("mget", "a", "a1") 128 | assert c1("exec") == [b"OK", [b"b", None]] 129 | assert r.modify(database=2)("get", "a") == b"b" 130 | except Error as e: 131 | if e.args[0] == "ERR SELECT is not allowed in cluster mode": 132 | pass 133 | else: 134 | raise 135 | 136 | 137 | def test_multidiscard(client): 138 | r = client 139 | with r.connection(key="a") as c: 140 | c("multi") 141 | with pytest.raises(Error): 142 | c("nothing") 143 | c("discard") 144 | c("multi") 145 | c("set", "a", "b") 146 | assert c("exec") == [b"OK"] 147 | 148 | 149 | def test_pubsub(client): 150 | r = client 151 | with r.connection(push=True) as p: 152 | p("subscribe", "hi") 153 | p("psubscribe", b"bye") 154 | assert p.next_message() == [b"subscribe", b"hi", 1] 155 | assert p.next_message() == [b"psubscribe", b"bye", 2] 156 | assert p.next_message(timeout=0.1) == None 157 | r("publish", "hi", "there") 158 | assert p.next_message(timeout=0.1) == [b"message", b"hi", b"there"] 159 | r("publish", "bye", "there") 160 | assert p.next_message(timeout=0.1) == [b"pmessage", b"bye", b"bye", b"there"] 161 | p("ping") 162 | # RESP2 and RESP3 behave differently here, so check for both 163 | assert p.next_message() in (b"PONG", [b"pong", b""]) 164 | p("ping", b"hi") 165 | assert p.next_message() in (b"hi", [b"pong", b"hi"]) 166 | p("unsubscribe", "hi") 167 | 168 | 169 | def test_misc(client): 170 | r = client 171 | # This tests an command which redis server says keys start in index 2. 172 | r("object", "help") 173 | # Check command with no keys 174 | r("client", "list") 175 | 176 | 177 | def test_server_no_cluster(no_cluster_client): 178 | r = no_cluster_client 179 | r("set", "cluster_aa", "a") == b"OK" 180 | r("set", "cluster_bb", "b") == b"OK" 181 | r("set", "cluster_cc", "c") == b"OK" 182 | result = r("keys", "cluster_*", endpoint="masters") 183 | result = list(result.values()) 184 | result = [i for s in result for i in s] 185 | assert set(result) == set([b"cluster_aa", b"cluster_bb", b"cluster_cc"]) 186 | 187 | 188 | def test_server_cluster(cluster_client): 189 | r = cluster_client 190 | # TODO (misc) split keys to 3 comps 191 | r("set", "cluster_aa", "a") == b"OK" 192 | r("set", "cluster_bb", "b") == b"OK" 193 | r("set", "cluster_cc", "c") == b"OK" 194 | result = r("keys", "cluster_*", endpoint="masters") 195 | assert len(result) == 3 196 | result = list(result.values()) 197 | result = [i for s in result for i in s] 198 | assert set(result) == set([b"cluster_aa", b"cluster_bb", b"cluster_cc"]) 199 | 200 | 201 | def test_moved_no_cluster(no_cluster_client): 202 | r = no_cluster_client 203 | r("set", "aa", "a") == b"OK" 204 | r("set", "bb", "b") == b"OK" 205 | r("set", "cc", "c") == b"OK" 206 | result = r("get", "aa", endpoint="masters") 207 | result = list(result.values()) 208 | assert result == [b"a"] 209 | 210 | 211 | def test_moved_cluster(cluster_client): 212 | r = cluster_client 213 | r("set", "aa", "a") == b"OK" 214 | r("set", "bb", "b") == b"OK" 215 | r("set", "cc", "c") == b"OK" 216 | assert r("get", "aa") == b"a" 217 | assert r("get", "bb") == b"b" 218 | assert r("get", "cc") == b"c" 219 | result = r("get", "aa", endpoint="masters") 220 | result = list(result.values()) 221 | assert b"a" in result 222 | assert len([x for x in result if isinstance(x, Error) and x.args[0].startswith("MOVED ")]) == 2 223 | -------------------------------------------------------------------------------- /tests/test_example.py: -------------------------------------------------------------------------------- 1 | from justredis import Redis, Error 2 | 3 | 4 | def example(): 5 | # Let's connect to localhost:6379 and decode the string results as utf-8 strings. 6 | r = Redis(decoder="utf8") 7 | assert r("set", "a", "b") == "OK" 8 | assert r("get", "a") == "b" 9 | assert r("get", "a", decoder=None) == b"b" # But this can be changed on the fly 10 | 11 | with r.modify(database=1) as r1: 12 | assert r1("get", "a") == None # In this database, a was not set to b 13 | 14 | # Here we can use a transactional set of commands 15 | with r.connection(key="a") as c: # Notice we pass here a key from below (not a must if you never plan on connecting to a cluster) 16 | c("multi") 17 | c("set", "a", "b") 18 | c("get", "a") 19 | assert c("exec") == ["OK", "b"] 20 | 21 | # Or we can just pipeline them. 22 | with r.connection(key="a") as c: 23 | result = c(("multi",), ("set", "a", "b"), ("get", "a"), ("exec",))[-1] 24 | assert result == ["OK", "b"] 25 | 26 | # Here is the famous increment example 27 | # Notice we take the connection inside the loop, this is to make sure if the cluster moved the keys, it will still be ok. 28 | while True: 29 | with r.connection(key="counter") as c: 30 | c("watch", "counter") 31 | value = int(c("get", "counter") or 0) 32 | c("multi") 33 | c("set", "counter", value + 1) 34 | if c("exec") is None: 35 | continue 36 | value += 1 # The value is updated if we got here 37 | break 38 | 39 | # Let's show some publish & subscribe commands, here we use a push connection (where commands have no direct response) 40 | with r.connection(push=True) as p: 41 | p("subscribe", "hello") 42 | assert p.next_message() == ["subscribe", "hello", 1] 43 | assert p.next_message(timeout=0.1) == None # Let's wait 0.1 seconds for another result 44 | r("publish", "hello", ", World !") 45 | assert p.next_message() == ["message", "hello", ", World !"] 46 | 47 | 48 | if __name__ == "__main__": 49 | example() 50 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = clean,py{35,36,37,38,39,py3} 3 | #isolated_build = true 4 | skip_missing_interpreters = true 5 | 6 | [testenv] 7 | deps = 8 | pytest 9 | pytest-cov 10 | py{36,37,38,39,py3}: anyio[trio,curio] 11 | commands = 12 | py{35,36,37,38,39,py3}: pytest --cov={toxinidir}/justredis --cov={toxinidir}/tests --cov-append --cov-report=term-missing {posargs} 13 | passenv = 14 | REDIS_6_PATH 15 | REDIS_5_PATH 16 | REDIS_USE_DOCKER 17 | 18 | [testenv:clean] 19 | deps = coverage 20 | skip_install = true 21 | commands = coverage erase 22 | --------------------------------------------------------------------------------