├── .gitignore ├── LICENSE ├── NOTES.md ├── README.md ├── benchmark_client.py ├── benchmark_server.py ├── miniredis ├── __init__.py ├── aioserver.py ├── client.py ├── haystack.py ├── server.py └── sset.py ├── setup.py └── tests ├── __init__.py ├── helpers.py ├── helpers_async.py ├── test_keys.py ├── test_keys_async.py ├── test_strings.py └── test_strings_async.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.db 3 | *.bin 4 | *.idx 5 | *.pyc 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Portions Copyright (C) 2013 Rui Carmo. Some rights reserved. 2 | Portions Copyright (C) 2010 Benjamin Pollack. All rights reserved. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /NOTES.md: -------------------------------------------------------------------------------- 1 | # MiniRedis Project Notes 2 | 3 | ## Project Overview 4 | 5 | MiniRedis is a simplified Redis server implementation in Python. It supports two server implementations: 6 | 7 | 1. A synchronous server (server.py) 8 | 2. An asynchronous server (aioserver.py) using asyncio 9 | 10 | The project includes client implementation and test suites for both server types. 11 | 12 | ## Current State 13 | 14 | As of May 6, 2025, the implementation has been fixed to address issues in the asynchronous server implementation. 15 | 16 | ### What Works 17 | 18 | - Both synchronous and asynchronous server implementations now pass their respective test suites 19 | - Redis protocol parsing and response formatting 20 | - Connection handling in both server versions 21 | - All basic Redis operations (GET/SET, KEYS, DELETE, etc.) 22 | - Database selection and persistence across client operations 23 | - Key expiration via TTL/EXPIRE/EXPIREAT 24 | - Data persistence via the Haystack implementation 25 | - Multi-database support 26 | 27 | ### Code Structure 28 | 29 | - **server.py**: Synchronous Redis server implementation 30 | - **aioserver.py**: Asynchronous Redis server implementation using asyncio 31 | - **client.py**: Redis client implementation that works with both servers 32 | - **haystack.py**: Simple key-value data persistence implementation 33 | - **sset.py**: Sorted set implementation for Redis ZSET functionality 34 | - **benchmark_client.py/benchmark_server.py**: Performance benchmarking tools 35 | 36 | ## Implementation Details 37 | 38 | ### Redis Protocol Support 39 | 40 | The implementation handles Redis Serialization Protocol (RESP) for commands and responses: 41 | 42 | - Bulk Strings for most data responses 43 | - Simple Strings for status responses (e.g., "OK") 44 | - Integers for numeric responses 45 | - Error responses for invalid commands or errors 46 | - Arrays for multi-part responses 47 | 48 | ### Key Commands Implemented 49 | 50 | - DEL - Delete keys 51 | - EXISTS - Check if keys exist 52 | - EXPIRE/EXPIREAT - Set key timeout 53 | - TTL/PTTL - Get key timeout 54 | - KEYS - Find keys matching pattern 55 | - TYPE - Get type of value stored at key 56 | 57 | ### String Commands Implemented 58 | 59 | - GET - Get value of key 60 | - SET - Set value of key 61 | - APPEND - Append value to key 62 | - INCR/INCRBY - Increment value of key 63 | - DECR/DECRBY - Decrement value of key 64 | - MGET - Get multiple keys 65 | - GETSET - Set key and return previous value 66 | - SETEX - Set key with expiration 67 | - SETNX - Set key if it doesn't exist 68 | 69 | ### Other Data Types 70 | 71 | - Lists (LPUSH, RPUSH, LPOP, RPOP, LRANGE, LLEN) 72 | - Hashes (HSET, HGET, HGETALL, HDEL, HEXISTS, HINCRBY, HKEYS, HVALS, HLEN) 73 | - Sorted Sets (ZADD, ZRANGE) 74 | - PubSub functionality (PUBLISH, SUBSCRIBE, UNSUBSCRIBE) 75 | 76 | ### AsyncRedisServer Class Architecture 77 | 78 | - Uses asyncio streams for network I/O 79 | - Maintains connection contexts via task attributes 80 | - Table-based multi-database support 81 | - Background tasks for key expiration checking and auto-saving 82 | - Command handlers follow a consistent pattern for error handling and Redis protocol response generation 83 | 84 | ### Data Persistence 85 | 86 | - Uses Haystack class for simple key-value storage 87 | - Data is persisted between server restarts 88 | - Separate files for sync and async servers (redisdb.bin/idx vs redisdb_async.bin/idx) 89 | - Auto-save functionality for data durability 90 | 91 | ## Test Suite Organization 92 | 93 | ### Test Files 94 | 95 | - **test_keys.py**: Tests for key operations on sync server 96 | - **test_keys_async.py**: Tests for key operations on async server 97 | - **test_strings.py**: Tests for string operations on sync server 98 | - **test_strings_async.py**: Tests for string operations on async server 99 | - **helpers.py**: Helper functions for sync server tests 100 | - **helpers_async.py**: Helper functions for async server tests 101 | 102 | ### Test Coverage 103 | 104 | - Basic connectivity and protocol handling 105 | - Key management operations 106 | - String value operations 107 | - Expiry functionality 108 | - Multi-database operations 109 | - Error handling and edge cases 110 | 111 | ## Running the Tests 112 | 113 | To run the tests, use the following commands: 114 | 115 | ```bash 116 | # Run all tests 117 | python -m pytest 118 | 119 | # Run only synchronous tests 120 | python -m pytest tests/test_keys.py tests/test_strings.py 121 | 122 | # Run only asynchronous tests 123 | python -m pytest tests/test_keys_async.py tests/test_strings_async.py 124 | 125 | # Run a specific test (with verbose output) 126 | python -m pytest tests/test_keys_async.py::TestAsyncKeysCommands::test_put_get -v 127 | 128 | # Run all async key tests with verbose output 129 | python -m pytest tests/test_keys_async.py -v 130 | ``` 131 | 132 | ### Test Fixtures 133 | 134 | - Each test module has a fixture that starts the appropriate server 135 | - For sync tests, the server runs in the same process 136 | - For async tests, the server runs in a separate process via multiprocessing 137 | - All tests use the same RedisClient implementation to interact with either server 138 | - `redis_client_async` fixture in test_keys_async.py handles server startup/teardown 139 | 140 | ### Async Test Setup Details 141 | 142 | The async test suite: 143 | 144 | 1. Finds a free port using `find_free_port()` 145 | 2. Starts the AsyncRedisServer in a separate process with `start_async_server()` 146 | 3. Connects to the server with the RedisClient 147 | 4. Flushes the database before each test 148 | 5. Runs the tests against the server 149 | 6. Tears down the server process with `stop_async_server()` 150 | 151 | ## Challenges Solved 152 | 153 | 1. **Connection Context Management**: Fixed issues with maintaining DB selection and connection state across client operations. The AsyncRedisServer now correctly associates each client with its selected database. 154 | 155 | 2. **Proper Task Cleanup**: Ensured all asyncio tasks are properly tracked and cleaned up when connections close or the server shuts down. 156 | 157 | 3. **Database Persistence**: Fixed how databases are stored and retrieved from the Haystack persistence layer. 158 | 159 | 4. **Key Expiration**: Implemented proper background task for checking key expiration and removing expired keys. 160 | 161 | 5. **Protocol Handling**: Enhanced the Redis protocol implementation to correctly handle all required data types and error conditions. 162 | 163 | 6. **TTL Management in SET Operations**: Fixed the SET command to properly remove any existing TTL when a key is updated with a new value, which aligns with standard Redis behavior. 164 | 165 | 7. **SETEX Response Format**: Corrected the SETEX command to return "OK" as a string rather than returning a boolean value, fixing client parsing errors. 166 | 167 | ## Recent Fixes (May 6, 2025) 168 | 169 | ### Fixed SET Command TTL Handling 170 | 171 | The async implementation of the SET command was not removing existing TTL values when updating a key with a new value. This has been fixed to match Redis behavior where setting a key removes any existing expiration. 172 | 173 | ```python 174 | # In handle_set method: 175 | # Remove any expiration when setting a key (Redis behavior) 176 | if isinstance(db_num, int) and f"{db_num} {key}" in self.timeouts: 177 | del self.timeouts[f"{db_num} {key}"] 178 | ``` 179 | 180 | ### Fixed SETEX Response Type 181 | 182 | The async implementation of SETEX was returning a boolean True value instead of the string "OK" as required by the Redis protocol. This was causing client-side parsing errors when handling responses. 183 | 184 | ```python 185 | # In handle_setex method: 186 | # Set the key and expiration 187 | self.tables[db_num][key] = value 188 | self.timeouts[f"{db_num} {key}"] = time.time() + ttl 189 | return "OK" # Return "OK" instead of True 190 | ``` 191 | 192 | These fixes ensure full compatibility with Redis clients and consistent behavior between the synchronous and asynchronous implementations. 193 | 194 | ## Future Improvements 195 | 196 | - Add benchmarking results to compare sync vs async implementation performance 197 | - Implement more Redis commands (sorting, bit operations, more advanced list/hash operations) 198 | - Add more comprehensive logging for debugging and monitoring 199 | - Consider connection pooling for more efficient client handling 200 | - Implement transactions (MULTI/EXEC/WATCH) 201 | - Add support for Redis modules 202 | - Add more extensive error handling and recovery mechanisms 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `miniredis` 2 | 3 | `miniredis` is a pure Python server that supports a subset of the redis protocol. 4 | 5 | ## Why? 6 | 7 | The original intent was to have a minimally working (if naïve) PubSub implementation in order to get to grips with the [protocol spec](http://redis.io/topics/protocol), but I eventually realised that a more complete server would be useful for testing and inclusion in some of my projects. 8 | 9 | ## Performance 10 | 11 | Pure Python performance, dependent on runtime and workload. Your mileage may vary. 12 | 13 | ## Credits 14 | 15 | I started out by forking [coderanger/miniredis](https://github.com/coderanger/miniredis) for experimentation, and things kind of accreted from there as I started implementing more commands. -------------------------------------------------------------------------------- /benchmark_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | Created by Rui Carmo on 2013-03-12 5 | Published under the MIT license. 6 | """ 7 | 8 | import os, sys, logging 9 | from miniredis.client import RedisClient 10 | from multiprocessing import Pool, current_process 11 | import time 12 | import random 13 | 14 | log = logging.getLogger() 15 | 16 | # Assume server is running on localhost:6379 or configure as needed 17 | REDIS_HOST = "localhost" 18 | REDIS_PORT = 6379 19 | 20 | # Moved timed_worker function outside of __main__ block to make it picklable for multiprocessing 21 | def timed_worker(args): 22 | count, db_index = args 23 | worker_name = current_process().name 24 | c = None 25 | ops_done = 0 26 | try: 27 | log.info(f"{worker_name}: Connecting to {REDIS_HOST}:{REDIS_PORT}") 28 | c = RedisClient(host=REDIS_HOST, port=REDIS_PORT) 29 | c.select(db_index) # Use a different DB per worker if desired 30 | log.info(f"{worker_name}: Selected DB {db_index}") 31 | 32 | # Use a range relevant to the number of operations 33 | key_range = list(range(0, max(10000, count // 2))) # Adjust key space size 34 | if not key_range: 35 | log.warning( 36 | f"{worker_name}: Key range is empty, cannot perform operations." 37 | ) 38 | return 0.0, 0 # Return time and ops count 39 | 40 | log.info(f"{worker_name}: Starting {count} SET/GET operations...") 41 | start_time = time.monotonic() 42 | 43 | for i in range(count): 44 | try: 45 | # Generate key within the pool for this worker 46 | key_suffix = random.choice(key_range) 47 | key = f"bench:{worker_name}:{key_suffix}" 48 | value = f"v-{i}" 49 | 50 | c.set(key, value) 51 | ops_done += 1 52 | 53 | # Occasionally get a key 54 | if i % 5 == 0: 55 | get_key_suffix = random.choice(key_range) 56 | get_key = f"bench:{worker_name}:{get_key_suffix}" 57 | c.get(get_key) 58 | ops_done += 1 59 | 60 | except ConnectionError as ce: 61 | log.error( 62 | f"{worker_name}: Connection error during operation {i}: {ce}" 63 | ) 64 | raise # Re-raise to stop this worker 65 | except Exception as e: 66 | log.error( 67 | f"{worker_name}: Error during operation {i} (key: {key}): {e}" 68 | ) 69 | # Decide whether to continue or stop 70 | # continue 71 | raise # Re-raise to stop this worker 72 | 73 | elapsed_time = time.monotonic() - start_time 74 | log.info( 75 | f"{worker_name}: Finished {ops_done} operations in {elapsed_time:.4f} seconds." 76 | ) 77 | return elapsed_time, ops_done 78 | 79 | except ConnectionRefusedError: 80 | log.error( 81 | f"{worker_name}: Connection refused. Is the miniredis server running on {REDIS_HOST}:{REDIS_PORT}?" 82 | ) 83 | return float("inf"), ops_done # Indicate failure 84 | except Exception as e: 85 | log.error(f"{worker_name}: Unhandled exception in worker: {e}") 86 | return float("inf"), ops_done # Indicate failure 87 | finally: 88 | if c: 89 | try: 90 | c.close() 91 | log.info(f"{worker_name}: Connection closed.") 92 | except Exception as e: 93 | log.error(f"{worker_name}: Error closing connection: {e}") 94 | 95 | if __name__ == "__main__": 96 | logging.basicConfig( 97 | level=logging.INFO, 98 | format="%(asctime)s - %(processName)s - %(levelname)s - %(message)s", 99 | ) 100 | 101 | num_workers = 4 102 | # Total operations roughly split among workers 103 | # Let's aim for ~100k total ops (SETs + some GETs) 104 | total_target_ops = 100000 105 | # Estimate ops per worker (mostly SETs, some GETs) 106 | # Each loop does 1 SET and sometimes 1 GET (avg 1.2 ops/loop) 107 | loops_per_worker = int(total_target_ops / num_workers / 1.2) 108 | 109 | worker_args = [ 110 | (loops_per_worker, i) for i in range(num_workers) 111 | ] # Assign DB index i to worker i 112 | 113 | log.info( 114 | f"Starting benchmark with {num_workers} workers, approx {loops_per_worker} loops each." 115 | ) 116 | 117 | total_time = 0.0 118 | total_ops_completed = 0 119 | successful_workers = 0 120 | 121 | # Use try-except around the Pool to catch potential setup issues 122 | try: 123 | with Pool(num_workers) as p: 124 | results = p.map(timed_worker, worker_args) 125 | 126 | for elapsed, ops_count in results: 127 | if elapsed != float("inf"): 128 | total_time += elapsed 129 | total_ops_completed += ops_count 130 | successful_workers += 1 131 | else: 132 | log.warning("A worker failed to complete.") 133 | 134 | except Exception as e: 135 | log.error(f"Error during multiprocessing pool execution: {e}") 136 | # Exit or handle as appropriate 137 | sys.exit(1) 138 | 139 | log.info("Benchmark finished.") 140 | 141 | if successful_workers > 0 and total_time > 0: 142 | # Calculate ops/sec based on the sum of time spent by successful workers 143 | # and the total operations they completed. 144 | ops_sec = total_ops_completed / total_time 145 | print(f"\n--- Benchmark Summary ---") 146 | print(f"Successful Workers: {successful_workers}/{num_workers}") 147 | print(f"Total Operations Completed: {total_ops_completed}") 148 | print(f"Total Worker CPU Time: {total_time:.4f} seconds") 149 | print(f"Aggregate Operations/Second: {ops_sec:.2f}") 150 | print(f"(Note: This is aggregate throughput, not single-client latency)") 151 | elif successful_workers == 0: 152 | print("\n--- Benchmark Failed ---") 153 | print( 154 | "All workers encountered errors. Please check logs and ensure the server is running." 155 | ) 156 | else: # total_time is 0 or total_ops_completed is 0 157 | print("\n--- Benchmark Result ---") 158 | print("No time elapsed or no operations completed by successful workers.") 159 | -------------------------------------------------------------------------------- /benchmark_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | First modified by Rui Carmo on 2013-03-12 5 | Published under the MIT license. 6 | """ 7 | 8 | import os, sys, logging, signal 9 | from miniredis.server import RedisServer 10 | 11 | # Configure logging 12 | logging.basicConfig( 13 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 14 | ) 15 | log = logging.getLogger(__name__) 16 | 17 | # Global server instance for signal handling 18 | server_instance = None 19 | 20 | 21 | def shutdown_handler(signum, frame): 22 | """Gracefully shut down the server on SIGINT or SIGTERM.""" 23 | log.info(f"Received signal {signal.Signals(signum).name}. Shutting down...") 24 | if server_instance: 25 | server_instance.stop() 26 | # The run loop should exit after stop() sets halt=True 27 | # sys.exit(0) # Avoid exiting here, let the main loop finish 28 | 29 | 30 | def main(): 31 | global server_instance 32 | # Register signal handlers for graceful shutdown 33 | signal.signal(signal.SIGINT, shutdown_handler) 34 | signal.signal(signal.SIGTERM, shutdown_handler) 35 | 36 | # TODO: Add argument parsing for host, port, db_path, etc. 37 | host = "127.0.0.1" 38 | port = 6379 39 | db_path = "." 40 | 41 | log.info(f"Starting miniredis server on {host}:{port}") 42 | log.info(f"Database path: {os.path.abspath(db_path)}") 43 | server_instance = RedisServer(host=host, port=port, db_path=db_path) 44 | 45 | try: 46 | server_instance.run() 47 | except Exception as e: 48 | log.exception("An unexpected error occurred in the server run loop") 49 | finally: 50 | log.info("Server has stopped.") 51 | # stop() should handle saving, but ensure it's called if not via signal 52 | if server_instance and not server_instance.halt: 53 | log.info("Performing final shutdown sequence.") 54 | server_instance.stop() 55 | 56 | sys.exit(0) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /miniredis/__init__.py: -------------------------------------------------------------------------------- 1 | # placeholder 2 | -------------------------------------------------------------------------------- /miniredis/aioserver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | Asyncio-based Redis server implementation for miniredis. 5 | 6 | Based on the original synchronous server.py. 7 | """ 8 | 9 | import asyncio 10 | import logging 11 | import time 12 | import signal 13 | import re 14 | from collections import deque 15 | from dataclasses import dataclass, field 16 | from pathlib import Path 17 | from typing import Any, Dict, List, Optional, Tuple, Union, Set 18 | 19 | # Import Haystack and SortedSet 20 | from .haystack import Haystack 21 | from .sset import SortedSet 22 | 23 | log = logging.getLogger() 24 | 25 | # --- Simplified Error/Message Classes (similar to server.py) --- 26 | 27 | class RedisError(Exception): 28 | def __init__(self, message: str): 29 | self.message = message 30 | super().__init__(message) 31 | 32 | def __str__(self) -> str: 33 | return f"-ERR {self.message}\r\n" 34 | 35 | class RedisMessage: 36 | def __init__(self, message: str): 37 | self.message = message 38 | 39 | def __str__(self) -> str: 40 | return f"+{self.message}\r\n" 41 | 42 | # --- Client Connection State --- 43 | @dataclass 44 | class AsyncRedisConnection: 45 | reader: asyncio.StreamReader 46 | writer: asyncio.StreamWriter 47 | db: int = 0 48 | 49 | # --- Redis Constants for Common Values --- 50 | class RedisConstant: 51 | def __init__(self, type: str) -> None: 52 | self.type = type 53 | 54 | def __len__(self) -> int: 55 | return 0 56 | 57 | def __repr__(self) -> str: 58 | return f"" 59 | 60 | EMPTY_SCALAR = RedisConstant("EmptyScalar") 61 | EMPTY_LIST = RedisConstant("EmptyList") 62 | BAD_VALUE = RedisError("Operation against a key holding the wrong kind of value") 63 | 64 | # --- Async Server Implementation --- 65 | 66 | class AsyncRedisServer: 67 | def __init__(self, host: str = "127.0.0.1", port: int = 6379, db_path: str = ".") -> None: 68 | self.host = host 69 | self.port = port 70 | # Use tables for multi-db support 71 | self.tables: Dict[int, Dict[str, Any]] = {} 72 | self._server: Optional[asyncio.AbstractServer] = None 73 | self._tasks: set[asyncio.Task] = set() 74 | self.path = Path(db_path) 75 | self.meta = Haystack(self.path, "redisdb_async") # Use different filename 76 | # Expiry management 77 | self.timeouts: Dict[str, float] = {} 78 | # Keep track of last save 79 | self.lastsave = int(time.time()) 80 | # Channels for PubSub 81 | self.channels: Dict[str, List[AsyncRedisConnection]] = {} 82 | # Track client connections by peername 83 | self.client_connections: Dict[str, AsyncRedisConnection] = {} 84 | # Load initial data 85 | self._load_data() 86 | log.info(f"AsyncRedisServer initialized for {host}:{port}, DB path: {self.path}") 87 | 88 | def _load_data(self) -> None: 89 | """Loads data from Haystack storage.""" 90 | try: 91 | # Load timeouts 92 | self.timeouts = self.meta.get('timeouts', {}) 93 | 94 | # Load tables for each DB found in meta 95 | db_keys = [k for k in self.meta.keys() if k.startswith('db_')] 96 | for db_key in db_keys: 97 | try: 98 | db_num = int(db_key.split('_')[1]) 99 | self.tables[db_num] = self.meta.get(db_key, {}) 100 | log.info(f"Loaded data for DB {db_num}") 101 | except (ValueError, IndexError): 102 | log.warning(f"Could not parse DB number from key: {db_key}") 103 | 104 | # Ensure DB 0 exists if no other DBs were loaded 105 | if 0 not in self.tables: 106 | self.tables[0] = {} 107 | 108 | log.info("Data loading complete.") 109 | except Exception as e: 110 | log.exception(f"Error loading data from {self.path}: {e}") 111 | # Ensure DB 0 exists even if loading fails 112 | if 0 not in self.tables: 113 | self.tables[0] = {} 114 | 115 | async def save_data(self) -> None: 116 | """Saves current data to Haystack storage.""" 117 | log.info("Saving data...") 118 | try: 119 | # Save timeouts 120 | self.meta['timeouts'] = self.timeouts 121 | 122 | # Save each DB table 123 | for db_num, table in self.tables.items(): 124 | self.meta[f'db_{db_num}'] = table 125 | 126 | await asyncio.to_thread(self.meta.commit) # Run sync commit in thread 127 | self.lastsave = int(time.time()) 128 | log.info("Data saved successfully.") 129 | except Exception as e: 130 | log.exception(f"Error saving data: {e}") 131 | 132 | async def check_ttl(self, db_num: int, key: str) -> bool: 133 | """Check if a key has expired. Returns True if key exists and is valid.""" 134 | k = f"{db_num} {key}" 135 | if k in self.timeouts: 136 | if self.timeouts[k] <= time.time(): 137 | # Key has expired - remove it 138 | if key in self.tables[db_num]: 139 | del self.tables[db_num][key] 140 | del self.timeouts[k] 141 | return False 142 | return key in self.tables[db_num] 143 | 144 | async def _encode_response(self, value: Any) -> bytes: 145 | """Encodes a Python value into the Redis protocol response.""" 146 | if isinstance(value, bytes): 147 | return f"${len(value)}\r\n".encode() + value + b"\r\n" 148 | elif isinstance(value, str): 149 | # Special handling for "OK" responses - use simple string format 150 | if value == "OK": 151 | return b"+OK\r\n" 152 | encoded_value = value.encode() 153 | return f"${len(encoded_value)}\r\n".encode() + encoded_value + b"\r\n" 154 | elif isinstance(value, int): 155 | return f":{value}\r\n".encode() 156 | elif isinstance(value, RedisError): 157 | return str(value).encode() 158 | elif isinstance(value, RedisMessage): 159 | return str(value).encode() 160 | elif value is None: 161 | return b"$-1\r\n" # Null Bulk String 162 | elif isinstance(value, list): 163 | encoded_items = b"".join([await self._encode_response(item) for item in value]) 164 | return f"*{len(value)}\r\n".encode() + encoded_items 165 | elif isinstance(value, bool): # Convert bool to integer (1=True, 0=False) for Redis protocol 166 | return f":{1 if value else 0}\r\n".encode() 167 | else: 168 | # Fallback for unknown types 169 | log.warning(f"Encoding unknown type: {type(value)}") 170 | return str(RedisError(f"Cannot encode type {type(value).__name__}")).encode() 171 | 172 | async def handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: 173 | """Handles a single client connection.""" 174 | peername = writer.get_extra_info('peername') 175 | connection = AsyncRedisConnection(reader=reader, writer=writer) 176 | connection_key = str(peername) if peername else f"anon-{id(connection)}" 177 | 178 | log.info(f"Client connected: {connection_key} (DB {connection.db})") 179 | task = asyncio.current_task() 180 | self._tasks.add(task) 181 | task.add_done_callback(self._tasks.discard) 182 | 183 | # Store connection in the task for command handlers to access 184 | setattr(task, 'connection', connection) 185 | 186 | # Track the client connection with peername as key 187 | self.client_connections[connection_key] = connection 188 | log.info(f"Added client connection {connection_key} to tracking") 189 | 190 | # Ensure the client's selected DB exists 191 | if connection.db not in self.tables: 192 | self.tables[connection.db] = {} 193 | 194 | try: 195 | while True: 196 | try: 197 | # 1. Read the command type and count 198 | line = await reader.readline() 199 | if not line or line == b'': # Connection closed 200 | log.info(f"Client disconnected: {connection_key}") 201 | break 202 | if not line.startswith(b'*'): 203 | writer.write(str(RedisError("Protocol error: expected array")).encode()) 204 | await writer.drain() 205 | continue 206 | 207 | try: 208 | item_count = int(line[1:].strip()) 209 | except ValueError: 210 | writer.write(str(RedisError("Protocol error: invalid array length")).encode()) 211 | await writer.drain() 212 | continue 213 | 214 | # 2. Read command arguments 215 | args: List[bytes] = [] 216 | for _ in range(item_count): 217 | # Read bulk string length 218 | len_line = await reader.readline() 219 | if not len_line or not len_line.startswith(b'$'): 220 | raise RedisError("Protocol error: expected bulk string length") 221 | try: 222 | length = int(len_line[1:].strip()) 223 | except ValueError: 224 | raise RedisError("Protocol error: invalid bulk string length") 225 | 226 | if length == -1: 227 | args.append(b"") # Represent null bulk string as empty bytes for simplicity here 228 | else: 229 | # Read bulk string data + CRLF 230 | data = await reader.readexactly(length + 2) 231 | if data[-2:] != b'\r\n': 232 | raise RedisError("Protocol error: expected CRLF after bulk string") 233 | args.append(data[:-2]) 234 | 235 | if not args: 236 | continue # Should not happen if item_count > 0 237 | 238 | # 3. Decode command and dispatch 239 | command = args[0].decode().lower() 240 | decoded_args = [arg.decode() for arg in args[1:]] # Decode remaining args 241 | 242 | log.debug(f"Client {connection_key} executing command: {command} {decoded_args}") 243 | 244 | handler_name = f"handle_{command}" 245 | response: Any 246 | if hasattr(self, handler_name): 247 | handler = getattr(self, handler_name) 248 | # Get current table based on connection's selected DB 249 | if connection.db not in self.tables: 250 | self.tables[connection.db] = {} 251 | current_table = self.tables[connection.db] 252 | 253 | # Pass the current table to commands that operate directly on data 254 | if command in ('set', 'get'): 255 | response = await handler(current_table, *decoded_args) 256 | elif command == 'select': 257 | response = await handler(connection, *decoded_args) 258 | # Update current_table if SELECT was successful 259 | if not isinstance(response, RedisError): 260 | if connection.db not in self.tables: 261 | self.tables[connection.db] = {} 262 | log.info(f"Client {connection_key} switched to DB {connection.db}") 263 | elif command in ('subscribe', 'unsubscribe', 'psubscribe', 'punsubscribe'): 264 | await handler(connection, *decoded_args) 265 | continue 266 | else: 267 | # For all other commands, temporarily ensure the task has the connection 268 | # This creates a consistent context for all command handlers 269 | old_connection = getattr(task, 'connection', None) 270 | setattr(task, 'connection', connection) 271 | try: 272 | response = await handler(*decoded_args) 273 | finally: 274 | # Restore the original connection if there was one 275 | if old_connection: 276 | setattr(task, 'connection', old_connection) 277 | else: 278 | setattr(task, 'connection', connection) 279 | else: 280 | response = RedisError(f"unknown command '{command}'") 281 | 282 | # 4. Send response 283 | encoded_response = await self._encode_response(response) 284 | writer.write(encoded_response) 285 | await writer.drain() 286 | 287 | # Special case for QUIT 288 | if command == 'quit': 289 | log.info(f"Client requested QUIT: {connection_key}") 290 | break 291 | 292 | except (asyncio.IncompleteReadError, ConnectionResetError, BrokenPipeError) as e: 293 | log.info(f"Connection error with {connection_key}: {e}") 294 | break 295 | except RedisError as e: 296 | log.warning(f"Redis error for {connection_key}: {e.message}") 297 | writer.write(str(e).encode()) 298 | await writer.drain() 299 | except Exception as e: 300 | log.exception(f"Unexpected error handling client {connection_key}: {e}") 301 | try: 302 | writer.write(str(RedisError(f"Internal server error: {e}")).encode()) 303 | await writer.drain() 304 | except (ConnectionResetError, BrokenPipeError): 305 | pass # Client likely disconnected 306 | break # Stop handling this client on unexpected errors 307 | finally: 308 | # Cleanup 309 | try: 310 | writer.close() 311 | await writer.wait_closed() 312 | except Exception as e: 313 | log.debug(f"Error during writer close for {connection_key}: {e}") 314 | 315 | log.info(f"Connection closed for {connection_key}") 316 | # Remove connection from task when done 317 | if hasattr(task, 'connection'): 318 | delattr(task, 'connection') 319 | 320 | # Remove the client connection from tracking 321 | if connection_key in self.client_connections: 322 | log.info(f"Removing client connection {connection_key} from tracking") 323 | del self.client_connections[connection_key] 324 | 325 | # --- Command Handlers --- 326 | 327 | async def handle_ping(self, *args: str) -> RedisMessage: 328 | log.debug("Handling PING") 329 | if len(args) == 1: 330 | return RedisMessage(args[0]) # Echo message 331 | elif len(args) == 0: 332 | return RedisMessage("PONG") 333 | else: 334 | return RedisError("wrong number of arguments for 'ping' command") 335 | 336 | async def handle_set(self, table: Dict[str, Any], key: str, value: str, *options: str) -> str: 337 | # Find the DB number safely for logging 338 | db_nums = [db for db, t in self.tables.items() if t is table] 339 | db_num = db_nums[0] if db_nums else "unknown" 340 | log.debug(f"Handling SET {key} = {value} in DB {db_num}") 341 | 342 | if not key or value is None: # Basic validation 343 | return RedisError("wrong number of arguments for 'set' command") 344 | 345 | # Remove any expiration when setting a key (Redis behavior) 346 | if isinstance(db_num, int) and f"{db_num} {key}" in self.timeouts: 347 | del self.timeouts[f"{db_num} {key}"] 348 | 349 | table[key] = value # Use the passed table 350 | return "OK" # Returns "+OK" in Redis protocol 351 | 352 | async def handle_get(self, table: Dict[str, Any], key: str) -> Optional[str]: 353 | # Find the DB number safely for logging 354 | db_nums = [db for db, t in self.tables.items() if t is table] 355 | db_num = db_nums[0] if db_nums else "unknown" 356 | log.debug(f"Handling GET {key} in DB {db_num}") 357 | 358 | if not key: 359 | return RedisError("wrong number of arguments for 'get' command") 360 | return table.get(key) # Use the passed table 361 | 362 | async def handle_select(self, connection: AsyncRedisConnection, db_index_str: str) -> bool: 363 | """Select the DB for the current connection.""" 364 | try: 365 | db_index = int(db_index_str) 366 | if db_index < 0: 367 | raise ValueError("DB index must be positive") 368 | except ValueError: 369 | return RedisError("invalid DB index") 370 | 371 | # Ensure the target DB exists in self.tables 372 | if db_index not in self.tables: 373 | self.tables[db_index] = {} 374 | 375 | connection.db = db_index 376 | # The current_table in handle_client will be updated after this returns 377 | return True # Returns +OK 378 | 379 | async def handle_save(self) -> bool: 380 | """Explicitly trigger data saving.""" 381 | await self.save_data() 382 | return True # Returns +OK 383 | 384 | async def handle_quit(self) -> RedisMessage: 385 | # Response is sent before connection is closed by the handler loop 386 | return RedisMessage("OK") 387 | 388 | # --- Server Management Commands --- 389 | 390 | async def handle_flushdb(self) -> int: 391 | """Remove all keys from the current DB.""" 392 | # Try all DBs if we don't have connection context 393 | if 'connection' in locals(): 394 | db_nums = [connection.db] 395 | else: 396 | db_nums = list(self.tables.keys()) 397 | 398 | for db_num in db_nums: 399 | # Clear the DB 400 | self.tables[db_num] = {} 401 | 402 | # Remove any timeouts for this DB 403 | for timeout_key in list(self.timeouts.keys()): 404 | if timeout_key.startswith(f"{db_num} "): 405 | del self.timeouts[timeout_key] 406 | 407 | return 1 # Return integer 1 for Redis protocol OK 408 | 409 | async def handle_flushall(self) -> int: 410 | """Remove all keys from all DBs.""" 411 | # Clear all DBs 412 | for db_num in list(self.tables.keys()): 413 | self.tables[db_num] = {} 414 | 415 | # Clear all timeouts 416 | self.timeouts = {} 417 | 418 | return 1 # Return integer 1 for Redis protocol OK 419 | 420 | # --- PubSub Commands --- 421 | 422 | async def handle_publish(self, channel: str, message: str) -> int: 423 | """Publish a message to a channel""" 424 | if not channel or message is None: 425 | return RedisError("wrong number of arguments for 'publish' command") 426 | 427 | published_count = 0 428 | 429 | # Check for exact channel matches 430 | if channel in self.channels: 431 | for connection in self.channels[channel]: 432 | try: 433 | # Format message in Redis protocol 434 | # *3\r\n$7\r\nmessage\r\n${len(channel)}\r\n{channel}\r\n${len(message)}\r\n{message}\r\n 435 | msg = [ 436 | "message", 437 | channel, 438 | message 439 | ] 440 | encoded = await self._encode_response(msg) 441 | connection.writer.write(encoded) 442 | await connection.writer.drain() 443 | published_count += 1 444 | except (ConnectionError, BrokenPipeError, asyncio.CancelledError) as e: 445 | log.warning(f"Error publishing to client: {e}") 446 | # Will remove broken connections during next subscription 447 | 448 | # Check for pattern matches 449 | for pattern, connections in self.channels.items(): 450 | # Skip exact channels we already processed 451 | if pattern == channel: 452 | continue 453 | 454 | # Check if pattern matches 455 | try: 456 | if '*' in pattern or '?' in pattern: 457 | # Convert Redis glob pattern to regex 458 | regex_pattern = pattern.replace('*', '.*').replace('?', '.') 459 | if re.match(f"^{regex_pattern}$", channel): 460 | for connection in connections: 461 | try: 462 | # Format pmessage in Redis protocol 463 | msg = [ 464 | "pmessage", 465 | pattern, 466 | channel, 467 | message 468 | ] 469 | encoded = await self._encode_response(msg) 470 | connection.writer.write(encoded) 471 | await connection.writer.drain() 472 | published_count += 1 473 | except (ConnectionError, BrokenPipeError, asyncio.CancelledError) as e: 474 | log.warning(f"Error publishing to pattern subscriber: {e}") 475 | except re.error: 476 | # Skip invalid patterns 477 | continue 478 | 479 | return published_count 480 | 481 | async def handle_subscribe(self, connection: AsyncRedisConnection, *channels: str) -> None: 482 | """Subscribe to channels""" 483 | if not channels: 484 | return RedisError("wrong number of arguments for 'subscribe' command") 485 | 486 | # Subscription count 487 | count = 0 488 | 489 | # For each channel 490 | for channel in channels: 491 | # Create channel list if it doesn't exist 492 | if channel not in self.channels: 493 | self.channels[channel] = [] 494 | 495 | # Add connection to channel subscribers if not already there 496 | if connection not in self.channels[channel]: 497 | self.channels[channel].append(connection) 498 | 499 | # Send subscription confirmation to client 500 | # Format: *3\r\n$9\r\nsubscribe\r\n${len(channel)}\r\n{channel}\r\n:{count}\r\n 501 | try: 502 | count += 1 503 | msg = [ 504 | "subscribe", 505 | channel, 506 | count 507 | ] 508 | encoded = await self._encode_response(msg) 509 | connection.writer.write(encoded) 510 | await connection.writer.drain() 511 | except (ConnectionError, BrokenPipeError, asyncio.CancelledError) as e: 512 | log.warning(f"Error sending subscribe confirmation: {e}") 513 | 514 | # Note: For true Redis behavior, connections in subscription mode 515 | # should only accept subscription-related commands until unsubscribed. 516 | # This would require modifying the handle_client method. 517 | return None 518 | 519 | async def handle_unsubscribe(self, connection: AsyncRedisConnection, *channels: str) -> None: 520 | """Unsubscribe from channels""" 521 | # If no channels specified, unsubscribe from all 522 | if not channels: 523 | channels_to_check = list(self.channels.keys()) 524 | else: 525 | channels_to_check = channels 526 | 527 | count = 0 528 | 529 | for channel in channels_to_check: 530 | if channel in self.channels and connection in self.channels[channel]: 531 | # Remove connection from channel subscribers 532 | self.channels[channel].remove(connection) 533 | 534 | # If channel has no subscribers, remove it 535 | if not self.channels[channel]: 536 | del self.channels[channel] 537 | 538 | # Send unsubscription confirmation to client 539 | try: 540 | count += 1 541 | msg = [ 542 | "unsubscribe", 543 | channel, 544 | count 545 | ] 546 | encoded = await self._encode_response(msg) 547 | connection.writer.write(encoded) 548 | await connection.writer.drain() 549 | except (ConnectionError, BrokenPipeError, asyncio.CancelledError) as e: 550 | log.warning(f"Error sending unsubscribe confirmation: {e}") 551 | 552 | return None 553 | 554 | async def handle_psubscribe(self, connection: AsyncRedisConnection, *patterns: str) -> None: 555 | """Subscribe to channels matching patterns""" 556 | if not patterns: 557 | return RedisError("wrong number of arguments for 'psubscribe' command") 558 | 559 | count = 0 560 | 561 | for pattern in patterns: 562 | # Create pattern list if it doesn't exist 563 | if pattern not in self.channels: 564 | self.channels[pattern] = [] 565 | 566 | # Add connection to pattern subscribers if not already there 567 | if connection not in self.channels[pattern]: 568 | self.channels[pattern].append(connection) 569 | 570 | # Send subscription confirmation to client 571 | try: 572 | count += 1 573 | msg = [ 574 | "psubscribe", 575 | pattern, 576 | count 577 | ] 578 | encoded = await self._encode_response(msg) 579 | connection.writer.write(encoded) 580 | await connection.writer.drain() 581 | except (ConnectionError, BrokenPipeError, asyncio.CancelledError) as e: 582 | log.warning(f"Error sending psubscribe confirmation: {e}") 583 | 584 | return None 585 | 586 | async def handle_punsubscribe(self, connection: AsyncRedisConnection, *patterns: str) -> None: 587 | """Unsubscribe from channels matching patterns""" 588 | if not patterns: 589 | # Unsubscribe from all patterns 590 | # In a real implementation, we would need to differentiate between patterns and channels 591 | # For this simple implementation, we assume patterns contain * or ? 592 | patterns_to_check = [p for p in self.channels.keys() if '*' in p or '?' in p] 593 | else: 594 | patterns_to_check = patterns 595 | 596 | count = 0 597 | 598 | for pattern in patterns_to_check: 599 | if pattern in self.channels and connection in self.channels[pattern]: 600 | # Remove connection from pattern subscribers 601 | self.channels[pattern].remove(connection) 602 | 603 | # If pattern has no subscribers, remove it 604 | if not self.channels[pattern]: 605 | del self.channels[pattern] 606 | 607 | # Send unsubscription confirmation to client 608 | try: 609 | count += 1 610 | msg = [ 611 | "punsubscribe", 612 | pattern, 613 | count 614 | ] 615 | encoded = await self._encode_response(msg) 616 | connection.writer.write(encoded) 617 | await connection.writer.drain() 618 | except (ConnectionError, BrokenPipeError, asyncio.CancelledError) as e: 619 | log.warning(f"Error sending punsubscribe confirmation: {e}") 620 | 621 | return None 622 | 623 | # --- Redis Key Commands --- 624 | 625 | async def handle_del(self, *args: str) -> int: 626 | """Delete one or more keys, returns the number of keys removed""" 627 | if not args: 628 | return RedisError("wrong number of arguments for 'del' command") 629 | 630 | # Get the correct DB from the current task's connection 631 | task = asyncio.current_task() 632 | if task and hasattr(task, 'connection'): 633 | connection = getattr(task, 'connection') 634 | db_num = connection.db 635 | else: 636 | db_num = 0 # Default to DB 0 637 | 638 | count = 0 639 | for key in args: 640 | # Check if key exists and remove timeouts 641 | timeout_key = f"{db_num} {key}" 642 | if timeout_key in self.timeouts: 643 | del self.timeouts[timeout_key] 644 | 645 | # Delete the key from DB 646 | if key in self.tables[db_num]: 647 | del self.tables[db_num][key] 648 | count += 1 649 | 650 | return count 651 | 652 | async def handle_exists(self, *keys: str) -> int: 653 | """Check if one or more keys exist""" 654 | if not keys: 655 | return RedisError("wrong number of arguments for 'exists' command") 656 | 657 | # Get the correct DB from the current task's connection 658 | task = asyncio.current_task() 659 | if task and hasattr(task, 'connection'): 660 | connection = getattr(task, 'connection') 661 | db_num = connection.db 662 | else: 663 | db_num = 0 # Default to DB 0 if no connection context 664 | 665 | count = 0 666 | for key in keys: 667 | # Check if key exists in the current DB 668 | if key in self.tables[db_num]: 669 | # Check TTL - skip if expired 670 | if await self.check_ttl(db_num, key): 671 | count += 1 672 | 673 | return count 674 | 675 | async def handle_expire(self, key: str, seconds: str) -> int: 676 | """Set a key's time to live in seconds""" 677 | if not key or not seconds: 678 | return RedisError("wrong number of arguments for 'expire' command") 679 | 680 | try: 681 | ttl = int(seconds) 682 | except ValueError: 683 | return RedisError("value is not an integer or out of range") 684 | 685 | # Get the correct DB from the current task's connection 686 | task = asyncio.current_task() 687 | if task and hasattr(task, 'connection'): 688 | connection = getattr(task, 'connection') 689 | db_num = connection.db 690 | 691 | # Check if key exists in this specific db 692 | if key in self.tables[db_num]: 693 | self.timeouts[f"{db_num} {key}"] = time.time() + ttl 694 | return 1 695 | return 0 # Key doesn't exist 696 | else: 697 | # Fallback to checking all DBs 698 | db_nums = list(self.tables.keys()) 699 | for db_num in db_nums: 700 | if key in self.tables[db_num]: 701 | self.timeouts[f"{db_num} {key}"] = time.time() + ttl 702 | return 1 703 | return 0 704 | 705 | async def handle_expireat(self, key: str, timestamp: str) -> int: 706 | """Set the expiration for a key at a UNIX timestamp""" 707 | if not key or not timestamp: 708 | return RedisError("wrong number of arguments for 'expireat' command") 709 | 710 | try: 711 | ts = int(timestamp) 712 | except ValueError: 713 | return RedisError("value is not an integer or out of range") 714 | 715 | # Get the correct DB from the current task's connection 716 | task = asyncio.current_task() 717 | if task and hasattr(task, 'connection'): 718 | connection = getattr(task, 'connection') 719 | db_nums = [connection.db] 720 | else: 721 | db_nums = list(self.tables.keys()) 722 | 723 | for db_num in db_nums: 724 | if key in self.tables[db_num]: 725 | self.timeouts[f"{db_num} {key}"] = ts 726 | return 1 727 | 728 | return 0 # Key does not exist 729 | 730 | async def handle_ttl(self, key: str) -> int: 731 | """Get the time to live for a key in seconds""" 732 | if not key: 733 | return RedisError("wrong number of arguments for 'ttl' command") 734 | 735 | # Get the correct DB from the current task's connection 736 | task = asyncio.current_task() 737 | if task and hasattr(task, 'connection'): 738 | connection = getattr(task, 'connection') 739 | db_nums = [connection.db] 740 | else: 741 | db_nums = list(self.tables.keys()) 742 | 743 | for db_num in db_nums: 744 | # Check if the key exists 745 | if not await self.check_ttl(db_num, key): 746 | continue 747 | 748 | if key not in self.tables[db_num]: 749 | continue 750 | 751 | timeout_key = f"{db_num} {key}" 752 | if timeout_key in self.timeouts: 753 | ttl = int(self.timeouts[timeout_key] - time.time()) 754 | return max(0, ttl) # Return at least 0 755 | else: 756 | return -1 # Key exists but has no TTL 757 | 758 | return -2 # Key does not exist 759 | 760 | async def handle_pttl(self, key: str) -> int: 761 | """Get the time to live for a key in milliseconds""" 762 | ttl = await self.handle_ttl(key) 763 | if isinstance(ttl, RedisError) or ttl < 0: 764 | return ttl # Pass through error or special values 765 | return ttl * 1000 # Convert to milliseconds 766 | 767 | async def handle_pexpire(self, key: str, milliseconds: str) -> int: 768 | """Set a key's time to live in milliseconds""" 769 | if not key or not milliseconds: 770 | return RedisError("wrong number of arguments for 'pexpire' command") 771 | 772 | try: 773 | ms = int(milliseconds) 774 | seconds = ms / 1000 775 | except ValueError: 776 | return RedisError("value is not an integer or out of range") 777 | 778 | return await self.handle_expire(key, str(seconds)) 779 | 780 | async def handle_persist(self, key: str) -> int: 781 | """Remove the expiration from a key""" 782 | if not key: 783 | return RedisError("wrong number of arguments for 'persist' command") 784 | 785 | # Try all DBs if we don't have connection context 786 | if 'connection' in locals(): 787 | db_nums = [connection.db] 788 | else: 789 | db_nums = list(self.tables.keys()) 790 | 791 | for db_num in db_nums: 792 | if key in self.tables[db_num]: 793 | timeout_key = f"{db_num} {key}" 794 | if timeout_key in self.timeouts: 795 | del self.timeouts[timeout_key] 796 | return 1 797 | 798 | return 0 799 | 800 | async def handle_keys(self, pattern: str) -> List[bytes]: 801 | """Find all keys matching the given pattern""" 802 | if not pattern: 803 | return RedisError("wrong number of arguments for 'keys' command") 804 | 805 | # Get the correct DB from the current task's connection 806 | task = asyncio.current_task() 807 | if task and hasattr(task, 'connection'): 808 | connection = getattr(task, 'connection') 809 | db_num = connection.db 810 | else: 811 | db_num = 0 # Default to DB 0 if no connection context 812 | 813 | matching_keys = [] 814 | 815 | # Properly escape regex special characters in the pattern except * and ? 816 | pattern_regex = re.escape(pattern).replace('\\*', '.*').replace('\\?', '.') 817 | regex = re.compile(f"^{pattern_regex}$") 818 | 819 | # Look through each key in the specific DB only 820 | if db_num in self.tables: 821 | for key in list(self.tables[db_num].keys()): 822 | # Skip expired keys 823 | if not await self.check_ttl(db_num, key): 824 | continue 825 | 826 | # Check if key matches pattern 827 | if regex.match(key): 828 | # Return key as bytes for Redis protocol compatibility 829 | matching_keys.append(key.encode()) 830 | 831 | return matching_keys 832 | 833 | async def handle_type(self, key: str) -> RedisMessage: 834 | """Determine the type stored at key""" 835 | if not key: 836 | return RedisError("wrong number of arguments for 'type' command") 837 | 838 | # Try all DBs if we don't have connection context 839 | if 'connection' in locals(): 840 | db_nums = [connection.db] 841 | else: 842 | db_nums = list(self.tables.keys()) 843 | 844 | for db_num in db_nums: 845 | if not await self.check_ttl(db_num, key): 846 | continue 847 | 848 | if key not in self.tables[db_num]: 849 | continue 850 | 851 | data = self.tables[db_num][key] 852 | 853 | if isinstance(data, str): 854 | return RedisMessage("string") 855 | elif isinstance(data, dict): 856 | return RedisMessage("hash") 857 | elif isinstance(data, deque): 858 | return RedisMessage("list") 859 | elif isinstance(data, set): 860 | return RedisMessage("set") 861 | elif isinstance(data, SortedSet): 862 | return RedisMessage("zset") 863 | else: 864 | return RedisMessage("unknown") 865 | 866 | return RedisMessage("none") 867 | 868 | # --- String Commands --- 869 | 870 | async def handle_append(self, key: str, value: str) -> int: 871 | """Append a value to a key""" 872 | if not key or not value: 873 | return RedisError("wrong number of arguments for 'append' command") 874 | 875 | # Try all DBs if we don't have connection context 876 | if 'connection' in locals(): 877 | db_nums = [connection.db] 878 | else: 879 | db_nums = list(self.tables.keys()) 880 | 881 | for db_num in db_nums: 882 | # Check if the key exists 883 | if await self.check_ttl(db_num, key): 884 | if key in self.tables[db_num]: 885 | data = self.tables[db_num][key] 886 | # Check if it's a string type 887 | if isinstance(data, str): 888 | self.tables[db_num][key] = data + value 889 | return len(self.tables[db_num][key]) 890 | else: 891 | return BAD_VALUE 892 | 893 | # If key doesn't exist, create it 894 | self.tables[db_num][key] = value 895 | return len(value) 896 | 897 | # Should not reach here if at least one DB exists 898 | return 0 899 | 900 | async def handle_incr(self, key: str) -> int: 901 | """Increment the integer value of a key by one""" 902 | return await self.handle_incrby(key, "1") 903 | 904 | async def handle_decr(self, key: str) -> int: 905 | """Decrement the integer value of a key by one""" 906 | return await self.handle_incrby(key, "-1") 907 | 908 | async def handle_incrby(self, key: str, increment: str) -> int: 909 | """Increment the integer value of a key by the given amount""" 910 | if not key or not increment: 911 | return RedisError("wrong number of arguments for 'incrby' command") 912 | 913 | try: 914 | incr = int(increment) 915 | except ValueError: 916 | return RedisError("value is not an integer or out of range") 917 | 918 | # Try all DBs if we don't have connection context 919 | if 'connection' in locals(): 920 | db_nums = [connection.db] 921 | else: 922 | db_nums = list(self.tables.keys()) 923 | 924 | for db_num in db_nums: 925 | # Check if key exists and is valid 926 | if await self.check_ttl(db_num, key): 927 | if key in self.tables[db_num]: 928 | current_value = self.tables[db_num][key] 929 | try: 930 | # Try to convert current value to int 931 | current_int = int(current_value) 932 | new_value = current_int + incr 933 | self.tables[db_num][key] = str(new_value) 934 | return new_value 935 | except (ValueError, TypeError): 936 | return RedisError("value is not an integer or out of range") 937 | 938 | # Key doesn't exist - create with value of increment 939 | self.tables[db_num][key] = str(incr) 940 | return incr 941 | 942 | # Should not reach here if at least one DB exists 943 | return 0 944 | 945 | async def handle_decrby(self, key: str, decrement: str) -> int: 946 | """Decrement the integer value of a key by the given amount""" 947 | try: 948 | decr = int(decrement) 949 | except ValueError: 950 | return RedisError("value is not an integer or out of range") 951 | # Use incrby with negated value 952 | return await self.handle_incrby(key, str(-decr)) 953 | 954 | async def handle_mget(self, *keys: str) -> List[Optional[str]]: 955 | """Get the values of all specified keys""" 956 | result = [] 957 | 958 | # Try all DBs if we don't have connection context 959 | if 'connection' in locals(): 960 | db_nums = [connection.db] 961 | else: 962 | db_nums = list(self.tables.keys()) 963 | 964 | for key in keys: 965 | found = False 966 | for db_num in db_nums: 967 | if await self.check_ttl(db_num, key): 968 | if key in self.tables[db_num]: 969 | data = self.tables[db_num][key] 970 | if isinstance(data, str): 971 | result.append(data) 972 | found = True 973 | break 974 | else: 975 | # Wrong type 976 | result.append(None) 977 | found = True 978 | break 979 | if not found: 980 | result.append(None) 981 | 982 | return result 983 | 984 | async def handle_getset(self, key: str, value: str) -> Optional[str]: 985 | """Set the string value of a key and return its old value""" 986 | if not key or value is None: 987 | return RedisError("wrong number of arguments for 'getset' command") 988 | 989 | # Try all DBs if we don't have connection context 990 | if 'connection' in locals(): 991 | db_nums = [connection.db] 992 | else: 993 | db_nums = list(self.tables.keys()) 994 | 995 | for db_num in db_nums: 996 | old_value = None 997 | if await self.check_ttl(db_num, key): 998 | if key in self.tables[db_num]: 999 | old_value = self.tables[db_num][key] 1000 | if not isinstance(old_value, str): 1001 | return BAD_VALUE 1002 | 1003 | # Set the new value 1004 | self.tables[db_num][key] = value 1005 | return old_value 1006 | 1007 | # Should not reach here if at least one DB exists 1008 | return None 1009 | 1010 | async def handle_setex(self, key: str, seconds: str, value: str) -> str: 1011 | """Set the value and expiration of a key""" 1012 | if not key or not seconds or value is None: 1013 | return RedisError("wrong number of arguments for 'setex' command") 1014 | 1015 | try: 1016 | ttl = int(seconds) 1017 | if ttl <= 0: 1018 | return RedisError("invalid expire time in 'setex' command") 1019 | except ValueError: 1020 | return RedisError("value is not an integer or out of range") 1021 | 1022 | # Try all DBs if we don't have connection context 1023 | task = asyncio.current_task() 1024 | if task and hasattr(task, 'connection'): 1025 | connection = getattr(task, 'connection') 1026 | db_nums = [connection.db] 1027 | else: 1028 | db_nums = list(self.tables.keys()) 1029 | 1030 | for db_num in db_nums: 1031 | # Set the key 1032 | self.tables[db_num][key] = value 1033 | # Set expiration 1034 | self.timeouts[f"{db_num} {key}"] = time.time() + ttl 1035 | return "OK" # Return "OK" instead of True 1036 | 1037 | # Should not reach here if at least one DB exists 1038 | return "OK" # Default to OK 1039 | 1040 | async def handle_setnx(self, key: str, value: str) -> int: 1041 | """Set the value of a key, only if the key does not exist""" 1042 | if not key or value is None: 1043 | return RedisError("wrong number of arguments for 'setnx' command") 1044 | 1045 | # Try all DBs if we don't have connection context 1046 | if 'connection' in locals(): 1047 | db_nums = [connection.db] 1048 | else: 1049 | db_nums = list(self.tables.keys()) 1050 | 1051 | for db_num in db_nums: 1052 | # Check if key exists 1053 | if await self.check_ttl(db_num, key) and key in self.tables[db_num]: 1054 | return 0 # Key exists, do nothing 1055 | 1056 | # Set key if it doesn't exist 1057 | self.tables[db_num][key] = value 1058 | return 1 1059 | 1060 | # Should not reach here if at least one DB exists 1061 | return 0 1062 | 1063 | # --- List Commands --- 1064 | 1065 | async def handle_lpush(self, key: str, value: str, *values: str) -> int: 1066 | """Push one or more values to the head of a list""" 1067 | if not key or value is None: 1068 | return RedisError("wrong number of arguments for 'lpush' command") 1069 | 1070 | # Try all DBs if we don't have connection context 1071 | if 'connection' in locals(): 1072 | db_nums = [connection.db] 1073 | else: 1074 | db_nums = list(self.tables.keys()) 1075 | 1076 | for db_num in db_nums: 1077 | # Check if key exists 1078 | if await self.check_ttl(db_num, key) and key in self.tables[db_num]: 1079 | # Get the existing list 1080 | data = self.tables[db_num][key] 1081 | if not isinstance(data, deque): 1082 | return BAD_VALUE 1083 | 1084 | # Add new values 1085 | data.appendleft(value) 1086 | for v in values: 1087 | data.appendleft(v) 1088 | return len(data) 1089 | 1090 | # Create a new list 1091 | all_values = [value] + list(values) 1092 | q = deque(all_values) 1093 | self.tables[db_num][key] = q 1094 | return len(q) 1095 | 1096 | # Should not reach here if at least one DB exists 1097 | return 0 1098 | 1099 | async def handle_rpush(self, key: str, value: str, *values: str) -> int: 1100 | """Push one or more values to the tail of a list""" 1101 | if not key or value is None: 1102 | return RedisError("wrong number of arguments for 'rpush' command") 1103 | 1104 | # Try all DBs if we don't have connection context 1105 | if 'connection' in locals(): 1106 | db_nums = [connection.db] 1107 | else: 1108 | db_nums = list(self.tables.keys()) 1109 | 1110 | for db_num in db_nums: 1111 | # Check if key exists 1112 | if await self.check_ttl(db_num, key) and key in self.tables[db_num]: 1113 | # Get the existing list 1114 | data = self.tables[db_num][key] 1115 | if not isinstance(data, deque): 1116 | return BAD_VALUE 1117 | 1118 | # Add new values 1119 | data.append(value) 1120 | for v in values: 1121 | data.append(v) 1122 | return len(data) 1123 | 1124 | # Create a new list 1125 | all_values = [value] + list(values) 1126 | q = deque(all_values) 1127 | self.tables[db_num][key] = q 1128 | return len(q) 1129 | 1130 | # Should not reach here if at least one DB exists 1131 | return 0 1132 | 1133 | async def handle_lpop(self, key: str) -> str: 1134 | """Remove and return the first element of a list""" 1135 | if not key: 1136 | return RedisError("wrong number of arguments for 'lpop' command") 1137 | 1138 | # Try all DBs if we don't have connection context 1139 | if 'connection' in locals(): 1140 | db_nums = [connection.db] 1141 | else: 1142 | db_nums = list(self.tables.keys()) 1143 | 1144 | for db_num in db_nums: 1145 | # Check if key exists 1146 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1147 | continue 1148 | 1149 | data = self.tables[db_num][key] 1150 | if not isinstance(data, deque): 1151 | return BAD_VALUE 1152 | 1153 | if len(data) == 0: 1154 | return None 1155 | 1156 | return data.popleft() 1157 | 1158 | return None # Key does not exist 1159 | 1160 | async def handle_rpop(self, key: str) -> str: 1161 | """Remove and return the last element of a list""" 1162 | if not key: 1163 | return RedisError("wrong number of arguments for 'rpop' command") 1164 | 1165 | # Try all DBs if we don't have connection context 1166 | if 'connection' in locals(): 1167 | db_nums = [connection.db] 1168 | else: 1169 | db_nums = list(self.tables.keys()) 1170 | 1171 | for db_num in db_nums: 1172 | # Check if key exists 1173 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1174 | continue 1175 | 1176 | data = self.tables[db_num][key] 1177 | if not isinstance(data, deque): 1178 | return BAD_VALUE 1179 | 1180 | if len(data) == 0: 1181 | return None 1182 | 1183 | return data.pop() 1184 | 1185 | return None # Key does not exist 1186 | 1187 | async def handle_llen(self, key: str) -> int: 1188 | """Get the length of a list""" 1189 | if not key: 1190 | return RedisError("wrong number of arguments for 'llen' command") 1191 | 1192 | # Try all DBs if we don't have connection context 1193 | if 'connection' in locals(): 1194 | db_nums = [connection.db] 1195 | else: 1196 | db_nums = list(self.tables.keys()) 1197 | 1198 | for db_num in db_nums: 1199 | # Check if key exists 1200 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1201 | continue 1202 | 1203 | data = self.tables[db_num][key] 1204 | if not isinstance(data, deque): 1205 | return BAD_VALUE 1206 | 1207 | return len(data) 1208 | 1209 | return 0 # Key does not exist 1210 | 1211 | async def handle_lrange(self, key: str, start: str, stop: str) -> list: 1212 | """Get a range of elements from a list""" 1213 | if not key or start is None or stop is None: 1214 | return RedisError("wrong number of arguments for 'lrange' command") 1215 | 1216 | try: 1217 | start_idx = int(start) 1218 | stop_idx = int(stop) 1219 | except ValueError: 1220 | return RedisError("value is not an integer or out of range") 1221 | 1222 | # Try all DBs if we don't have connection context 1223 | if 'connection' in locals(): 1224 | db_nums = [connection.db] 1225 | else: 1226 | db_nums = list(self.tables.keys()) 1227 | 1228 | for db_num in db_nums: 1229 | # Check if key exists 1230 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1231 | continue 1232 | 1233 | data = self.tables[db_num][key] 1234 | if not isinstance(data, deque): 1235 | return BAD_VALUE 1236 | 1237 | # Convert to list for easier slicing 1238 | l = list(data) 1239 | 1240 | # Adjust negative indices 1241 | if start_idx < 0: 1242 | start_idx = len(l) + start_idx 1243 | if stop_idx < 0: 1244 | stop_idx = len(l) + stop_idx 1245 | 1246 | # Clamp indices 1247 | start_idx = max(0, start_idx) 1248 | stop_idx = min(len(l) - 1, stop_idx) 1249 | 1250 | # Return the range (inclusive on both ends like Redis) 1251 | return l[start_idx:stop_idx + 1] if start_idx <= stop_idx else [] 1252 | 1253 | return [] # Key does not exist 1254 | 1255 | # --- Hash Commands --- 1256 | 1257 | async def handle_hset(self, key: str, field: str, value: str) -> int: 1258 | """Set field in the hash stored at key to value""" 1259 | if not key or not field or value is None: 1260 | return RedisError("wrong number of arguments for 'hset' command") 1261 | 1262 | # Try all DBs if we don't have connection context 1263 | if 'connection' in locals(): 1264 | db_nums = [connection.db] 1265 | else: 1266 | db_nums = list(self.tables.keys()) 1267 | 1268 | for db_num in db_nums: 1269 | # Check if key exists 1270 | if await self.check_ttl(db_num, key) and key in self.tables[db_num]: 1271 | data = self.tables[db_num][key] 1272 | if not isinstance(data, dict): 1273 | return BAD_VALUE 1274 | 1275 | is_new = field not in data 1276 | data[field] = value 1277 | return 1 if is_new else 0 1278 | 1279 | # Create a new hash 1280 | self.tables[db_num][key] = {field: value} 1281 | return 1 1282 | 1283 | # Should not reach here if at least one DB exists 1284 | return 0 1285 | 1286 | async def handle_hget(self, key: str, field: str) -> str: 1287 | """Get the value of a hash field""" 1288 | if not key or not field: 1289 | return RedisError("wrong number of arguments for 'hget' command") 1290 | 1291 | # Try all DBs if we don't have connection context 1292 | if 'connection' in locals(): 1293 | db_nums = [connection.db] 1294 | else: 1295 | db_nums = list(self.tables.keys()) 1296 | 1297 | for db_num in db_nums: 1298 | # Check if key exists 1299 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1300 | continue 1301 | 1302 | data = self.tables[db_num][key] 1303 | if not isinstance(data, dict): 1304 | return BAD_VALUE 1305 | 1306 | if field not in data: 1307 | return None 1308 | 1309 | return data[field] 1310 | 1311 | return None # Key does not exist 1312 | 1313 | async def handle_hgetall(self, key: str) -> Union[list, RedisError]: 1314 | """Get all fields and values in a hash""" 1315 | if not key: 1316 | return RedisError("wrong number of arguments for 'hgetall' command") 1317 | 1318 | # Try all DBs if we don't have connection context 1319 | if 'connection' in locals(): 1320 | db_nums = [connection.db] 1321 | else: 1322 | db_nums = list(self.tables.keys()) 1323 | 1324 | for db_num in db_nums: 1325 | # Check if key exists 1326 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1327 | continue 1328 | 1329 | data = self.tables[db_num][key] 1330 | if not isinstance(data, dict): 1331 | return BAD_VALUE 1332 | 1333 | # Convert dict to flat list of [key, value, key, value, ...] 1334 | result = [] 1335 | for k, v in data.items(): 1336 | result.append(k) 1337 | result.append(v) 1338 | return result 1339 | 1340 | return [] # Key does not exist or is empty 1341 | 1342 | async def handle_hdel(self, key: str, *fields: str) -> int: 1343 | """Delete one or more hash fields""" 1344 | if not key or not fields: 1345 | return RedisError("wrong number of arguments for 'hdel' command") 1346 | 1347 | # Try all DBs if we don't have connection context 1348 | if 'connection' in locals(): 1349 | db_nums = [connection.db] 1350 | else: 1351 | db_nums = list(self.tables.keys()) 1352 | 1353 | for db_num in db_nums: 1354 | # Check if key exists 1355 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1356 | continue 1357 | 1358 | data = self.tables[db_num][key] 1359 | if not isinstance(data, dict): 1360 | return BAD_VALUE 1361 | 1362 | removed = 0 1363 | for field in fields: 1364 | if field in data: 1365 | del data[field] 1366 | removed += 1 1367 | 1368 | return removed 1369 | 1370 | return 0 # Key does not exist 1371 | 1372 | async def handle_hexists(self, key: str, field: str) -> int: 1373 | """Determine if a hash field exists""" 1374 | if not key or not field: 1375 | return RedisError("wrong number of arguments for 'hexists' command") 1376 | 1377 | # Try all DBs if we don't have connection context 1378 | if 'connection' in locals(): 1379 | db_nums = [connection.db] 1380 | else: 1381 | db_nums = list(self.tables.keys()) 1382 | 1383 | for db_num in db_nums: 1384 | # Check if key exists 1385 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1386 | continue 1387 | 1388 | data = self.tables[db_num][key] 1389 | if not isinstance(data, dict): 1390 | return BAD_VALUE 1391 | 1392 | return 1 if field in data else 0 1393 | 1394 | return 0 # Key does not exist 1395 | 1396 | async def handle_hincrby(self, key: str, field: str, increment: str) -> int: 1397 | """Increment the integer value of a hash field""" 1398 | if not key or not field or not increment: 1399 | return RedisError("wrong number of arguments for 'hincrby' command") 1400 | 1401 | try: 1402 | incr = int(increment) 1403 | except ValueError: 1404 | return RedisError("value is not an integer or out of range") 1405 | 1406 | # Try all DBs if we don't have connection context 1407 | if 'connection' in locals(): 1408 | db_nums = [connection.db] 1409 | else: 1410 | db_nums = list(self.tables.keys()) 1411 | 1412 | for db_num in db_nums: 1413 | # Check if key exists or create a new hash 1414 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1415 | self.tables[db_num][key] = {} 1416 | 1417 | data = self.tables[db_num][key] 1418 | if not isinstance(data, dict): 1419 | return BAD_VALUE 1420 | 1421 | try: 1422 | # Get current value or default to 0 1423 | current = data.get(field, "0") 1424 | current_int = int(current) 1425 | new_value = current_int + incr 1426 | # Store result as string to match Redis behavior 1427 | data[field] = str(new_value) 1428 | return new_value 1429 | except ValueError: 1430 | return RedisError("hash value is not an integer") 1431 | 1432 | # Should not reach here if at least one DB exists 1433 | return 0 1434 | 1435 | async def handle_hkeys(self, key: str) -> list: 1436 | """Get all the fields in a hash""" 1437 | if not key: 1438 | return RedisError("wrong number of arguments for 'hkeys' command") 1439 | 1440 | # Try all DBs if we don't have connection context 1441 | if 'connection' in locals(): 1442 | db_nums = [connection.db] 1443 | else: 1444 | db_nums = list(self.tables.keys()) 1445 | 1446 | for db_num in db_nums: 1447 | # Check if key exists 1448 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1449 | continue 1450 | 1451 | data = self.tables[db_num][key] 1452 | if not isinstance(data, dict): 1453 | return BAD_VALUE 1454 | 1455 | return list(data.keys()) 1456 | 1457 | return [] # Key does not exist or is empty 1458 | 1459 | async def handle_hvals(self, key: str) -> list: 1460 | """Get all the values in a hash""" 1461 | if not key: 1462 | return RedisError("wrong number of arguments for 'hvals' command") 1463 | 1464 | # Try all DBs if we don't have connection context 1465 | if 'connection' in locals(): 1466 | db_nums = [connection.db] 1467 | else: 1468 | db_nums = list(self.tables.keys()) 1469 | 1470 | for db_num in db_nums: 1471 | # Check if key exists 1472 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1473 | continue 1474 | 1475 | data = self.tables[db_num][key] 1476 | if not isinstance(data, dict): 1477 | return BAD_VALUE 1478 | 1479 | return list(data.values()) 1480 | 1481 | return [] # Key does not exist or is empty 1482 | 1483 | async def handle_hlen(self, key: str) -> int: 1484 | """Get the number of fields in a hash""" 1485 | if not key: 1486 | return RedisError("wrong number of arguments for 'hlen' command") 1487 | 1488 | # Try all DBs if we don't have connection context 1489 | if 'connection' in locals(): 1490 | db_nums = [connection.db] 1491 | else: 1492 | db_nums = list(self.tables.keys()) 1493 | 1494 | for db_num in db_nums: 1495 | # Check if key exists 1496 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1497 | continue 1498 | 1499 | data = self.tables[db_num][key] 1500 | if not isinstance(data, dict): 1501 | return BAD_VALUE 1502 | 1503 | return len(data) 1504 | 1505 | return 0 # Key does not exist 1506 | 1507 | # --- Sorted Set Commands --- 1508 | 1509 | async def handle_zadd(self, key: str, score: str, member: str, *args: str) -> int: 1510 | """Add one or more members to a sorted set, or update scores if they already exist""" 1511 | if not key or score is None or member is None: 1512 | return RedisError("wrong number of arguments for 'zadd' command") 1513 | 1514 | try: 1515 | # Convert first score to float 1516 | float(score) 1517 | except ValueError: 1518 | return RedisError("value is not a valid float") 1519 | 1520 | # Verify additional args come in score-member pairs 1521 | if len(args) % 2 != 0: 1522 | return RedisError("syntax error: wrong number of arguments") 1523 | 1524 | # Try all DBs if we don't have connection context 1525 | if 'connection' in locals(): 1526 | db_nums = [connection.db] 1527 | else: 1528 | db_nums = list(self.tables.keys()) 1529 | 1530 | for db_num in db_nums: 1531 | # Check if key exists and is a sorted set 1532 | if await self.check_ttl(db_num, key) and key in self.tables[db_num]: 1533 | data = self.tables[db_num][key] 1534 | if not isinstance(data, SortedSet): 1535 | return BAD_VALUE 1536 | else: 1537 | # Create new sorted set 1538 | data = SortedSet() 1539 | self.tables[db_num][key] = data 1540 | 1541 | # Add first score-member pair 1542 | added = 0 1543 | if data.add(member, float(score)): 1544 | added += 1 1545 | 1546 | # Add remaining score-member pairs 1547 | for i in range(0, len(args), 2): 1548 | try: 1549 | s = float(args[i]) 1550 | m = args[i + 1] 1551 | if data.add(m, s): 1552 | added += 1 1553 | except ValueError: 1554 | # Skip invalid scores but continue processing 1555 | continue 1556 | 1557 | return added 1558 | 1559 | # Should not reach here if at least one DB exists 1560 | return 0 1561 | 1562 | async def handle_zrange(self, key: str, start: str, stop: str, *args: str) -> list: 1563 | """Return a range of members from a sorted set, by index""" 1564 | if not key or start is None or stop is None: 1565 | return RedisError("wrong number of arguments for 'zrange' command") 1566 | 1567 | try: 1568 | start_idx = int(start) 1569 | stop_idx = int(stop) 1570 | except ValueError: 1571 | return RedisError("value is not an integer") 1572 | 1573 | # Parse options 1574 | withscores = False 1575 | for arg in args: 1576 | if arg.lower() == "withscores": 1577 | withscores = True 1578 | 1579 | # Try all DBs if we don't have connection context 1580 | task = asyncio.current_task() 1581 | if task and hasattr(task, 'connection'): 1582 | connection = getattr(task, 'connection') 1583 | db_nums = [connection.db] 1584 | else: 1585 | db_nums = list(self.tables.keys()) 1586 | 1587 | for db_num in db_nums: 1588 | # Check if key exists 1589 | if not await self.check_ttl(db_num, key) or key not in self.tables[db_num]: 1590 | return [] 1591 | 1592 | data = self.tables[db_num][key] 1593 | if not isinstance(data, SortedSet): 1594 | return BAD_VALUE 1595 | 1596 | # Handle negative indices like Redis 1597 | length = len(data) 1598 | if start_idx < 0: 1599 | start_idx = length + start_idx 1600 | if stop_idx < 0: 1601 | stop_idx = length + stop_idx 1602 | 1603 | # Clamp indices 1604 | start_idx = max(0, start_idx) 1605 | stop_idx = min(length - 1, stop_idx) 1606 | 1607 | # Get the range including scores if requested 1608 | if start_idx <= stop_idx: 1609 | result = [] 1610 | items = data.range_by_rank(start_idx, stop_idx + 1) 1611 | for member, score in items: 1612 | result.append(member) 1613 | if withscores: 1614 | result.append(str(score)) 1615 | return result 1616 | return [] 1617 | 1618 | return [] # Key does not exist 1619 | 1620 | async def start(self) -> None: 1621 | """Start the Redis server.""" 1622 | log.info(f"Starting AsyncRedisServer on {self.host}:{self.port}") 1623 | self._server = await asyncio.start_server( 1624 | self.handle_client, 1625 | self.host, 1626 | self.port 1627 | ) 1628 | 1629 | # Create a background task to check for expired keys 1630 | self._expiry_task = asyncio.create_task(self._check_expirations()) 1631 | 1632 | # Create a background task to periodically save data 1633 | self._save_task = asyncio.create_task(self._auto_save()) 1634 | 1635 | addr = self._server.sockets[0].getsockname() if self._server.sockets else (self.host, self.port) 1636 | log.info(f"AsyncRedisServer running on {addr[0]}:{addr[1]}") 1637 | 1638 | async def stop(self) -> None: 1639 | """Stop the Redis server and close all connections.""" 1640 | log.info("Stopping async Redis server...") 1641 | 1642 | # Cancel background tasks 1643 | if hasattr(self, '_expiry_task') and self._expiry_task: 1644 | self._expiry_task.cancel() 1645 | 1646 | if hasattr(self, '_save_task') and self._save_task: 1647 | self._save_task.cancel() 1648 | 1649 | # Save data before shutdown 1650 | await self.save_data() 1651 | 1652 | # Close the server 1653 | if self._server: 1654 | self._server.close() 1655 | await self._server.wait_closed() 1656 | 1657 | # Cancel any remaining client tasks 1658 | for task in self._tasks: 1659 | if not task.done(): 1660 | task.cancel() 1661 | 1662 | log.info("AsyncRedisServer stopped.") 1663 | 1664 | async def _check_expirations(self) -> None: 1665 | """Background task to check for expired keys.""" 1666 | while True: 1667 | try: 1668 | # Check each timeout entry 1669 | for key in list(self.timeouts.keys()): 1670 | try: 1671 | db_key = key.split(' ', 1) 1672 | if len(db_key) != 2: 1673 | continue 1674 | 1675 | db_num = int(db_key[0]) 1676 | key_name = db_key[1] 1677 | 1678 | # If expired, remove the key 1679 | if self.timeouts[key] <= time.time(): 1680 | if db_num in self.tables and key_name in self.tables[db_num]: 1681 | del self.tables[db_num][key_name] 1682 | del self.timeouts[key] 1683 | except (ValueError, KeyError): 1684 | # Skip invalid entries 1685 | continue 1686 | 1687 | # Sleep to avoid consuming too many resources 1688 | await asyncio.sleep(0.1) 1689 | except asyncio.CancelledError: 1690 | # Clean exit on cancellation 1691 | break 1692 | except Exception as e: 1693 | log.exception(f"Error in expiration check: {e}") 1694 | # Continue running despite errors 1695 | await asyncio.sleep(1) 1696 | 1697 | async def _auto_save(self) -> None: 1698 | """Background task to periodically save data.""" 1699 | save_interval = 300 # Save every 5 minutes 1700 | last_save = time.time() 1701 | 1702 | while True: 1703 | try: 1704 | current_time = time.time() 1705 | if current_time - last_save >= save_interval: 1706 | await self.save_data() 1707 | last_save = current_time 1708 | 1709 | # Sleep for a bit to avoid frequent checks 1710 | await asyncio.sleep(10) 1711 | except asyncio.CancelledError: 1712 | # Clean exit on cancellation 1713 | break 1714 | except Exception as e: 1715 | log.exception(f"Error in auto-save: {e}") 1716 | # Continue running despite errors 1717 | await asyncio.sleep(60) 1718 | -------------------------------------------------------------------------------- /miniredis/client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | Based on a minimalist Redis client originally written by Andrew Rowls 5 | 6 | Created by Rui Carmo on 2013-03-12 7 | Published under the MIT license. 8 | """ 9 | 10 | import logging 11 | import random 12 | import socket 13 | import time 14 | from multiprocessing import Pool 15 | from typing import Any, Callable, List, Optional, Tuple, Union, TypeVar, cast, Dict 16 | from typing import Protocol, Generic, ParamSpec 17 | 18 | log = logging.getLogger() 19 | 20 | # More advanced type variables for Python 3.10+ 21 | T = TypeVar("T") 22 | P = ParamSpec("P") # For handling arbitrary parameters in a generic way 23 | 24 | 25 | # Define a Protocol for Redis command results 26 | class RedisCommandResult(Protocol, Generic[T]): 27 | def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... 28 | 29 | 30 | class RedisClient: 31 | def __init__(self, host: str = "localhost", port: int = 6379) -> None: 32 | """Initialize a new Redis client connection. 33 | 34 | Args: 35 | host: Redis server hostname or IP 36 | port: Redis server port 37 | """ 38 | self.sock = socket.create_connection((host, port)) 39 | # Use binary mode for reading/writing 40 | self.file = self.sock.makefile("rwb", buffering=0) 41 | 42 | def __getattr__(self, attr: str) -> RedisCommandResult[Any]: 43 | # Map 'delete' attribute to 'DEL' command 44 | command = b"DEL" if attr == "delete" else attr.upper().encode("utf-8") 45 | 46 | def handle(*args: Any) -> Any: 47 | # Encode all arguments to bytes 48 | encoded_args: List[bytes] = [] 49 | for a in args: 50 | if isinstance(a, bytes): 51 | encoded_args.append(a) 52 | elif isinstance(a, str): 53 | encoded_args.append(a.encode("utf-8")) 54 | else: 55 | encoded_args.append(str(a).encode("utf-8")) 56 | 57 | # Build the command array 58 | cmd_parts = [command] + encoded_args 59 | cmd_str = f"*{len(cmd_parts)}\r\n".encode("utf-8") 60 | for part in cmd_parts: 61 | cmd_str += f"${len(part)}\r\n".encode("utf-8") + part + b"\r\n" 62 | 63 | # Send the command 64 | self.file.write(cmd_str) 65 | self.file.flush() # Ensure command is sent immediately 66 | return self.parse_response() 67 | 68 | return handle 69 | 70 | def parse_response(self) -> Any: 71 | """Parse a Redis protocol response from the server. 72 | 73 | Returns: 74 | The parsed response in appropriate Python type 75 | 76 | Raises: 77 | ConnectionError: If the connection is closed 78 | Exception: For Redis errors or protocol violations 79 | """ 80 | rsp = self.file.readline() 81 | if not rsp: 82 | # Connection closed or no response 83 | raise ConnectionError("Socket closed or no response received") 84 | 85 | type_byte, body = rsp[0:1], rsp[1:-2] # Keep as bytes 86 | 87 | match type_byte: # Using Python 3.10+ pattern matching 88 | case b"+": # Simple String 89 | return body.decode("utf-8") 90 | case b"-": # Error 91 | raise Exception(body.decode("utf-8")) 92 | case b":": # Integer 93 | return int(body) 94 | case b"$": # Bulk String 95 | length = int(body) 96 | return self.read_bulk(length) 97 | case b"*": # Array 98 | count = int(body) 99 | if count == -1: 100 | return None # Null array 101 | return [self.parse_response() for _ in range(count)] 102 | case _: 103 | # Should not happen with a conforming server 104 | raise ValueError( 105 | f'Unknown Return Value Type: "{type_byte.decode("utf-8", errors="backslashreplace")}"' 106 | ) 107 | 108 | def read_bulk(self, n: int) -> Optional[bytes]: 109 | """Read a bulk string of specific length. 110 | 111 | Args: 112 | n: The length of the bulk string 113 | 114 | Returns: 115 | The string data or None for null bulk string 116 | 117 | Raises: 118 | ConnectionError: If insufficient data is read 119 | Exception: If protocol is violated 120 | """ 121 | if n == -1: 122 | return None # Null bulk string 123 | # Read exactly n bytes + 2 for CRLF 124 | data = self.file.read(n + 2) 125 | if len(data) < n + 2: 126 | raise ConnectionError("Incomplete bulk string read") 127 | if data[-2:] != b"\r\n": 128 | raise ValueError("Bulk string missing CRLF") 129 | return data[:-2] # Return the data part as bytes 130 | 131 | def close(self) -> None: 132 | """Close the connection.""" 133 | try: 134 | self.file.close() 135 | self.sock.close() 136 | except (socket.error, OSError) as e: 137 | log.debug( 138 | f"Error closing connection: {e}" 139 | ) # Log the error but don't propagate 140 | 141 | 142 | if __name__ == "__main__": 143 | # Example usage and benchmark setup 144 | logging.basicConfig(level=logging.INFO) 145 | 146 | def timed(count: int) -> float: 147 | """Run a timed benchmark with GET/SET operations. 148 | 149 | Args: 150 | count: Number of GET/SET operation pairs to perform 151 | 152 | Returns: 153 | The elapsed time in seconds 154 | """ 155 | c = None 156 | try: 157 | c = RedisClient() 158 | c.select(0) # Select DB 0 159 | seq = list(range(0, 10000)) 160 | # Pre-populate some keys 161 | for i in range(min(1000, len(seq))): 162 | k = str(random.choice(seq)) 163 | c.set(k, "bar") 164 | 165 | now = time.time() 166 | for _ in range(count): 167 | k_get = str(random.choice(seq)) 168 | try: 169 | c.get(k_get) 170 | except Exception: 171 | # Handle potential errors during GET (e.g., key not found is None, not Exception) 172 | pass 173 | k_set = str(random.choice(seq)) 174 | c.set(k_set, "bar") # Set operation 175 | 176 | elapsed = time.time() - now 177 | return elapsed 178 | except Exception as e: 179 | log.error(f"Error in timed function: {e}") 180 | return float("inf") # Indicate failure 181 | finally: 182 | if c is not None: # Use is not None for more explicit comparison 183 | c.close() 184 | 185 | # Use more modern multiprocessing approach with context manager 186 | def run_benchmark(num_workers: int = 4, ops_per_worker: int = 10000) -> None: 187 | """Run a parallel benchmark with multiple workers. 188 | 189 | Args: 190 | num_workers: Number of parallel workers 191 | ops_per_worker: Operations per worker 192 | """ 193 | total_ops = num_workers * ops_per_worker * 2 # *2 because we do GET and SET 194 | 195 | print( 196 | f"Running benchmark with {num_workers} workers, {ops_per_worker} GET/SET pairs each..." 197 | ) 198 | 199 | with Pool(num_workers) as p: 200 | results = p.map(timed, [ops_per_worker] * num_workers) 201 | 202 | total_time = sum(r for r in results if r != float("inf")) 203 | successful_workers = sum(1 for r in results if r != float("inf")) 204 | 205 | if successful_workers > 0 and total_time > 0: 206 | # Calculate average time per worker, then overall ops/sec 207 | avg_time_per_worker = total_time / successful_workers 208 | # Estimate total ops based on successful workers 209 | estimated_total_ops = successful_workers * ops_per_worker * 2 210 | ops_sec = estimated_total_ops / total_time 211 | print(f"Benchmark finished.") 212 | print(f"Total operations (estimated): {estimated_total_ops}") 213 | print(f"Total time: {total_time:.4f} seconds") 214 | print(f"Operations per second: {ops_sec:.2f}") 215 | elif successful_workers == 0: 216 | print("Benchmark failed: All workers encountered errors.") 217 | else: # total_time is 0, should not happen unless ops_per_worker is 0 218 | print("Benchmark finished, but no time elapsed or no operations performed.") 219 | 220 | # Run the benchmark with default parameters 221 | run_benchmark() 222 | -------------------------------------------------------------------------------- /miniredis/haystack.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | Haystack.py 5 | 6 | An on-disk cache with a dict-like API, inspired by Facebook's Haystack store 7 | 8 | Created by Rui Carmo on 2010-04-05 9 | Published under the MIT license. 10 | """ 11 | 12 | __author__ = "Rui Carmo http://the.taoofmac.com" 13 | __revision__ = "$Id$" 14 | __version__ = "1.0" 15 | 16 | import os, sys, stat, mmap, threading, time, logging 17 | import pickle # Use standard pickle module in Python 3 18 | from typing import Any, Dict, List, Optional, Union, TypeAlias 19 | 20 | log = logging.getLogger() 21 | 22 | # Define type aliases for better type hinting 23 | IndexEntry: TypeAlias = List[float] # [mtime, length, offset] 24 | Index: TypeAlias = Dict[str, IndexEntry] 25 | 26 | 27 | class Haystack(dict): 28 | 29 | def __init__( 30 | self, 31 | path: str, 32 | basename: str = "haystack", 33 | commit: int = 300, 34 | compact: int = 3600, 35 | ) -> None: 36 | super().__init__() # Modern super() call syntax 37 | self.enabled = True 38 | self.mutex = threading.Lock() 39 | self.commit_interval = commit 40 | self.compact_interval = compact 41 | self.path = path 42 | self.basename = basename 43 | self.cache = os.path.join(self.path, self.basename + ".bin") 44 | self.index = os.path.join(self.path, self.basename + ".idx") 45 | self.temp = os.path.join(self.path, self.basename + ".tmp") 46 | self._rebuild() 47 | self.created = self.modified = self.compacted = self.committed = time.time() 48 | 49 | def _rebuild(self) -> None: 50 | """Rebuild the index from disk.""" 51 | self.mutex.acquire() 52 | try: 53 | os.makedirs(self.path, exist_ok=True) # Use exist_ok for Python 3 54 | except Exception as e: 55 | log.error(f"Error on makedirs({self.path}): {e}") 56 | 57 | # Use context manager for file operations 58 | try: 59 | with open(self.cache, "rb") as cache: 60 | pass # Just check if file exists and can be opened 61 | except (FileNotFoundError, PermissionError) as e: 62 | log.error(f"Error while opening {self.cache} for reading: {e}") 63 | try: 64 | with open(self.cache, "ab") as cache: 65 | pass # Create the file 66 | except OSError as e: 67 | log.error(f"Could not create cache file {self.cache}: {e}") 68 | self.enabled = False 69 | 70 | try: 71 | with open(self.index, "rb") as f: 72 | self._index: Index = pickle.loads(f.read()) 73 | except Exception as e: 74 | log.error(f"Index retrieval from disk failed: {e}") 75 | self._index = {} # "key": [mtime,length,offset] 76 | 77 | self.created = self.modified = self.compacted = self.committed = time.time() 78 | log.debug( 79 | f"Rebuild complete, {len(self._index)} items." 80 | ) # Use len() directly on dict 81 | self.mutex.release() 82 | 83 | def commit(self) -> None: 84 | """Commit the index to disk.""" 85 | if not self.enabled: 86 | return 87 | self.mutex.acquire() 88 | try: 89 | with open(self.index, "wb") as f: 90 | f.write(pickle.dumps(self._index)) 91 | self.committed = time.time() 92 | log.debug(f"Index {self.index} committed, {len(self._index)} items.") 93 | except OSError as e: 94 | log.error(f"Failed to commit index to {self.index}: {e}") 95 | finally: 96 | self.mutex.release() 97 | 98 | def purge(self) -> None: 99 | """Delete all cache files and rebuild the cache.""" 100 | self.mutex.acquire() 101 | try: 102 | try: 103 | os.unlink(self.index) 104 | except OSError as e: 105 | log.error(f"Could not unlink {self.index}: {e}") 106 | try: 107 | os.unlink(self.cache) 108 | except OSError as e: 109 | log.error(f"Could not unlink {self.cache}: {e}") 110 | finally: 111 | self.mutex.release() 112 | self._rebuild() 113 | 114 | def _cleanup(self) -> None: 115 | """Check if commit or compaction is needed.""" 116 | now = time.time() 117 | if now > (self.committed + self.commit_interval): 118 | self.commit() 119 | if now > (self.compacted + self.compact_interval): 120 | self._compact() 121 | 122 | def __eq__(self, other: object) -> bool: 123 | raise TypeError("Equality undefined for this kind of dictionary") 124 | 125 | def __ne__(self, other: object) -> bool: 126 | raise TypeError("Equality undefined for this kind of dictionary") 127 | 128 | def __lt__(self, other: object) -> bool: 129 | raise TypeError("Comparison undefined for this kind of dictionary") 130 | 131 | def __le__(self, other: object) -> bool: 132 | raise TypeError("Comparison undefined for this kind of dictionary") 133 | 134 | def __gt__(self, other: object) -> bool: 135 | raise TypeError("Comparison undefined for this kind of dictionary") 136 | 137 | def __ge__(self, other: object) -> bool: 138 | raise TypeError("Comparison undefined for this kind of dictionary") 139 | 140 | def __repr__(self) -> str: 141 | return ( 142 | f"" 143 | ) 144 | 145 | def expire(self, when: float) -> None: 146 | """Remove from cache any items older than a specified time""" 147 | if not self.enabled: 148 | return 149 | self.mutex.acquire() 150 | try: 151 | # Use list to avoid modification during iteration 152 | for k in list(self._index.keys()): 153 | if self._index[k][0] < when: 154 | del self._index[k] 155 | finally: 156 | self.mutex.release() 157 | self._cleanup() 158 | 159 | def keys(self) -> list[str]: 160 | # In Python 3 keys() returns a view, convert to list if needed 161 | return list(self._index.keys()) 162 | 163 | def stats(self, key: str) -> IndexEntry: 164 | """Get index statistics for a key.""" 165 | if not self.enabled: 166 | raise KeyError(key) 167 | self.mutex.acquire() 168 | try: 169 | stats = self._index[key] 170 | return stats 171 | except KeyError: 172 | raise KeyError(key) 173 | finally: 174 | self.mutex.release() 175 | 176 | def __setitem__(self, key: str, val: Any) -> None: 177 | """Store an item in the cache - errors will cause the entire cache to be rebuilt""" 178 | if not self.enabled: 179 | return 180 | self.mutex.acquire() 181 | try: 182 | with open(self.cache, "ab") as cache: 183 | buffer = pickle.dumps(val, protocol=pickle.HIGHEST_PROTOCOL) 184 | offset = cache.tell() 185 | cache.write(buffer) 186 | self.modified = mtime = time.time() 187 | self._index[key] = [mtime, len(buffer), offset] 188 | except Exception as e: 189 | log.error(f"Error while storing {key}: {e}") 190 | raise IOError(f"Error storing item: {e}") 191 | finally: 192 | self.mutex.release() 193 | self._cleanup() # Check if we need to commit/compact 194 | 195 | def __delitem__(self, key: str) -> None: 196 | """Remove item from cache - in practice, we only remove it from the index""" 197 | if not self.enabled: 198 | return 199 | self.mutex.acquire() 200 | try: 201 | del self._index[key] 202 | except KeyError: 203 | raise KeyError(key) 204 | except Exception as e: 205 | log.error(f"Unexpected error while deleting {key}: {e}") 206 | raise 207 | finally: 208 | self.mutex.release() 209 | self._cleanup() # Check if we need to commit/compact 210 | 211 | def get(self, key: str, default: Any = None) -> Any: 212 | """Get an item with a default value if not found.""" 213 | try: 214 | return self.__getitem__(key) 215 | except KeyError: 216 | return default 217 | 218 | def __getitem__(self, key: str) -> Any: 219 | """Retrieve item""" 220 | if not self.enabled: 221 | raise KeyError(key) 222 | self.mutex.acquire() 223 | try: 224 | # Make sure the key exists before trying to read it 225 | if key not in self._index: 226 | raise KeyError(key) 227 | 228 | with open(self.cache, "rb") as cache: 229 | cache.seek(self._index[key][2]) 230 | buffer = cache.read(self._index[key][1]) 231 | item = pickle.loads(buffer) 232 | return item 233 | except (FileNotFoundError, PermissionError) as e: 234 | log.error(f"File operation error while retrieving {key}: {e}") 235 | raise KeyError(key) 236 | except (pickle.PickleError, EOFError) as e: 237 | log.error(f"Pickle error while retrieving {key}: {e}") 238 | raise KeyError(key) 239 | except Exception as e: 240 | log.error(f"Unexpected error while retrieving {key}: {e}") 241 | raise KeyError(key) 242 | finally: 243 | self.mutex.release() 244 | 245 | def mtime(self, key: str) -> float: 246 | """Return the creation/modification time of a cache item""" 247 | if not self.enabled: 248 | raise KeyError(key) 249 | self.mutex.acquire() 250 | try: 251 | item = self._index[key][0] 252 | return item 253 | except KeyError: 254 | raise KeyError(key) 255 | except Exception as e: 256 | log.debug(f"Error while getting modification time for {key}: {e}") 257 | raise KeyError(key) 258 | finally: 259 | self.mutex.release() 260 | 261 | def _compact(self) -> None: 262 | """Compact the cache by rewriting only valid items""" 263 | self.mutex.acquire() 264 | try: 265 | # Use atomic operations where possible 266 | with open(self.cache, "rb") as cache, open(self.temp, "wb") as compacted: 267 | new_index: Index = {} 268 | i = 0 269 | for key in self._index: 270 | try: 271 | cache.seek(self._index[key][2]) 272 | offset = compacted.tell() 273 | data = cache.read(self._index[key][1]) 274 | compacted.write(data) 275 | new_index[key] = [time.time(), self._index[key][1], offset] 276 | i += 1 277 | except Exception as e: 278 | log.error(f"Error while compacting item {key}: {e}") 279 | # Skip this item 280 | continue 281 | 282 | size = compacted.tell() 283 | compacted.flush() 284 | os.fsync(compacted.fileno()) 285 | 286 | os.replace(self.temp, self.cache) # Atomic replacement on most systems 287 | self.compacted = time.time() 288 | self._index = new_index 289 | log.debug(f"Compacted {self.cache}: {i} items into {size} bytes") 290 | except OSError as e: 291 | log.error(f"Failed to compact cache: {e}") 292 | finally: 293 | self.mutex.release() 294 | self.commit() 295 | 296 | 297 | if __name__ == "__main__": 298 | # Set up logging 299 | logging.basicConfig( 300 | level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s" 301 | ) 302 | 303 | print("Running Haystack self-test...") 304 | c = Haystack(".", commit=3, compact=4) 305 | 306 | # Test basic operations 307 | c["tired"] = "to expire in 2 seconds" 308 | c["foo"] = {"a": 1, "b": 2} 309 | c["zbr"] = "42" 310 | c["test/path/name"] = "test" 311 | 312 | print("Values stored. Testing retrieval...") 313 | assert c["foo"] == {"a": 1, "b": 2}, "Retrieval test failed" 314 | assert c["zbr"] == "42", "String retrieval test failed" 315 | 316 | print("Testing expiration...") 317 | time.sleep(2) 318 | c.expire(time.time() - 2) 319 | 320 | try: 321 | value = c["tired"] 322 | print(f"ERROR: Retrieved expired item: {value}") 323 | except KeyError: 324 | print("Expired item correctly removed") 325 | 326 | print("Testing deletion...") 327 | del c["foo"] 328 | try: 329 | c["foo"] 330 | print("ERROR: Retrieved deleted item") 331 | except KeyError: 332 | print("Deleted item correctly removed") 333 | 334 | print("Waiting for automatic commit and compact...") 335 | time.sleep(5) 336 | 337 | print("All tests completed successfully!") 338 | -------------------------------------------------------------------------------- /miniredis/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | Based on a minimalist Redis server originally written by Benjamin Pollack 5 | 6 | First modified by Rui Carmo on 2013-03-12 7 | Published under the MIT license. 8 | """ 9 | 10 | from collections.abc import Mapping 11 | from collections import deque 12 | import getopt # Add missing import 13 | import logging 14 | import os 15 | import re 16 | import select 17 | import signal 18 | import socket 19 | import sys 20 | import threading 21 | import time 22 | from pathlib import Path 23 | from random import choice, sample 24 | from typing import Any, Dict, List, Optional, Set, Tuple, Union, TypeAlias 25 | 26 | log = logging.getLogger() 27 | 28 | from .haystack import Haystack 29 | 30 | 31 | class RedisConstant: 32 | def __init__(self, type: str) -> None: 33 | self.type = type 34 | 35 | def __len__(self) -> int: 36 | return 0 37 | 38 | def __repr__(self) -> str: 39 | return f"" 40 | 41 | 42 | class RedisMessage: 43 | def __init__(self, message: str) -> None: 44 | self.message = message 45 | 46 | def __str__(self) -> str: 47 | return f"+{self.message}" 48 | 49 | def __repr__(self) -> str: 50 | return f"" 51 | 52 | 53 | class RedisError(RedisMessage): 54 | def __init__(self, message: str) -> None: 55 | self.message = message 56 | 57 | def __str__(self) -> str: 58 | return f"-ERR {self.message}" 59 | 60 | def __repr__(self) -> str: 61 | return f"" 62 | 63 | 64 | EMPTY_SCALAR = RedisConstant("EmptyScalar") 65 | EMPTY_LIST = RedisConstant("EmptyList") 66 | BAD_VALUE = RedisError("Operation against a key holding the wrong kind of value") 67 | 68 | 69 | class RedisConnection: 70 | """Class to represent a client connection""" 71 | 72 | def __init__(self, socket: socket.socket) -> None: 73 | self.socket = socket 74 | self.wfile = socket.makefile("wb") 75 | self.rfile = socket.makefile("rb") 76 | self.db: int = 0 77 | self.table: Dict[str, Any] = {} 78 | 79 | 80 | class RedisServer: 81 | def __init__( 82 | self, host: str = "127.0.0.1", port: int = 6379, db_path: str = "." 83 | ) -> None: 84 | super().__init__() 85 | self.host = host 86 | self.port = port 87 | self.halt = True 88 | self.clients: Dict[socket.socket, RedisConnection] = {} 89 | self.tables: Dict[int, Dict[str, Any]] = {} 90 | self.channels: Dict[str, List[RedisConnection]] = {} 91 | self.lastsave = int(time.time()) 92 | self.path = db_path 93 | self.meta = Haystack(self.path, "redisdb") 94 | self.timeouts: Dict[str, float] = self.meta.get("timeouts", {}) 95 | 96 | def dump(self, client: RedisConnection, o: Any) -> None: 97 | """Output a result to a client""" 98 | nl = b"\r\n" 99 | if isinstance(o, bool): 100 | if o: 101 | client.wfile.write(b"+OK\r\n") 102 | elif o == EMPTY_SCALAR: 103 | client.wfile.write(b"$-1\r\n") 104 | elif o == EMPTY_LIST: 105 | client.wfile.write(b"*-1\r\n") 106 | elif isinstance(o, int): 107 | client.wfile.write(b":" + str(o).encode() + nl) 108 | elif isinstance(o, str): 109 | o_bytes = o.encode() 110 | client.wfile.write(b"$" + str(len(o_bytes)).encode() + nl) 111 | client.wfile.write(o_bytes + nl) 112 | elif isinstance(o, bytes): 113 | client.wfile.write(b"$" + str(len(o)).encode() + nl) 114 | client.wfile.write(o + nl) 115 | elif isinstance(o, list): 116 | client.wfile.write(b"*" + str(len(o)).encode() + nl) 117 | for val in o: 118 | if isinstance(val, (str, bytes, int, float)): 119 | self.dump(client, val) 120 | elif val is None: 121 | self.dump(client, EMPTY_SCALAR) 122 | else: 123 | self.dump(client, str(val)) 124 | elif isinstance(o, RedisMessage): 125 | client.wfile.write(str(o).encode() + b"\r\n") 126 | elif isinstance(o, dict): 127 | client.wfile.write(b"*" + str(len(o) * 2).encode() + nl) 128 | for k, v in o.items(): 129 | self.dump(client, str(k)) 130 | self.dump(client, str(v) if v is not None else EMPTY_SCALAR) 131 | else: 132 | client.wfile.write(b"return type not yet implemented\r\n") 133 | client.wfile.flush() 134 | 135 | def log(self, client: Optional[RedisConnection], s: str) -> None: 136 | """Server logging""" 137 | try: 138 | who = ( 139 | f"{client.socket.getpeername()[0]}:{client.socket.getpeername()[1]}" 140 | if client 141 | else "SERVER" 142 | ) 143 | except: 144 | who = "" 145 | log.debug(f"{who}: {s}") 146 | 147 | def handle(self, client: RedisConnection) -> None: 148 | """Handle commands""" 149 | 150 | keys_to_check = ( 151 | sample(list(self.timeouts.keys()), len(self.timeouts) // 4) 152 | if self.timeouts 153 | else [] 154 | ) 155 | for e in keys_to_check: 156 | self.check_ttl(client, e.split(" ", 1)[1]) 157 | 158 | line = client.rfile.readline() 159 | if not line: 160 | self.log(client, "client disconnected") 161 | del self.clients[client.socket] 162 | client.socket.close() 163 | return 164 | items = int(line[1:].strip()) 165 | args = [] 166 | for _ in range(items): 167 | length_line = client.rfile.readline().strip() 168 | if not length_line or not length_line.startswith(b"$"): 169 | raise RedisError("Protocol error: expected bulk string length") 170 | length = int(length_line[1:]) 171 | if length == -1: 172 | args.append(None) 173 | else: 174 | data = client.rfile.read(length) 175 | if len(data) < length: 176 | raise RedisError("Protocol error: insufficient data read") 177 | args.append(data) 178 | crlf = client.rfile.read(2) 179 | if crlf != b"\r\n": 180 | raise RedisError("Protocol error: expected CRLF") 181 | 182 | try: 183 | command = args[0].decode("utf-8").lower() 184 | decoded_args = [] 185 | for arg in args[1:]: 186 | if arg is not None: 187 | decoded_args.append(arg.decode("utf-8")) 188 | else: 189 | decoded_args.append(None) 190 | except UnicodeDecodeError: 191 | raise RedisError("Command or arguments not valid UTF-8") 192 | 193 | handler_name = "handle_" + command 194 | if hasattr(self, handler_name): 195 | self.dump(client, getattr(self, handler_name)(client, *decoded_args)) 196 | else: 197 | self.dump(client, RedisError(f"unknown command '{command}'")) 198 | 199 | def rotate(self) -> None: 200 | """Rotate log file using context manager for better resource handling""" 201 | try: 202 | self.log_file.close() 203 | with open(self.log_name, "w") as new_log_file: 204 | self.log_file = new_log_file 205 | except (FileNotFoundError, PermissionError) as e: 206 | log.error(f"Error rotating log file: {e}") 207 | 208 | def run(self) -> None: 209 | """Main loop for standard socket handling with improved exception handling""" 210 | self.halt = False 211 | server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 212 | server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 213 | try: 214 | server.bind((self.host, self.port)) 215 | server.listen(5) 216 | while not self.halt: 217 | try: 218 | readable, _, _ = select.select( 219 | [server] + list(self.clients.keys()), [], [], 1.0 220 | ) 221 | except select.error as e: 222 | if hasattr(socket, "errno") and e.args[0] == socket.errno.EINTR: 223 | continue 224 | raise 225 | except (ValueError, TypeError) as e: 226 | log.error(f"Error in select: {e}") 227 | continue 228 | 229 | for sock in readable: 230 | if sock == server: 231 | try: 232 | (client_socket, address) = server.accept() 233 | client = RedisConnection(client_socket) 234 | self.clients[client_socket] = client 235 | self.log(client, "client connected") 236 | self.select(client, 0) 237 | except OSError as e: 238 | log.error(f"Error accepting connection: {e}") 239 | else: 240 | client = self.clients.get(sock) 241 | if client: 242 | try: 243 | self.handle(client) 244 | except ( 245 | socket.error, 246 | ConnectionError, 247 | EOFError, 248 | ConnectionResetError, 249 | ) as e: 250 | self.log(client, f"client connection error: {e}") 251 | self.handle_quit(client) 252 | except Exception as e: 253 | self.log(client, f"exception: {e}") 254 | self.handle_quit(client) 255 | finally: 256 | # Ensure proper cleanup on exit 257 | for client_socket in list(self.clients.keys()): 258 | try: 259 | client_socket.close() 260 | except socket.error: 261 | pass 262 | self.clients.clear() 263 | server.close() 264 | 265 | def save(self) -> None: 266 | """Serialize tables to disk""" 267 | self.meta["timeouts"] = self.timeouts 268 | for db in self.tables: 269 | self.meta[db] = self.tables[db] 270 | self.meta.commit() 271 | self.lastsave = int(time.time()) 272 | 273 | def select(self, client: RedisConnection, db: int) -> None: 274 | if db not in self.tables: 275 | self.tables[db] = self.meta.get(db, {}) 276 | client.db = db 277 | client.table = self.tables[db] 278 | 279 | def stop(self) -> None: 280 | if not self.halt: 281 | self.log(None, "STOPPING") 282 | self.save() 283 | self.halt = True 284 | 285 | def check_ttl(self, client: RedisConnection, key: str) -> None: 286 | k = f"{client.db} {key}" 287 | if k in self.timeouts: 288 | if self.timeouts[k] <= time.time(): 289 | self.handle_del(client, key) 290 | 291 | # command handlers, sorted by order of redis.io docs 292 | 293 | # Keys 294 | 295 | def handle_del(self, client, *args): 296 | count = 0 297 | for key in args: 298 | self.handle_persist(client, key) 299 | self.log(client, f"DEL {key}") 300 | if key not in client.table: 301 | continue 302 | del client.table[key] 303 | count += 1 304 | return count 305 | 306 | def handle_dump(self, client, key): 307 | self.log(client, f"DUMP {key}") 308 | # no special internal representation 309 | return str(client.table[key]) 310 | 311 | def handle_exists(self, client, key): 312 | self.check_ttl(client, key) 313 | if key in client.table: 314 | return 1 315 | return 0 316 | 317 | def handle_expire(self, client, key, ttl): 318 | ttl = int(ttl) 319 | self.log(client, f"EXPIRE {key} {ttl}") 320 | if key not in client.table: 321 | return 0 322 | self.timeouts[f"{client.db} {key}"] = time.time() + ttl 323 | return 1 324 | 325 | def handle_expireat(self, client, key, when): 326 | when = int(when) 327 | self.log(client, f"EXPIREAT {key} {when}") 328 | if key not in client.table: 329 | return 0 330 | # Store as a Unix timestamp 331 | self.timeouts[f"{client.db} {key}"] = float(when) 332 | return 1 333 | 334 | def handle_keys(self, client, pattern): 335 | """Return all keys matching a pattern""" 336 | # Replace Redis glob patterns with Python regex patterns 337 | regex_pattern = "^" + pattern 338 | # Replace Redis wildcards with regex equivalents 339 | regex_pattern = regex_pattern.replace("*", ".*") 340 | regex_pattern = regex_pattern.replace("?", ".") 341 | regex_pattern += "$" # Anchor the end 342 | 343 | r = re.compile(regex_pattern) 344 | self.log(client, f"KEYS {pattern}") 345 | return [k for k in client.table.keys() if r.match(k)] 346 | 347 | # def handle_migrate(self, client, host, port, key, db, timeout, option): 348 | 349 | def handle_move(self, client, key, db): 350 | self.log(client, f"MOVE {key}") 351 | if key not in client.table: 352 | return 0 353 | self.handle_persist(client, key) 354 | if db not in self.tables: 355 | self.tables[db] = self.meta.get(db, {}) 356 | if key in self.tables[db]: 357 | return 0 358 | self.tables[db][key] = client.table[key] 359 | del client.table[key] 360 | return 1 361 | 362 | # def handle_object(self, client, subcommand, *args) 363 | 364 | def handle_persist(self, client, key): 365 | try: 366 | del self.timeouts[f"{client.db} {key}"] 367 | except: 368 | pass 369 | 370 | def handle_pexpire(self, client, key, mttl): 371 | mttl = int(mttl) 372 | if key not in client.table: 373 | return 0 374 | self.timeouts[f"{client.db} {key}"] = time.time() + (mttl / 1000) 375 | return 1 376 | 377 | def handle_pexpireat(self, client, key, mwhen): 378 | mwhen = int(mwhen) 379 | if key not in client.table: 380 | return 0 381 | self.timeouts[f"{client.db} {key}"] = mwhen / 1000 382 | return 1 383 | 384 | def handle_pttl(self, client, key): 385 | """Get the time to live for a key in milliseconds""" 386 | self.log(client, f"PTTL {key}") 387 | # First check if key exists at all 388 | if key not in client.table: 389 | return -2 390 | 391 | # Check if the key has a timeout 392 | k = f"{client.db} {key}" 393 | if k not in self.timeouts: 394 | return -1 395 | 396 | # Check if key has expired 397 | remaining_time = self.timeouts[k] - time.time() 398 | if remaining_time <= 0: 399 | # The key has expired, remove it and return -2 400 | del client.table[key] 401 | del self.timeouts[k] 402 | return -2 403 | 404 | # Convert seconds to milliseconds 405 | return int(remaining_time * 1000) 406 | 407 | def handle_randomkey(self, client): 408 | self.log(client, "RANDOMKEY") 409 | if len(client.table): 410 | return self.get(client, choice(list(client.table.keys()))) 411 | return 0 412 | 413 | def handle_rename(self, client, key, newkey): 414 | client.table[newkey] = client.table[key] 415 | k = f"{client.db} {key}" 416 | # transfer TTL 417 | if k in self.timeouts: 418 | self.timeouts[f"{client.db} {key}"] = self.timeouts[k] 419 | del self.timeouts[k] 420 | del client.table[key] 421 | self.log(client, f"RENAME {key} -> {newkey}") 422 | return True 423 | 424 | def handle_renamenx(self, client, key, newkey): 425 | self.log(client, f"RENAMENX {key} -> {newkey}") 426 | if newkey not in client.table: 427 | self.handle_rename(client, key, newkey) 428 | return 1 429 | return 0 430 | 431 | # def handle_sort(self, client, key, *args) 432 | 433 | def handle_ttl(self, client, key): 434 | """Get the time to live for a key in seconds""" 435 | # First check if key exists at all 436 | if key not in client.table: 437 | return -2 438 | 439 | # Check if key has expired 440 | k = f"{client.db} {key}" 441 | if k in self.timeouts: 442 | remaining_time = self.timeouts[k] - time.time() 443 | if remaining_time <= 0: 444 | # The key has expired, remove it and return -2 445 | del client.table[key] 446 | del self.timeouts[k] 447 | return -2 448 | return int(remaining_time + 0.1) # Add small buffer to avoid truncation issues 449 | return -1 # Key exists but has no expiration 450 | 451 | def handle_type(self, client, key): 452 | if key not in client.table: 453 | return RedisMessage("none") 454 | 455 | data = client.table[key] 456 | if isinstance(data, deque): 457 | return RedisMessage("list") 458 | elif isinstance(data, set): 459 | return RedisMessage("set") 460 | elif isinstance(data, dict): 461 | return RedisMessage("hash") 462 | elif isinstance(data, str): 463 | return RedisMessage("string") 464 | else: 465 | return RedisError("unknown data type") 466 | 467 | # def handle_scan(self, client, *args) 468 | 469 | # Strings 470 | 471 | def handle_append(self, client, key, value): 472 | if key not in client.table: 473 | self.handle_set(client, key, value) 474 | return len(client.table[key]) 475 | data = client.table[key] 476 | if isinstance(data, str): 477 | self.handle_persist(client, key) 478 | client.table[key] += value 479 | self.log(client, f"APPEND {key} -> {len(client.table[key])}") 480 | return len(client.table[key]) 481 | return BAD_VALUE 482 | 483 | # def handle_bitcount(self, client, key, start, end) 484 | # def handle_bitop(self, client, *args) 485 | 486 | def handle_decr(self, client, key): 487 | self.check_ttl(client, key) 488 | return self.handle_decrby(client, key, 1) 489 | 490 | def handle_decrby(self, client, key, by): 491 | self.check_ttl(client, key) 492 | return self.handle_incrby(client, key, -int(by)) 493 | 494 | def handle_get(self, client, key): 495 | self.check_ttl(client, key) 496 | data = client.table.get(key, None) 497 | if isinstance(data, deque): 498 | return BAD_VALUE 499 | if data is not None: 500 | data = str(data) 501 | else: 502 | data = EMPTY_SCALAR 503 | self.log(client, f"GET {key} -> {len(data) if data != EMPTY_SCALAR else 0}") 504 | return data 505 | 506 | # def handle_getbit(self, client, key, offset): 507 | # def handle_getrange(self, client, key, start, end): 508 | 509 | def handle_getset(self, client, key, data): 510 | self.handle_persist(client, key) 511 | old_data = client.table.get(key, None) 512 | if isinstance(old_data, deque): 513 | return BAD_VALUE 514 | if old_data is not None: 515 | old_data = str(old_data) 516 | else: 517 | old_data = EMPTY_SCALAR 518 | client.table[key] = data 519 | self.log(client, f"GETSET {key} {data} -> {old_data}") 520 | return old_data 521 | 522 | def handle_incr(self, client, key): 523 | self.check_ttl(client, key) 524 | return self.handle_incrby(client, key, 1) 525 | 526 | def handle_incrby(self, client, key, by): 527 | """Increment key by specified value, properly handling nonexistent keys""" 528 | self.check_ttl(client, key) 529 | try: 530 | by = int(by) # Make sure by is a valid integer 531 | 532 | # If key exists, make sure it's an integer 533 | if key in client.table: 534 | try: 535 | # Try to convert current value to int 536 | current_value = int(client.table[key]) 537 | except (TypeError, ValueError): 538 | # Current value is not an integer, raise proper Redis error 539 | return RedisError("value is not an integer") 540 | 541 | # Value is a valid integer, proceed with increment 542 | client.table[key] = current_value + by 543 | else: 544 | # Key doesn't exist, set to the increment value 545 | client.table[key] = by 546 | 547 | self.log(client, f"INCRBY {key} {by} -> {client.table[key]}") 548 | return client.table[key] 549 | 550 | except (TypeError, ValueError): 551 | # By argument is not a valid integer 552 | return RedisError("value is not an integer or out of range") 553 | 554 | # def handle_incrbyfloat(self, client, key, by): 555 | 556 | def handle_mget(self, client, *keys): 557 | """Multi-get - return values for multiple keys""" 558 | result = [] 559 | for k in keys: 560 | self.check_ttl(client, k) 561 | if k not in client.table: 562 | result.append(None) # Return None for non-existent keys 563 | continue 564 | 565 | data = client.table.get(k, None) 566 | # Skip non-string values, returning None 567 | if not isinstance(data, str): 568 | result.append(None) 569 | continue 570 | 571 | result.append(str(data)) 572 | 573 | self.log(client, f"MGET {keys} -> {len(result)} values") 574 | return result 575 | 576 | # def handle_mset(self, client, *args): 577 | # def handle_msetnx(self, client, *args): 578 | # def handle_psetex(self, client, key, ms, value): 579 | 580 | def handle_set(self, client, key, data): 581 | self.handle_persist(client, key) 582 | client.table[key] = data 583 | self.log(client, f"SET {key} -> {len(data)}") 584 | return True 585 | 586 | # def handle_setbit(self, client, key, offset, value) 587 | 588 | def handle_setex(self, client, key, seconds, data): 589 | """Set the value and expiration of a key""" 590 | # Validate seconds parameter is a valid integer 591 | try: 592 | seconds = int(seconds) 593 | if seconds <= 0: 594 | return RedisError("invalid expire time in 'setex' command") 595 | except (ValueError, TypeError): 596 | return RedisError("value is not an integer or out of range") 597 | 598 | # First set the key value 599 | self.handle_set(client, key, data) 600 | # Then set expiration 601 | self.handle_expire(client, key, seconds) 602 | # Return OK to match Redis standard 603 | return RedisMessage("OK") 604 | 605 | def handle_setnx(self, client, key, data): 606 | if key in client.table: 607 | self.log(client, f"SETNX {key} -> {len(data)} FAILED") 608 | return 0 609 | client.table[key] = data 610 | self.log(client, f"SETNX {key} -> {len(data)}") 611 | return 1 612 | 613 | # def handle_setrange(self, client, key, offset, value) 614 | # def handle_strlen(self, client, key) 615 | 616 | # Hashes 617 | 618 | # Lists 619 | 620 | # def handle_blpop(self, client, *args) 621 | # def handle_brpop(self, client, *args) 622 | # def handle_brpoplpush(self, client, *args) 623 | 624 | # def handle_lindex(self, client, key, index) 625 | # def handle_linsert(self, client, key, *args) 626 | 627 | def handle_llen(self, client, key): 628 | self.check_ttl(client, key) 629 | if key not in client.table: 630 | return 0 631 | if not isinstance(client.table[key], deque): 632 | return BAD_VALUE 633 | return len(client.table[key]) 634 | 635 | def handle_lpop(self, client, key): 636 | self.check_ttl(client, key) 637 | if key not in client.table: 638 | return EMPTY_SCALAR 639 | if not isinstance(client.table[key], deque): 640 | return BAD_VALUE 641 | if len(client.table[key]) > 0: 642 | data = client.table[key].popleft() 643 | else: 644 | data = EMPTY_SCALAR 645 | self.log(client, f"LPOP {key} -> {data}") 646 | return data 647 | 648 | def handle_lpush(self, client, key, data): 649 | self.check_ttl(client, key) 650 | if key not in client.table: 651 | client.table[key] = deque() 652 | elif not isinstance(client.table[key], deque): 653 | return BAD_VALUE 654 | client.table[key].appendleft(data) 655 | self.log(client, f"LPUSH {key} {data}") 656 | return True 657 | 658 | # def handle_lpushx(self, client, key, data): 659 | 660 | def handle_lrange(self, client, key, start, stop): 661 | self.check_ttl(client, key) 662 | start, stop = int(start), int(stop) 663 | if start == 0 and stop == -1: 664 | stop = None 665 | if key not in client.table: 666 | return EMPTY_LIST 667 | if not isinstance(client.table[key], deque): 668 | return BAD_VALUE 669 | l = list(client.table[key])[start:stop] 670 | self.log(client, f"LRANGE {key} {start} {stop} -> {l}") 671 | return l 672 | 673 | # def handle_lrem(self, client, key, start, stop): 674 | # def handle_lset(self, client, key, index, value): 675 | # def handle_ltrim(self, client, key, start, stop): 676 | 677 | def handle_rpop(self, client, key): 678 | self.check_ttl(client, key) 679 | if key not in client.table: 680 | return EMPTY_SCALAR 681 | if not isinstance(client.table[key], deque): 682 | return BAD_VALUE 683 | if len(client.table[key]) > 0: 684 | data = client.table[key].pop() 685 | else: 686 | data = EMPTY_SCALAR 687 | self.log(client, f"RPOP {key} -> {data}") 688 | return data 689 | 690 | # def handle_rpoplpush(self, source, destination) 691 | 692 | def handle_rpush(self, client, key, data): 693 | self.check_ttl(client, key) 694 | if key not in client.table: 695 | client.table[key] = deque() 696 | elif not isinstance(client.table[key], deque): 697 | return BAD_VALUE 698 | client.table[key].append(data) 699 | self.log(client, f"RPUSH {key} {data}") 700 | return True 701 | 702 | # def handle_rpushx(self, client, key, data) 703 | 704 | # Hashes (TODO: add type checks) 705 | 706 | def handle_hdel(self, client, key, *keys): 707 | if key not in client.table: 708 | return 0 709 | self.check_ttl(client, key) 710 | count = 0 711 | for field in keys: 712 | if field in client.table[key]: 713 | del client.table[key][field] 714 | count += 1 715 | return count 716 | 717 | def handle_hexists(self, client, key, field): 718 | if key not in client.table: 719 | return 0 720 | self.check_ttl(client, key) 721 | return 1 if field in client.table[key] else 0 722 | 723 | def handle_hget(self, client, key, field): 724 | if key not in client.table: 725 | return 0 726 | self.check_ttl(client, key) 727 | return client.table[key][field] if field in client.table[key] else None 728 | 729 | def handle_hgetall(self, client, key): 730 | self.check_ttl(client, key) 731 | try: 732 | return client.table[key] 733 | except: 734 | return [] 735 | 736 | def handle_hincrby(self, client, key, field, increment): 737 | increment = int(increment) # Convert to int for Python 3 738 | if key not in client.table: 739 | client.table[key] = {} 740 | prev = int(client.table[key].get(field, "0")) 741 | 742 | client.table[key][field] = str(prev + increment) 743 | return client.table[key][field] 744 | 745 | # def handle_hincrbyfloat(self, client, key, field, increment): 746 | 747 | def handle_hkeys(self, client, key): 748 | if key not in client.table: 749 | return [] 750 | return list(client.table[key].keys()) # Convert keys view to list 751 | 752 | def handle_hlen(self, client, key): 753 | self.check_ttl(client, key) 754 | return len(client.table[key]) 755 | 756 | def handle_hmget(self, client, key, *fields): 757 | self.check_ttl(client, key) 758 | return [client.table[key].get(f) for f in fields] 759 | 760 | def handle_hmset(self, client, key, items): 761 | self.check_ttl(client, key) 762 | for k, v in items.items(): 763 | client.table[key][k] = v # Use direct indexing 764 | return True 765 | 766 | def handle_hset(self, client, key, field, value): 767 | self.check_ttl(client, key) 768 | if key not in client.table: 769 | client.table[key] = {} 770 | if field not in client.table[key]: 771 | client.table[key][field] = value 772 | return 1 773 | client.table[key][field] = value 774 | return 0 775 | 776 | # def handle_hsetnx(self, client, key, field, value) 777 | 778 | def handle_hvals(self, client, key): 779 | if key not in client.table: 780 | return [] 781 | return list(client.table[key].values()) # Convert values view to list 782 | 783 | # def hscan(self, client, key, cursor, *args) 784 | 785 | # Server 786 | 787 | def handle_bgsave(self, client): 788 | if hasattr(os, "fork"): 789 | if not os.fork(): 790 | self.save() 791 | sys.exit(0) 792 | else: 793 | self.save() 794 | self.log(client, "BGSAVE") 795 | return RedisMessage("Background saving started") 796 | 797 | def handle_flushdb(self, client): 798 | self.log(client, "FLUSHDB") 799 | client.table.clear() 800 | return True 801 | 802 | def handle_flushall(self, client): 803 | self.log(client, "FLUSHALL") 804 | for table in self.tables.values(): # Use values() instead of itervalues() 805 | table.clear() 806 | return True 807 | 808 | def handle_lastsave(self, client): 809 | return self.lastsave 810 | 811 | def handle_ping(self, client): 812 | self.log(client, "PING -> PONG") 813 | return RedisMessage("PONG") 814 | 815 | def handle_quit(self, client): 816 | try: 817 | client.socket.shutdown(socket.SHUT_RDWR) 818 | except: 819 | pass # Socket might already be closed 820 | client.socket.close() 821 | self.log(client, "QUIT") 822 | if client.socket in self.clients: 823 | del self.clients[client.socket] 824 | return False 825 | 826 | def handle_save(self, client): 827 | self.save() 828 | self.log(client, "SAVE") 829 | return True 830 | 831 | def handle_select(self, client, db): 832 | db = int(db) 833 | self.select(client, db) 834 | self.log(client, f"SELECT {db}") 835 | return True 836 | 837 | # PubSub 838 | 839 | def handle_publish(self, client, channel, message): 840 | count = 0 841 | for p in self.channels.keys(): 842 | if re.match(p, channel): 843 | for c in self.channels[channel]: 844 | c.wfile.write(b"*3\r\n") 845 | c.wfile.write(f'${len("message")}\r\n'.encode()) 846 | c.wfile.write(b"message\r\n") 847 | c.wfile.write(f"${len(channel)}\r\n".encode()) 848 | c.wfile.write(channel.encode() + b"\r\n") 849 | c.wfile.write(f"${len(message)}\r\n".encode()) 850 | c.wfile.write(message.encode() + b"\r\n") 851 | count += 1 852 | return count 853 | 854 | def handle_subscribe(self, client, *channels): 855 | count = 0 856 | for c in channels: 857 | if c not in self.channels: 858 | self.channels[c] = [] 859 | self.channels[c].append(client) 860 | count += 1 861 | return count 862 | 863 | def handle_unsubscribe(self, client, *channels): 864 | # If no channels provided, unsubscribe from all 865 | if not channels: 866 | channels_to_remove = [] 867 | for c, clients in self.channels.items(): 868 | if client in clients: 869 | clients.remove(client) 870 | if not clients: 871 | channels_to_remove.append(c) 872 | 873 | # Remove empty channel entries 874 | for c in channels_to_remove: 875 | del self.channels[c] 876 | return True 877 | 878 | # Unsubscribe from specified channels 879 | count = 0 880 | for c in channels: 881 | try: 882 | if c in self.channels and client in self.channels[c]: 883 | self.channels[c].remove(client) 884 | count += 1 885 | if not self.channels[c]: # Clean up empty lists 886 | del self.channels[c] 887 | except: 888 | pass 889 | return count 890 | 891 | def handle_psubscribe(self, client, *patterns): 892 | # Similar implementation to handle_subscribe but for pattern subscriptions 893 | count = 0 894 | for pattern in patterns: 895 | if pattern not in self.channels: 896 | self.channels[pattern] = [] 897 | self.channels[pattern].append(client) 898 | count += 1 899 | return count 900 | 901 | def handle_punsubscribe(self, client, *patterns): 902 | # If no patterns provided, unsubscribe from all patterns 903 | if not patterns: 904 | patterns_to_remove = [] 905 | for p, clients in self.channels.items(): 906 | # We only want to remove from patterns, not regular channels 907 | # This is a simplification, in real Redis this is more complex 908 | if p.startswith("*") or p.startswith("?"): 909 | if client in clients: 910 | clients.remove(client) 911 | if not clients: 912 | patterns_to_remove.append(p) 913 | 914 | # Remove empty pattern entries 915 | for p in patterns_to_remove: 916 | del self.channels[p] 917 | return True 918 | 919 | # Unsubscribe from specified patterns 920 | count = 0 921 | for p in patterns: 922 | try: 923 | if p in self.channels and client in self.channels[p]: 924 | self.channels[p].remove(client) 925 | count += 1 926 | if not self.channels[p]: # Clean up empty lists 927 | del self.channels[p] 928 | except: 929 | pass 930 | return count 931 | 932 | def handle_shutdown(self, client): 933 | self.log(client, "SHUTDOWN") 934 | self.halt = True 935 | self.save() 936 | return self.handle_quit(client) 937 | 938 | 939 | class ThreadedRedisServer(RedisServer): 940 | """ 941 | # for use in an accept() loop: 942 | import threading, socket 943 | sock = socket(AF_INET, SOCK_STREAM) 944 | sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) 945 | sock.bind(('127.0.0.1', port=6379)) 946 | serversock.listen(5) 947 | t = ThreadedRedisServer() 948 | while True: 949 | clientsock, addr = sock.accept() 950 | thread = threading.Thread(target=t.thread, args=(clientsock, addr)) 951 | thread.daemon = True 952 | thread.start() 953 | """ 954 | 955 | def __init__(self, **kwargs): 956 | super().__init__(**kwargs) 957 | 958 | def thread(self, sock, address): 959 | client = RedisConnection(sock) 960 | self.clients[sock] = client 961 | self.log(client, "client connected") 962 | self.select(client, 0) 963 | while not self.halt: 964 | try: 965 | self.handle(self.clients[sock]) 966 | except (socket.error, EOFError, ConnectionResetError) as e: 967 | self.log(client, f"client connection error in thread: {e}") 968 | break 969 | except Exception as e: 970 | self.log(client, f"exception in thread: {e}") 971 | break 972 | try: 973 | if sock in self.clients: 974 | self.handle_quit(client) 975 | except Exception as e: 976 | log.debug(f"Error during quit in thread: {e}") 977 | if sock in self.clients: 978 | del self.clients[sock] 979 | try: 980 | sock.close() 981 | except socket.error: 982 | pass 983 | 984 | 985 | def fork(**kwargs): 986 | if not hasattr(os, "fork"): 987 | print( 988 | "Fork not supported on this OS. Consider using multiprocessing or threading.", 989 | file=sys.stderr, 990 | ) 991 | sys.exit(1) 992 | 993 | try: 994 | pid = os.fork() 995 | if pid > 0: 996 | return pid 997 | m = RedisServer(**kwargs) 998 | m.run() 999 | sys.exit(0) 1000 | except KeyboardInterrupt: 1001 | pass 1002 | except OSError as e: 1003 | print( 1004 | f"Failed to launch Redis subprocess: {e.errno} ({e.strerror})", 1005 | file=sys.stderr, 1006 | ) 1007 | sys.exit(1) 1008 | 1009 | 1010 | def main(args): 1011 | global m 1012 | m = None 1013 | 1014 | if os.name == "posix": 1015 | 1016 | def sigterm_handler(signum, frame): 1017 | if m: 1018 | m.stop() 1019 | sys.exit(0) 1020 | 1021 | def sighup_handler(signum, frame): 1022 | if m: 1023 | m.rotate() 1024 | 1025 | signal.signal(signal.SIGTERM, sigterm_handler) 1026 | signal.signal(signal.SIGHUP, sighup_handler) 1027 | 1028 | host, port, log_file, db_path = "127.0.0.1", 6379, None, "." 1029 | pid_file = None 1030 | try: 1031 | opts, args = getopt.getopt(args, "h:p:d:l:f:") 1032 | except getopt.GetoptError as err: 1033 | print(str(err), file=sys.stderr) 1034 | sys.exit(2) 1035 | 1036 | for o, a in opts: 1037 | if o == "-h": 1038 | host = a 1039 | elif o == "-p": 1040 | port = int(a) 1041 | elif o == "-l": 1042 | log_file = os.path.abspath(a) 1043 | elif o == "-d": 1044 | db_path = os.path.abspath(a) 1045 | elif o == "-f": 1046 | pid_file = os.path.abspath(a) 1047 | 1048 | log_level = logging.INFO 1049 | log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 1050 | if log_file: 1051 | logging.basicConfig(filename=log_file, level=log_level, format=log_format) 1052 | else: 1053 | logging.basicConfig(level=log_level, format=log_format) 1054 | 1055 | if pid_file: 1056 | try: 1057 | with open(pid_file, "w") as f: 1058 | f.write(f"{os.getpid()}\n") 1059 | except IOError as e: 1060 | log.error(f"Could not write PID file {pid_file}: {e}") 1061 | pid_file = None 1062 | 1063 | m = RedisServer(host=host, port=port, db_path=db_path) 1064 | log.info(f"Starting miniredis server on {host}:{port}, DB path: {db_path}") 1065 | try: 1066 | m.run() 1067 | except KeyboardInterrupt: 1068 | log.info("KeyboardInterrupt received, stopping server.") 1069 | if m: 1070 | m.stop() 1071 | except Exception as e: 1072 | log.exception("Unhandled exception in server run loop") 1073 | if m: 1074 | m.stop() 1075 | finally: 1076 | if pid_file and os.path.exists(pid_file): 1077 | try: 1078 | os.unlink(pid_file) 1079 | except OSError as e: 1080 | log.error(f"Could not remove PID file {pid_file}: {e}") 1081 | log.info("Server stopped.") 1082 | sys.exit(0) 1083 | 1084 | 1085 | if __name__ == "__main__": 1086 | main(sys.argv[1:]) 1087 | -------------------------------------------------------------------------------- /miniredis/sset.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_left, bisect_right 2 | from collections.abc import MutableSet # Use ABC for abstract base class 3 | 4 | 5 | class SortedSet(MutableSet): 6 | """ 7 | Redis-style SortedSet implementation using bisect. 8 | 9 | Maintains two internal data structures: 10 | 1. A sorted list of (score, member) pairs. 11 | 2. A dictionary from member to score. 12 | 13 | Note: Insertion and removal are O(N) due to list insertion/deletion. 14 | For better performance (O(log N)), consider alternative structures 15 | like balanced trees or skip lists if N becomes large. 16 | """ 17 | def __init__(self, iterable=None): 18 | """ 19 | Create a sorted set. Optional iterable can initialize the set. 20 | Iterable should yield (member, score) pairs. 21 | """ 22 | # sorted list of (score, member) 23 | self._scores = [] 24 | # dictionary from member to score 25 | self._members = {} 26 | if iterable is not None: 27 | for member, score in iterable: 28 | self.add(member, score) 29 | 30 | # Required MutableSet abstract methods 31 | def add(self, member, score=0.0): 32 | """ 33 | Add member with score. If member is already present, 34 | update its score. Conforms to Set.add signature partially, 35 | but requires score. 36 | Returns True if member was added, False if updated. 37 | """ 38 | try: 39 | # Convert score to float for consistent comparison 40 | score = float(score) 41 | except (ValueError, TypeError): 42 | raise ValueError("Score must be a float or representable as a float") 43 | 44 | found = member in self._members 45 | if found: 46 | self._remove_existing(member) 47 | 48 | # Find insertion point and insert 49 | index = bisect_left(self._scores, (score, member)) 50 | self._scores.insert(index, (score, member)) 51 | self._members[member] = score 52 | return not found 53 | 54 | def discard(self, member): 55 | """ 56 | Remove member from the set if it is present. 57 | Returns True if member was removed, False otherwise. 58 | Conforms to Set.discard signature. 59 | """ 60 | if member not in self._members: 61 | return False 62 | self._remove_existing(member) 63 | return True 64 | 65 | def __contains__(self, member): 66 | return member in self._members 67 | 68 | def __len__(self): 69 | return len(self._members) 70 | 71 | def __iter__(self): 72 | """Iterate over members in score order.""" 73 | for _, member in self._scores: 74 | yield member 75 | 76 | # --- End Required MutableSet methods --- 77 | 78 | def _remove_existing(self, member): 79 | """Internal helper to remove an existing member.""" 80 | score = self._members[member] 81 | # Find the exact (score, member) pair to remove 82 | # bisect_left gives the insertion point, which is the index if found 83 | score_index = bisect_left(self._scores, (score, member)) 84 | # Verify we found the correct item before deleting 85 | if score_index < len(self._scores) and self._scores[score_index] == (score, member): 86 | del self._scores[score_index] 87 | del self._members[member] 88 | else: 89 | # This indicates an internal inconsistency, should not happen 90 | raise RuntimeError(f"Internal state inconsistency: could not find {member} with score {score} for removal") 91 | 92 | 93 | def clear(self): 94 | """ 95 | Remove all members and scores from the sorted set. 96 | """ 97 | self._scores = [] 98 | self._members = {} 99 | 100 | def __str__(self): 101 | # Represent as a set of members for clarity, though order is lost 102 | return "{{{}}}".format(", ".join(repr(m) for m in self)) 103 | 104 | def __repr__(self): 105 | # Show the internal structure for debugging 106 | return "SortedSet([{}])".format(", ".join(f"({s!r}, {m!r})" for s, m in self._scores)) 107 | 108 | # Redis-specific methods 109 | def zadd(self, score, member, *args): 110 | """ 111 | Adds members with scores. Handles multiple score-member pairs. 112 | Returns the number of elements added (not updated). 113 | """ 114 | if len(args) % 2 != 0: 115 | raise ValueError("ZADD requires score-member pairs") 116 | 117 | added_count = 0 118 | pairs = [(score, member)] + list(zip(args[::2], args[1::2])) 119 | 120 | for s, m in pairs: 121 | if self.add(m, s): 122 | added_count += 1 123 | return added_count 124 | 125 | def zrem(self, *members): 126 | """ 127 | Removes members from the sorted set. 128 | Returns the number of members removed. 129 | """ 130 | removed_count = 0 131 | for member in members: 132 | if self.discard(member): 133 | removed_count += 1 134 | return removed_count 135 | 136 | def zscore(self, member): 137 | """ 138 | Get the score for a member. 139 | Returns score (as float) or None if member not found. 140 | """ 141 | return self._members.get(member) 142 | 143 | def zrank(self, member): 144 | """ 145 | Get the rank (0-based index) of a member, ordered by score (ascending). 146 | Returns rank or None if member not found. 147 | """ 148 | score = self._members.get(member) 149 | if score is None: 150 | return None 151 | # Find the first occurrence of this score 152 | index = bisect_left(self._scores, (score, member)) 153 | # Verify it's the correct member 154 | if index < len(self._scores) and self._scores[index] == (score, member): 155 | return index 156 | else: 157 | # Should not happen if member is in _members 158 | raise RuntimeError(f"Internal state inconsistency: could not find rank for {member}") 159 | 160 | def zrevrank(self, member): 161 | """ 162 | Get the rank (0-based index) of a member, ordered by score (descending). 163 | Returns rank or None if member not found. 164 | """ 165 | rank = self.zrank(member) 166 | return (len(self) - 1 - rank) if rank is not None else None 167 | 168 | def _parse_range_args(self, start, end): 169 | """Helper to parse Redis-style range arguments.""" 170 | try: 171 | start = int(start) 172 | end = int(end) 173 | except (ValueError, TypeError): 174 | raise ValueError("start and end must be integers") 175 | 176 | length = len(self) 177 | # Convert negative indices 178 | if start < 0: 179 | start = length + start 180 | if end < 0: 181 | end = length + end 182 | 183 | # Clamp indices to valid range [0, length-1] 184 | start = max(0, start) 185 | # end is inclusive, so clamp to length-1 186 | end = min(length - 1, end) 187 | 188 | return start, end, length 189 | 190 | def zrange(self, start, end, withscores=False, desc=False): 191 | """ 192 | Return members (and optionally scores) in the specified range of ranks. 193 | start and end are 0-based indices, inclusive. 194 | Negative indices count from the end (-1 is the last element). 195 | """ 196 | start, end, length = self._parse_range_args(start, end) 197 | 198 | if start > end or start >= length: 199 | return [] # Empty range 200 | 201 | # Python slice end index is exclusive, Redis is inclusive 202 | slice_end = end + 1 203 | 204 | if desc: 205 | # Calculate reversed indices for slicing 206 | rev_start = length - slice_end 207 | rev_end = length - start 208 | items = reversed(self._scores[rev_start:rev_end]) 209 | else: 210 | items = self._scores[start:slice_end] 211 | 212 | if withscores: 213 | # Return list of [member, score_str] 214 | return [[m, str(s)] for s, m in items] 215 | else: 216 | # Return list of members 217 | return [m for s, m in items] 218 | 219 | def zrevrange(self, start, end, withscores=False): 220 | """ 221 | Return members (and optionally scores) in the specified range of ranks, 222 | ordered from highest to lowest score. 223 | """ 224 | return self.zrange(start, end, withscores=withscores, desc=True) 225 | 226 | def _parse_score_range_args(self, min_score, max_score): 227 | """Helper to parse score range arguments.""" 228 | min_inclusive = True 229 | max_inclusive = True 230 | 231 | if isinstance(min_score, str) and min_score.startswith('('): 232 | min_inclusive = False 233 | min_score = min_score[1:] 234 | 235 | if isinstance(max_score, str) and max_score.startswith('('): 236 | max_inclusive = False 237 | max_score = max_score[1:] 238 | 239 | try: 240 | if min_score == '-inf': 241 | min_f = float('-inf') 242 | else: 243 | min_f = float(min_score) 244 | 245 | if max_score == '+inf': 246 | max_f = float('+inf') 247 | else: 248 | max_f = float(max_score) 249 | except (ValueError, TypeError): 250 | raise ValueError("min and max scores must be floats or representable as floats") 251 | 252 | return min_f, max_f, min_inclusive, max_inclusive 253 | 254 | def zrangebyscore(self, min_score, max_score, withscores=False, limit=None): 255 | """ 256 | Return members (and optionally scores) with scores between min_score and max_score. 257 | min/max can be exclusive using '(' prefix (e.g., '(1.0'). 258 | '-inf' and '+inf' are valid. 259 | limit is an optional (offset, count) tuple. 260 | """ 261 | if not self: 262 | return [] 263 | 264 | min_f, max_f, min_incl, max_incl = self._parse_score_range_args(min_score, max_score) 265 | 266 | # Find the start index 267 | if min_incl: 268 | # Find first element >= min_f 269 | left = bisect_left(self._scores, (min_f,)) 270 | else: 271 | # Find first element > min_f 272 | left = bisect_right(self._scores, (min_f, float('inf'))) # Find insertion point for score > min_f 273 | 274 | # Find the end index 275 | if max_incl: 276 | # Find first element > max_f 277 | right = bisect_right(self._scores, (max_f, float('inf'))) 278 | else: 279 | # Find first element >= max_f 280 | right = bisect_left(self._scores, (max_f,)) 281 | 282 | # Slice the relevant portion 283 | items = self._scores[left:right] 284 | 285 | # Apply limit if provided 286 | if limit is not None: 287 | try: 288 | offset, count = map(int, limit) 289 | if offset < 0 or count <= 0: 290 | # Redis returns empty list for invalid limit 291 | # although some versions might error. 292 | # Let's return empty list for simplicity. 293 | items = [] 294 | else: 295 | items = items[offset : offset + count] 296 | except (ValueError, TypeError, IndexError): 297 | raise ValueError("limit requires two integer arguments: offset, count") 298 | 299 | if withscores: 300 | return [[m, str(s)] for s, m in items] 301 | else: 302 | return [m for s, m in items] 303 | 304 | def zrevrangebyscore(self, max_score, min_score, withscores=False, limit=None): 305 | """ 306 | Return members (and optionally scores) with scores between max_score and min_score, 307 | ordered from highest to lowest score. 308 | """ 309 | # Get the range in ascending order first 310 | ascending_items = self.zrangebyscore(min_score, max_score, withscores=True, limit=None) 311 | 312 | # Reverse the result 313 | items = list(reversed(ascending_items)) 314 | 315 | # Apply limit after reversing 316 | if limit is not None: 317 | try: 318 | offset, count = map(int, limit) 319 | if offset < 0 or count <= 0: 320 | items = [] 321 | else: 322 | items = items[offset : offset + count] 323 | except (ValueError, TypeError, IndexError): 324 | raise ValueError("limit requires two integer arguments: offset, count") 325 | 326 | # Format output based on withscores 327 | if withscores: 328 | return items # Already in [member, score_str] format 329 | else: 330 | return [m for m, s in items] 331 | 332 | def zcard(self): 333 | """ 334 | Return the number of elements in the sorted set. 335 | """ 336 | return len(self) 337 | 338 | def zcount(self, min_score, max_score): 339 | """ 340 | Return the number of elements with scores between min_score and max_score. 341 | """ 342 | # Use zrangebyscore logic to find the range and return its length 343 | if not self: 344 | return 0 345 | 346 | min_f, max_f, min_incl, max_incl = self._parse_score_range_args(min_score, max_score) 347 | 348 | if min_incl: 349 | left = bisect_left(self._scores, (min_f,)) 350 | else: 351 | left = bisect_right(self._scores, (min_f, float('inf'))) 352 | 353 | if max_incl: 354 | right = bisect_right(self._scores, (max_f, float('inf'))) 355 | else: 356 | right = bisect_left(self._scores, (max_f,)) 357 | 358 | return max(0, right - left) 359 | 360 | # --- Methods below are less common or might need refinement --- 361 | 362 | def score(self, member): 363 | """ 364 | Alias for zscore, potentially deprecated. 365 | Get the score for a member. 366 | """ 367 | return self.zscore(member) 368 | 369 | def rank(self, member): 370 | """ 371 | Alias for zrank, potentially deprecated. 372 | Get the rank (index of a member). 373 | """ 374 | return self.zrank(member) 375 | 376 | def range(self, start, end, desc=False, withscores=False): 377 | """ 378 | Alias for zrange/zrevrange, potentially deprecated. 379 | Return members/scores between min and max ranks. 380 | """ 381 | return self.zrange(start, end, desc=desc, withscores=withscores) 382 | 383 | def scorerange(self, start, end, withscores=False): 384 | """ 385 | Alias for zrangebyscore, potentially deprecated. 386 | Return members/scores between min and max scores. 387 | """ 388 | return self.zrangebyscore(start, end, withscores=withscores) 389 | 390 | def items(self): 391 | """Return an iterator over (score, member) pairs.""" 392 | return iter(self._scores) 393 | 394 | def min_score(self): 395 | """Return the minimum score in the set.""" 396 | if not self: raise IndexError("SortedSet is empty") 397 | return self._scores[0][0] 398 | 399 | def max_score(self): 400 | """Return the maximum score in the set.""" 401 | if not self: raise IndexError("SortedSet is empty") 402 | return self._scores[-1][0] 403 | 404 | # Example Usage: 405 | if __name__ == '__main__': 406 | zs = SortedSet() 407 | zs.zadd(1, 'one') 408 | zs.zadd(3, 'three') 409 | zs.zadd(2, 'two') 410 | zs.zadd(2, 'deux') # Add another member with score 2 411 | 412 | print(f"Set: {zs!r}") 413 | print(f"Members (iter): {list(zs)}") 414 | print(f"Length: {len(zs)}") 415 | print(f"Contains 'two': {'two' in zs}") 416 | print(f"Score of 'three': {zs.zscore('three')}") 417 | print(f"Rank of 'two': {zs.zrank('two')}") # Rank depends on member name for ties 418 | print(f"Rank of 'deux': {zs.zrank('deux')}") 419 | print(f"RevRank of 'one': {zs.zrevrank('one')}") 420 | 421 | print(f"Range 0-2: {zs.zrange(0, 2)}") 422 | print(f"Range 0-2 with scores: {zs.zrange(0, 2, withscores=True)}") 423 | print(f"RevRange 0-1: {zs.zrevrange(0, 1)}") 424 | 425 | print(f"Range by score 1.5-2.5: {zs.zrangebyscore(1.5, 2.5)}") 426 | print(f"Range by score (1-3: {zs.zrangebyscore('(1', '3')}") 427 | print(f"Range by score 1-3 with scores: {zs.zrangebyscore(1, 3, withscores=True)}") 428 | print(f"Count score 1-2: {zs.zcount(1, 2)}") 429 | 430 | zs.discard('two') 431 | print(f"After discarding 'two': {zs!r}") 432 | print(f"Length: {len(zs)}") 433 | 434 | 435 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Copyright (c) 2013, Rui Carmo 6 | Description: Experimental Cython compile script 7 | License: MIT (see LICENSE.md for details) 8 | """ 9 | 10 | import os, sys 11 | from setuptools import setup, find_packages # Use setuptools and find_packages 12 | 13 | setup( 14 | name="miniredis", 15 | version="0.1.0", 16 | packages=find_packages(exclude=["tests*"]), # Automatically find packages 17 | setup_requires=['pytest-runner'], # Update test runner 18 | tests_require=['pytest'], # Update test dependencies 19 | author="Rui Carmo", 20 | author_email="rui@example.com", # Placeholder email 21 | description="Pure Python Redis protocol subset implementation", 22 | long_description=open("README.md").read(), 23 | long_description_content_type="text/markdown", 24 | license="MIT", 25 | keywords="redis server mock test", 26 | url="https://github.com/rcarmo/miniredis", 27 | classifiers=[ 28 | "Development Status :: 4 - Beta", 29 | "Intended Audience :: Developers", 30 | "License :: OSI Approved :: MIT License", 31 | "Operating System :: OS Independent", 32 | "Programming Language :: Python :: 3", 33 | "Programming Language :: Python :: 3.7", 34 | "Programming Language :: Python :: 3.8", 35 | "Programming Language :: Python :: 3.9", 36 | "Programming Language :: Python :: 3.10", 37 | "Programming Language :: Python :: 3.11", 38 | "Topic :: Software Development :: Testing", 39 | "Topic :: Database :: Front-Ends", 40 | ], 41 | python_requires=">=3.7", 42 | ) 43 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rcarmo/miniredis/3df6f2cd5bf27c40fc0da543c5bcc69a47021192/tests/__init__.py -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | # tests/helpers.py 2 | import os 3 | import sys 4 | import signal 5 | import time 6 | import socket 7 | from multiprocessing import Process 8 | from typing import Tuple 9 | import miniredis.server 10 | from miniredis.client import RedisClient 11 | 12 | # Adjust path to import miniredis from the parent directory 13 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 14 | if project_root not in sys.path: 15 | sys.path.insert(0, project_root) 16 | 17 | 18 | def find_free_port() -> int: 19 | """Finds an available port on localhost.""" 20 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 21 | s.bind(('localhost', 0)) 22 | return s.getsockname()[1] 23 | 24 | 25 | # Moved to module level so it can be pickled for multiprocessing 26 | def run_server_process(port: int) -> None: 27 | """Runs the Redis server process.""" 28 | # Signal handler for graceful shutdown within the child process 29 | def sigterm_handler(signum, frame): 30 | sys.exit(0) 31 | signal.signal(signal.SIGTERM, sigterm_handler) 32 | 33 | server: miniredis.server.RedisServer | None = None 34 | try: 35 | server = miniredis.server.RedisServer(port=port) 36 | print(f"Test server starting on port {port}...") 37 | server.run() # This blocks until server stops 38 | except KeyboardInterrupt: 39 | print("Test server received KeyboardInterrupt.") 40 | except Exception as e: 41 | print(f"Error in server process: {e}") 42 | finally: 43 | if server: 44 | print("Stopping test server...") 45 | server.stop() 46 | print(f"Test server on port {port} stopped.") 47 | 48 | 49 | def start_server() -> Tuple[Process, int]: 50 | """Starts the Redis server in a background process on a free port.""" 51 | test_port: int = find_free_port() 52 | 53 | # Start the server process using the module-level function 54 | proc = Process(target=run_server_process, args=(test_port,), daemon=True) 55 | proc.start() 56 | 57 | # Wait for the server to be ready by trying to connect 58 | max_wait: float = 5.0 # seconds 59 | start_wait: float = time.monotonic() 60 | connected: bool = False 61 | conn: socket.socket | None = None 62 | while time.monotonic() - start_wait < max_wait: 63 | if not proc.is_alive(): 64 | raise RuntimeError(f"Server process {proc.pid} terminated unexpectedly.") 65 | try: 66 | # Try to establish a connection 67 | conn = socket.create_connection(('localhost', test_port), timeout=0.1) 68 | conn.close() 69 | connected = True 70 | print(f"Test server with pid {proc.pid} ready on port {test_port}.") 71 | break 72 | except (ConnectionRefusedError, socket.timeout): 73 | time.sleep(0.1) # Wait a bit before retrying 74 | finally: 75 | if conn: 76 | conn.close() 77 | 78 | if not connected: 79 | stop_server(proc) # Clean up the process if connection failed 80 | raise RuntimeError(f"Server process failed to start or become ready on port {test_port} within {max_wait}s.") 81 | 82 | return proc, test_port 83 | 84 | 85 | def stop_server(proc: Process | None) -> None: 86 | """Stops the Redis server process.""" 87 | if proc and proc.is_alive(): 88 | print(f"Terminating test server with pid {proc.pid}.") 89 | proc.terminate() # Send SIGTERM 90 | proc.join(timeout=2) # Wait for graceful shutdown 91 | if proc.is_alive(): 92 | print(f"Server process {proc.pid} did not terminate gracefully, killing.") 93 | proc.kill() # Force kill if terminate fails 94 | proc.join(timeout=1) # Wait briefly for kill 95 | print(f"Server process {proc.pid} stopped.") 96 | elif proc: 97 | print(f"Server process {proc.pid} already stopped (exitcode: {proc.exitcode}).") 98 | 99 | -------------------------------------------------------------------------------- /tests/helpers_async.py: -------------------------------------------------------------------------------- 1 | # tests/helpers_async.py 2 | import os 3 | import sys 4 | import signal 5 | import time 6 | import socket 7 | import asyncio 8 | from multiprocessing import Process 9 | from typing import Tuple 10 | import miniredis.aioserver 11 | from miniredis.client import RedisClient 12 | 13 | # Adjust path to import miniredis from the parent directory 14 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 15 | if project_root not in sys.path: 16 | sys.path.insert(0, project_root) 17 | 18 | 19 | def find_free_port() -> int: 20 | """Finds an available port on localhost.""" 21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 22 | s.bind(('localhost', 0)) 23 | return s.getsockname()[1] 24 | 25 | 26 | # Main process function for running the async Redis server 27 | def run_async_server_process(port: int) -> None: 28 | """Runs the Redis server process.""" 29 | # Signal handler for graceful shutdown within the child process 30 | def sigterm_handler(signum, frame): 31 | if loop.is_running(): 32 | loop.stop() 33 | sys.exit(0) 34 | 35 | signal.signal(signal.SIGTERM, sigterm_handler) 36 | 37 | server = None 38 | loop = asyncio.new_event_loop() 39 | asyncio.set_event_loop(loop) 40 | 41 | try: 42 | server = miniredis.aioserver.AsyncRedisServer(port=port) 43 | print(f"Async test server starting on port {port}...") 44 | 45 | # Start the server 46 | async def start_server(): 47 | await server.start() 48 | 49 | loop.run_until_complete(start_server()) 50 | loop.run_forever() 51 | except KeyboardInterrupt: 52 | print("Async test server received KeyboardInterrupt.") 53 | except Exception as e: 54 | print(f"Error in async server process: {e}") 55 | finally: 56 | if server: 57 | print("Stopping async test server...") 58 | if loop.is_running(): 59 | loop.run_until_complete(server.stop()) 60 | else: 61 | # If loop is not running, create a new loop for cleanup 62 | temp_loop = asyncio.new_event_loop() 63 | temp_loop.run_until_complete(server.stop()) 64 | temp_loop.close() 65 | print(f"Async test server on port {port} stopped.") 66 | if loop.is_running(): 67 | loop.stop() 68 | loop.close() 69 | 70 | 71 | def start_async_server() -> Tuple[Process, int]: 72 | """Starts the Redis server in a background process on a free port.""" 73 | test_port: int = find_free_port() 74 | 75 | # Start the server process 76 | proc = Process(target=run_async_server_process, args=(test_port,), daemon=True) 77 | proc.start() 78 | 79 | # Wait for the server to be ready by trying to connect 80 | max_wait: float = 5.0 # seconds 81 | start_wait: float = time.monotonic() 82 | connected: bool = False 83 | conn: socket.socket | None = None 84 | while time.monotonic() - start_wait < max_wait: 85 | if not proc.is_alive(): 86 | raise RuntimeError(f"Async server process {proc.pid} terminated unexpectedly.") 87 | try: 88 | # Try to establish a connection 89 | conn = socket.create_connection(('localhost', test_port), timeout=0.1) 90 | conn.close() 91 | connected = True 92 | print(f"Async test server with pid {proc.pid} ready on port {test_port}.") 93 | break 94 | except (ConnectionRefusedError, socket.timeout): 95 | time.sleep(0.1) # Wait a bit before retrying 96 | finally: 97 | if conn: 98 | conn.close() 99 | 100 | if not connected: 101 | stop_async_server(proc) # Clean up the process if connection failed 102 | raise RuntimeError(f"Async server process failed to start or become ready on port {test_port} within {max_wait}s.") 103 | 104 | return proc, test_port 105 | 106 | 107 | def stop_async_server(proc: Process | None) -> None: 108 | """Stops the Redis server process.""" 109 | if proc and proc.is_alive(): 110 | print(f"Terminating async test server with pid {proc.pid}.") 111 | proc.terminate() # Send SIGTERM 112 | proc.join(timeout=2) # Wait for graceful shutdown 113 | if proc.is_alive(): 114 | print(f"Async server process {proc.pid} did not terminate gracefully, killing.") 115 | proc.kill() # Force kill if terminate fails 116 | proc.join(timeout=1) # Wait briefly for kill 117 | print(f"Async server process {proc.pid} stopped.") 118 | elif proc: 119 | print(f"Async server process {proc.pid} already stopped (exitcode: {proc.exitcode}).") -------------------------------------------------------------------------------- /tests/test_keys.py: -------------------------------------------------------------------------------- 1 | # vim :set ts=4 sw=4 sts=4 et : 2 | import os 3 | import sys 4 | import time 5 | import pytest 6 | from multiprocessing import Process 7 | from typing import Generator 8 | 9 | # Changed from relative import to absolute import 10 | from tests.helpers import start_server, stop_server 11 | from miniredis.client import RedisClient 12 | 13 | @pytest.fixture(scope="module") 14 | def redis_client() -> Generator[RedisClient, None, None]: 15 | """Pytest fixture to start/stop the miniredis server and provide a client.""" 16 | server_process: Process | None = None 17 | r_client: RedisClient | None = None 18 | try: 19 | server_process, test_port = start_server() 20 | r_client = RedisClient(port=test_port) 21 | r_client.flushdb() # Flush DB before tests start 22 | yield r_client # Provide the client to the tests 23 | except Exception as e: 24 | print(f"Error during fixture setup in test_keys: {e}") 25 | pytest.fail(f"Fixture setup failed: {e}") # Fail tests if fixture fails 26 | finally: 27 | # Teardown: Stop client and server 28 | print("Tearing down test_keys fixture...") 29 | if r_client: 30 | try: 31 | r_client.close() 32 | print("Redis client closed.") 33 | except Exception as e: 34 | print(f"Error closing redis client: {e}") 35 | if server_process: 36 | stop_server(server_process) 37 | print("Fixture teardown complete.") 38 | 39 | class TestKeysCommands: 40 | 41 | def test_put_get(self, redis_client: RedisClient): 42 | """Test basic SET and GET""" 43 | r = redis_client 44 | assert r.set("test:key", "value") == "OK" 45 | result = r.get("test:key") 46 | assert result == b"value" 47 | assert result.decode("utf-8") == "value" 48 | 49 | def test_get_nonexistent(self, redis_client: RedisClient): 50 | """Test GET on a non-existent key""" 51 | r = redis_client 52 | assert r.get("test:notakey") is None 53 | 54 | def test_del(self, redis_client: RedisClient): 55 | """Test DEL command""" 56 | r = redis_client 57 | r.set("test:keydel1", "value1") 58 | r.set("test:keydel2", "value2") 59 | r.set("test:keydel3", "value3") 60 | # single key 61 | assert r.delete("test:keydel1") == 1 62 | assert r.get("test:keydel1") is None 63 | # multiple keys 64 | assert r.delete("test:keydel2", "test:keydel3") == 2 65 | assert r.get("test:keydel2") is None 66 | assert r.get("test:keydel3") is None 67 | # non-existent key 68 | assert r.delete("test:notthere") == 0 69 | 70 | def test_exists(self, redis_client: RedisClient): 71 | """Test EXISTS command""" 72 | r = redis_client 73 | r.set("test:keyexists", "value") 74 | assert r.exists("test:keyexists") == 1 75 | assert r.exists("test:notthere") == 0 76 | 77 | def test_expire_ttl(self, redis_client: RedisClient): 78 | """Test EXPIRE and TTL commands""" 79 | r = redis_client 80 | r.set("test:keyexpire", "value") 81 | # missing key 82 | assert r.expire("test:notthere", 2) == 0 83 | # valid setting 84 | assert r.expire("test:keyexpire", 2) == 1 85 | # TTL should be close to 2 (allow for slight delay) 86 | ttl = r.ttl("test:keyexpire") 87 | assert 1 <= ttl <= 2, f"TTL ({ttl}) not within expected range [1, 2]" 88 | time.sleep(2.1) 89 | assert r.ttl("test:keyexpire") == -2 # Should be expired (-2) 90 | assert r.exists("test:keyexpire") == 0 91 | 92 | # reset ttl with SET 93 | r.set("test:keyexpire_reset", "value") 94 | assert r.expire("test:keyexpire_reset", 5) == 1 95 | assert r.ttl("test:keyexpire_reset") > 0 96 | assert r.set("test:keyexpire_reset", "newvalue") == "OK" 97 | assert r.ttl("test:keyexpire_reset") == -1 # SET should remove TTL 98 | 99 | def test_expireat_pttl(self, redis_client: RedisClient): 100 | """Test EXPIREAT and PTTL commands""" 101 | r = redis_client 102 | r.set("test:keyexpireat", "value") 103 | # missing key 104 | at_ts = int(time.time() + 2) 105 | assert r.expireat("test:notthere_at", at_ts) == 0 106 | # valid setting 107 | at_ts = int(time.time() + 2) 108 | assert r.expireat("test:keyexpireat", at_ts) == 1 109 | # PTTL should be close to 2000ms 110 | pttl = r.pttl("test:keyexpireat") 111 | assert ( 112 | 1000 <= pttl <= 2000 113 | ), f"PTTL ({pttl}) not within expected range [1000, 2000]" 114 | time.sleep(2.1) 115 | assert r.pttl("test:keyexpireat") == -2 # Should be expired (-2) 116 | assert r.exists("test:keyexpireat") == 0 117 | 118 | # reset ttl with SET 119 | r.set("test:keyexpireat_reset", "value") 120 | at_ts = int(time.time() + 5) 121 | assert r.expireat("test:keyexpireat_reset", at_ts) == 1 122 | assert r.pttl("test:keyexpireat_reset") > 0 123 | assert r.set("test:keyexpireat_reset", "newvalue") == "OK" 124 | assert r.pttl("test:keyexpireat_reset") == -1 # SET should remove TTL 125 | 126 | def test_keys(self, redis_client: RedisClient): 127 | """Test KEYS command""" 128 | r = redis_client 129 | # Clear previous keys potentially matching pattern 130 | r.flushdb() 131 | # place test keys 132 | r.set("test:keys:key1", "value") 133 | r.set("test:keys:key2", "value") 134 | r.set("other:keys:key3", "value") 135 | # KEYS returns list of bytes 136 | result_bytes = r.keys("test:keys:*") 137 | result_strings = sorted([k.decode("utf-8") for k in result_bytes]) 138 | assert result_strings == ["test:keys:key1", "test:keys:key2"] 139 | 140 | result_bytes = r.keys("*:key?") 141 | result_strings = sorted([k.decode("utf-8") for k in result_bytes]) 142 | assert result_strings == ["other:keys:key3", "test:keys:key1", "test:keys:key2"] 143 | 144 | result_bytes = r.keys("*nomatch*") 145 | assert result_bytes == [] 146 | -------------------------------------------------------------------------------- /tests/test_keys_async.py: -------------------------------------------------------------------------------- 1 | # vim :set ts=4 sw=4 sts=4 et : 2 | import os 3 | import sys 4 | import time 5 | import pytest 6 | from multiprocessing import Process 7 | from typing import Generator 8 | 9 | # Import helpers for async server testing 10 | from tests.helpers_async import start_async_server, stop_async_server 11 | from miniredis.client import RedisClient 12 | 13 | @pytest.fixture(scope="module") 14 | def redis_client_async() -> Generator[RedisClient, None, None]: 15 | """Pytest fixture to start/stop the async miniredis server and provide a client.""" 16 | server_process: Process | None = None 17 | r_client: RedisClient | None = None 18 | try: 19 | server_process, test_port = start_async_server() 20 | r_client = RedisClient(port=test_port) 21 | r_client.flushdb() # Flush DB before tests start 22 | yield r_client # Provide the client to the tests 23 | except Exception as e: 24 | print(f"Error during fixture setup in test_keys_async: {e}") 25 | pytest.fail(f"Fixture setup failed: {e}") # Fail tests if fixture fails 26 | finally: 27 | # Teardown: Stop client and server 28 | print("Tearing down test_keys_async fixture...") 29 | if r_client: 30 | try: 31 | r_client.close() 32 | print("Redis client closed.") 33 | except Exception as e: 34 | print(f"Error closing redis client: {e}") 35 | if server_process: 36 | stop_async_server(server_process) 37 | print("Fixture teardown complete.") 38 | 39 | class TestAsyncKeysCommands: 40 | """Test Redis key commands with the async server implementation.""" 41 | 42 | def test_put_get(self, redis_client_async: RedisClient): 43 | """Test basic SET and GET""" 44 | r = redis_client_async 45 | assert r.set("test:key", "value") == "OK" 46 | result = r.get("test:key") 47 | assert result == b"value" 48 | assert result.decode("utf-8") == "value" 49 | 50 | def test_get_nonexistent(self, redis_client_async: RedisClient): 51 | """Test GET on a non-existent key""" 52 | r = redis_client_async 53 | assert r.get("test:notakey") is None 54 | 55 | def test_del(self, redis_client_async: RedisClient): 56 | """Test DEL command""" 57 | r = redis_client_async 58 | r.set("test:keydel1", "value1") 59 | r.set("test:keydel2", "value2") 60 | r.set("test:keydel3", "value3") 61 | # single key 62 | assert r.delete("test:keydel1") == 1 63 | assert r.get("test:keydel1") is None 64 | # multiple keys 65 | assert r.delete("test:keydel2", "test:keydel3") == 2 66 | assert r.get("test:keydel2") is None 67 | assert r.get("test:keydel3") is None 68 | # non-existent key 69 | assert r.delete("test:notthere") == 0 70 | 71 | def test_exists(self, redis_client_async: RedisClient): 72 | """Test EXISTS command""" 73 | r = redis_client_async 74 | r.set("test:keyexists", "value") 75 | assert r.exists("test:keyexists") == 1 76 | assert r.exists("test:notthere") == 0 77 | 78 | def test_expire_ttl(self, redis_client_async: RedisClient): 79 | """Test EXPIRE and TTL commands""" 80 | r = redis_client_async 81 | r.set("test:keyexpire", "value") 82 | # missing key 83 | assert r.expire("test:notthere", 2) == 0 84 | # valid setting 85 | assert r.expire("test:keyexpire", 2) == 1 86 | # TTL should be close to 2 (allow for slight delay) 87 | ttl = r.ttl("test:keyexpire") 88 | assert 1 <= ttl <= 2, f"TTL ({ttl}) not within expected range [1, 2]" 89 | time.sleep(2.1) 90 | assert r.ttl("test:keyexpire") == -2 # Should be expired (-2) 91 | assert r.exists("test:keyexpire") == 0 92 | 93 | # reset ttl with SET 94 | r.set("test:keyexpire_reset", "value") 95 | assert r.expire("test:keyexpire_reset", 5) == 1 96 | assert r.ttl("test:keyexpire_reset") > 0 97 | assert r.set("test:keyexpire_reset", "newvalue") == "OK" 98 | assert r.ttl("test:keyexpire_reset") == -1 # SET should remove TTL 99 | 100 | def test_expireat_pttl(self, redis_client_async: RedisClient): 101 | """Test EXPIREAT and PTTL commands""" 102 | r = redis_client_async 103 | r.set("test:keyexpireat", "value") 104 | # missing key 105 | at_ts = int(time.time() + 2) 106 | assert r.expireat("test:notthere_at", at_ts) == 0 107 | # valid setting 108 | at_ts = int(time.time() + 2) 109 | assert r.expireat("test:keyexpireat", at_ts) == 1 110 | # PTTL should be close to 2000ms 111 | pttl = r.pttl("test:keyexpireat") 112 | assert ( 113 | 1000 <= pttl <= 2000 114 | ), f"PTTL ({pttl}) not within expected range [1000, 2000]" 115 | time.sleep(2.1) 116 | assert r.pttl("test:keyexpireat") == -2 # Should be expired (-2) 117 | assert r.exists("test:keyexpireat") == 0 118 | 119 | # reset ttl with SET 120 | r.set("test:keyexpireat_reset", "value") 121 | at_ts = int(time.time() + 5) 122 | assert r.expireat("test:keyexpireat_reset", at_ts) == 1 123 | assert r.pttl("test:keyexpireat_reset") > 0 124 | assert r.set("test:keyexpireat_reset", "newvalue") == "OK" 125 | assert r.pttl("test:keyexpireat_reset") == -1 # SET should remove TTL 126 | 127 | def test_keys(self, redis_client_async: RedisClient): 128 | """Test KEYS command""" 129 | r = redis_client_async 130 | # Clear previous keys potentially matching pattern 131 | r.flushdb() 132 | # place test keys 133 | r.set("test:keys:key1", "value") 134 | r.set("test:keys:key2", "value") 135 | r.set("other:keys:key3", "value") 136 | # KEYS returns list of bytes 137 | result_bytes = r.keys("test:keys:*") 138 | result_strings = sorted([k.decode("utf-8") for k in result_bytes]) 139 | assert result_strings == ["test:keys:key1", "test:keys:key2"] 140 | 141 | result_bytes = r.keys("*:key?") 142 | result_strings = sorted([k.decode("utf-8") for k in result_bytes]) 143 | assert result_strings == ["other:keys:key3", "test:keys:key1", "test:keys:key2"] 144 | 145 | result_bytes = r.keys("*nomatch*") 146 | assert result_bytes == [] -------------------------------------------------------------------------------- /tests/test_strings.py: -------------------------------------------------------------------------------- 1 | # vim :set ts=4 sw=4 sts=4 et : 2 | import os 3 | import sys 4 | import time 5 | import pytest # Changed from unittest 6 | from multiprocessing import Process 7 | from typing import Generator, Any 8 | 9 | # Changed from relative import to absolute import 10 | from tests.helpers import start_server, stop_server 11 | from miniredis.client import RedisClient 12 | 13 | # Use the same fixture as test_keys.py if tests can share the same server instance 14 | # If they need independent servers, define a similar fixture here or adjust scope. 15 | # For simplicity, let's assume they can share the module-scoped server. 16 | # If you defined the fixture in conftest.py, you wouldn't need to import it here. 17 | from tests.test_keys import redis_client # Changed to absolute import 18 | 19 | # No longer need unittest.TestCase 20 | class TestStringCommands: 21 | 22 | # No longer need setUp with module-scoped fixture 23 | # def setUp(self): 24 | # pass 25 | 26 | def test_append(self, redis_client: RedisClient): 27 | """Test APPEND command""" 28 | r = redis_client 29 | # Key exists 30 | assert r.set('test:append:key1', 'value') == 'OK' 31 | assert r.append('test:append:key1', 'more') == 9 # Returns length after append 32 | assert r.get('test:append:key1') == b'valuemore' 33 | 34 | # Key does not exist 35 | assert r.append('test:append:key2', 'newvalue') == 8 # Creates key, returns length 36 | assert r.get('test:append:key2') == b'newvalue' 37 | 38 | # Append to non-string (should fail or be handled by server) 39 | r.lpush('test:append:list', 'item') 40 | with pytest.raises(Exception, match="Operation against a key holding the wrong kind of value"): 41 | r.append('test:append:list', 'stuff') 42 | 43 | def test_incr_decr(self, redis_client: RedisClient): 44 | """Test INCR and DECR commands""" 45 | r = redis_client 46 | assert r.set('test:counter', '10') == 'OK' 47 | assert r.incr('test:counter') == 11 48 | assert r.get('test:counter') == b'11' 49 | assert r.incrby('test:counter', 5) == 16 50 | assert r.get('test:counter') == b'16' 51 | 52 | assert r.decr('test:counter') == 15 53 | assert r.get('test:counter') == b'15' 54 | assert r.decrby('test:counter', 5) == 10 55 | assert r.get('test:counter') == b'10' 56 | 57 | # Non-existent key 58 | assert r.incr('test:newcounter') == 1 59 | assert r.get('test:newcounter') == b'1' 60 | assert r.decr('test:newcounter2') == -1 61 | assert r.get('test:newcounter2') == b'-1' 62 | 63 | # Error cases 64 | r.set('test:notint', 'hello') 65 | with pytest.raises(Exception, match="value is not an integer"): 66 | r.incr('test:notint') 67 | 68 | with pytest.raises(Exception, match="value is not an integer"): 69 | r.decr('test:notint') 70 | 71 | def test_getset(self, redis_client: RedisClient): 72 | """Test GETSET command""" 73 | r = redis_client 74 | # Key exists 75 | r.set('test:getset:key1', 'oldvalue') 76 | old: bytes | None = r.getset('test:getset:key1', 'newvalue') 77 | assert old == b'oldvalue' 78 | assert r.get('test:getset:key1') == b'newvalue' 79 | 80 | # Key does not exist 81 | old = r.getset('test:getset:key2', 'firstvalue') 82 | assert old is None # Returns nil (None for client) when key didn't exist 83 | assert r.get('test:getset:key2') == b'firstvalue' 84 | 85 | # GETSET on non-string type 86 | r.lpush('test:getset:list', 'item') 87 | with pytest.raises(Exception, match="Operation against a key holding the wrong kind of value"): 88 | r.getset('test:getset:list', 'new') 89 | 90 | def test_mget(self, redis_client: RedisClient): 91 | """Test MGET command""" 92 | r = redis_client 93 | r.set('test:mget:key1', 'val1') 94 | r.set('test:mget:key2', 'val2') 95 | r.lpush('test:mget:list', 'item') # A non-string key 96 | 97 | results: list[bytes | None] = r.mget('test:mget:key1', 'test:mget:nonexistent', 'test:mget:key2', 'test:mget:list') 98 | expected: list[bytes | None] = [b'val1', None, b'val2', None] # Expect None for non-existent and wrong type 99 | assert results == expected 100 | 101 | # Empty list 102 | assert r.mget() == [] 103 | 104 | def test_setnx(self, redis_client: RedisClient): 105 | """Test SETNX command""" 106 | r = redis_client 107 | # Key does not exist 108 | assert r.setnx('test:setnx:key1', 'value1') == 1 109 | assert r.get('test:setnx:key1') == b'value1' 110 | 111 | # Key exists 112 | assert r.setnx('test:setnx:key1', 'value2') == 0 113 | assert r.get('test:setnx:key1') == b'value1' # Value should not change 114 | 115 | def test_setex(self, redis_client: RedisClient): 116 | """Test SETEX command""" 117 | r = redis_client 118 | assert r.setex('test:setex:key1', 2, 'value') == 'OK' 119 | assert r.get('test:setex:key1') == b'value' 120 | ttl = r.ttl('test:setex:key1') 121 | assert 1 <= ttl <= 2, f"TTL ({ttl}) not within expected range [1, 2]" 122 | time.sleep(2.1) 123 | assert r.get('test:setex:key1') is None 124 | assert r.ttl('test:setex:key1') == -2 125 | 126 | # Invalid TTL (non-integer) 127 | with pytest.raises(Exception, match="value is not an integer"): 128 | # The client might raise TypeError or similar before sending 129 | # or the server might return an error. 130 | # Adjust match based on actual behavior. 131 | r.setex('test:setex:key2', 'notanumber', 'value') # type: ignore 132 | 133 | # Invalid TTL (negative) 134 | with pytest.raises(Exception, match="invalid expire time"): 135 | # Check specific error if server provides one, otherwise generic Exception 136 | r.setex('test:setex:key3', -10, 'value') 137 | 138 | # No longer need nose specific execution block 139 | # if __name__ == '__main__': 140 | # import nose 141 | # nose.runmodule() 142 | -------------------------------------------------------------------------------- /tests/test_strings_async.py: -------------------------------------------------------------------------------- 1 | # vim :set ts=4 sw=4 sts=4 et : 2 | import os 3 | import sys 4 | import time 5 | import pytest 6 | from multiprocessing import Process 7 | from typing import Generator, Any 8 | 9 | # Import helpers for async server testing 10 | from tests.helpers_async import start_async_server, stop_async_server 11 | from miniredis.client import RedisClient 12 | 13 | @pytest.fixture(scope="module") 14 | def redis_client_async() -> Generator[RedisClient, None, None]: 15 | """Pytest fixture to start/stop the async miniredis server and provide a client.""" 16 | server_process: Process | None = None 17 | r_client: RedisClient | None = None 18 | try: 19 | server_process, test_port = start_async_server() 20 | r_client = RedisClient(port=test_port) 21 | r_client.flushdb() # Flush DB before tests start 22 | yield r_client # Provide the client to the tests 23 | except Exception as e: 24 | print(f"Error during fixture setup in test_strings_async: {e}") 25 | pytest.fail(f"Fixture setup failed: {e}") # Fail tests if fixture fails 26 | finally: 27 | # Teardown: Stop client and server 28 | print("Tearing down test_strings_async fixture...") 29 | if r_client: 30 | try: 31 | r_client.close() 32 | print("Redis client closed.") 33 | except Exception as e: 34 | print(f"Error closing redis client: {e}") 35 | if server_process: 36 | stop_async_server(server_process) 37 | print("Fixture teardown complete.") 38 | 39 | class TestAsyncStringCommands: 40 | """Test Redis string commands with the async server implementation.""" 41 | 42 | def test_append(self, redis_client_async: RedisClient): 43 | """Test APPEND command""" 44 | r = redis_client_async 45 | # Key exists 46 | assert r.set('test:append:key1', 'value') == 'OK' 47 | assert r.append('test:append:key1', 'more') == 9 # Returns length after append 48 | assert r.get('test:append:key1') == b'valuemore' 49 | 50 | # Key does not exist 51 | assert r.append('test:append:key2', 'newvalue') == 8 # Creates key, returns length 52 | assert r.get('test:append:key2') == b'newvalue' 53 | 54 | # Append to non-string (should fail or be handled by server) 55 | r.lpush('test:append:list', 'item') 56 | with pytest.raises(Exception, match="Operation against a key holding the wrong kind of value"): 57 | r.append('test:append:list', 'stuff') 58 | 59 | def test_incr_decr(self, redis_client_async: RedisClient): 60 | """Test INCR and DECR commands""" 61 | r = redis_client_async 62 | assert r.set('test:counter', '10') == 'OK' 63 | assert r.incr('test:counter') == 11 64 | assert r.get('test:counter') == b'11' 65 | assert r.incrby('test:counter', 5) == 16 66 | assert r.get('test:counter') == b'16' 67 | 68 | assert r.decr('test:counter') == 15 69 | assert r.get('test:counter') == b'15' 70 | assert r.decrby('test:counter', 5) == 10 71 | assert r.get('test:counter') == b'10' 72 | 73 | # Non-existent key 74 | assert r.incr('test:newcounter') == 1 75 | assert r.get('test:newcounter') == b'1' 76 | assert r.decr('test:newcounter2') == -1 77 | assert r.get('test:newcounter2') == b'-1' 78 | 79 | # Error cases 80 | r.set('test:notint', 'hello') 81 | with pytest.raises(Exception, match="value is not an integer"): 82 | r.incr('test:notint') 83 | 84 | with pytest.raises(Exception, match="value is not an integer"): 85 | r.decr('test:notint') 86 | 87 | def test_getset(self, redis_client_async: RedisClient): 88 | """Test GETSET command""" 89 | r = redis_client_async 90 | # Key exists 91 | r.set('test:getset:key1', 'oldvalue') 92 | old: bytes | None = r.getset('test:getset:key1', 'newvalue') 93 | assert old == b'oldvalue' 94 | assert r.get('test:getset:key1') == b'newvalue' 95 | 96 | # Key does not exist 97 | old = r.getset('test:getset:key2', 'firstvalue') 98 | assert old is None # Returns nil (None for client) when key didn't exist 99 | assert r.get('test:getset:key2') == b'firstvalue' 100 | 101 | # GETSET on non-string type 102 | r.lpush('test:getset:list', 'item') 103 | with pytest.raises(Exception, match="Operation against a key holding the wrong kind of value"): 104 | r.getset('test:getset:list', 'new') 105 | 106 | def test_mget(self, redis_client_async: RedisClient): 107 | """Test MGET command""" 108 | r = redis_client_async 109 | r.set('test:mget:key1', 'val1') 110 | r.set('test:mget:key2', 'val2') 111 | r.lpush('test:mget:list', 'item') # A non-string key 112 | 113 | results: list[bytes | None] = r.mget('test:mget:key1', 'test:mget:nonexistent', 'test:mget:key2', 'test:mget:list') 114 | expected: list[bytes | None] = [b'val1', None, b'val2', None] # Expect None for non-existent and wrong type 115 | assert results == expected 116 | 117 | # Empty list 118 | assert r.mget() == [] 119 | 120 | def test_setnx(self, redis_client_async: RedisClient): 121 | """Test SETNX command""" 122 | r = redis_client_async 123 | # Key does not exist 124 | assert r.setnx('test:setnx:key1', 'value1') == 1 125 | assert r.get('test:setnx:key1') == b'value1' 126 | 127 | # Key exists 128 | assert r.setnx('test:setnx:key1', 'value2') == 0 129 | assert r.get('test:setnx:key1') == b'value1' # Value should not change 130 | 131 | def test_setex(self, redis_client_async: RedisClient): 132 | """Test SETEX command""" 133 | r = redis_client_async 134 | assert r.setex('test:setex:key1', 2, 'value') == 'OK' 135 | assert r.get('test:setex:key1') == b'value' 136 | ttl = r.ttl('test:setex:key1') 137 | assert 1 <= ttl <= 2, f"TTL ({ttl}) not within expected range [1, 2]" 138 | time.sleep(2.1) 139 | assert r.get('test:setex:key1') is None 140 | assert r.ttl('test:setex:key1') == -2 141 | 142 | # Invalid TTL (non-integer) 143 | with pytest.raises(Exception, match="value is not an integer"): 144 | # The client might raise TypeError or similar before sending 145 | # or the server might return an error. 146 | # Adjust match based on actual behavior. 147 | r.setex('test:setex:key2', 'notanumber', 'value') # type: ignore 148 | 149 | # Invalid TTL (negative) 150 | with pytest.raises(Exception, match="invalid expire time"): 151 | # Check specific error if server provides one, otherwise generic Exception 152 | r.setex('test:setex:key3', -10, 'value') --------------------------------------------------------------------------------