├── aiokeydb ├── version.py ├── types │ ├── indexes │ │ ├── __init__.py │ │ └── utils.py │ ├── __init__.py │ ├── static.py │ ├── serializer.py │ └── compat.py ├── v2 │ ├── types │ │ ├── __init__.py │ │ ├── indexes │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── serializer.py │ │ └── static.py │ ├── README.md │ ├── configs │ │ └── __init__.py │ ├── queues.py │ ├── backoff.py │ ├── utils │ │ ├── __init__.py │ │ ├── cron.py │ │ ├── logs.py │ │ ├── base.py │ │ └── helpers.py │ ├── exceptions.py │ ├── serializers │ │ ├── _msgpack.py │ │ ├── __init__.py │ │ ├── _pickle.py │ │ └── _json.py │ ├── commands │ │ └── __init__.py │ ├── typing.py │ └── __init__.py ├── v1 │ ├── client │ │ ├── schemas │ │ │ └── __init__.py │ │ ├── serializers │ │ │ ├── base.py │ │ │ ├── _msgpack.py │ │ │ ├── __init__.py │ │ │ ├── _pickle.py │ │ │ └── _json.py │ │ ├── __init__.py │ │ └── utils.py │ ├── queues │ │ ├── worker │ │ │ └── __init__.py │ │ ├── errors.py │ │ ├── __init__.py │ │ ├── imports │ │ │ └── cron.py │ │ ├── README.md │ │ └── utils.py │ ├── commands │ │ ├── graph │ │ │ ├── exceptions.py │ │ │ ├── path.py │ │ │ ├── node.py │ │ │ └── edge.py │ │ ├── json │ │ │ ├── _util.py │ │ │ ├── path.py │ │ │ ├── decoders.py │ │ │ └── __init__.py │ │ ├── search │ │ │ ├── _util.py │ │ │ ├── document.py │ │ │ ├── suggestion.py │ │ │ ├── result.py │ │ │ ├── indexDefinition.py │ │ │ ├── reducers.py │ │ │ └── field.py │ │ ├── __init__.py │ │ ├── timeseries │ │ │ ├── utils.py │ │ │ ├── info.py │ │ │ └── __init__.py │ │ ├── bf │ │ │ └── info.py │ │ ├── redismodules.py │ │ ├── sentinel.py │ │ └── helpers.py │ ├── compat.py │ ├── crc.py │ ├── credentials.py │ ├── asyncio │ │ ├── utils.py │ │ ├── client.py │ │ ├── __init__.py │ │ ├── retry.py │ │ └── parser.py │ ├── retry.py │ ├── typing.py │ ├── backoff.py │ ├── utils.py │ ├── __init__.py │ └── exceptions.py ├── queues.py ├── compat.py ├── backoff.py ├── utils │ ├── __init__.py │ ├── lazy.py │ ├── cron.py │ ├── logs.py │ └── base.py ├── configs │ └── __init__.py ├── crc.py ├── exceptions.py ├── serializers │ ├── _msgpack.py │ ├── __init__.py │ ├── _json.py │ └── _pickle.py ├── commands │ └── __init__.py ├── retry.py ├── typing.py └── __init__.py ├── .gitattributes ├── MANIFEST.in ├── tests ├── test_keydb.py ├── test_client.py ├── test_task_queue.py ├── test_multi_config.py ├── test_cron_v2.py ├── test_dict.py ├── test_v2_cachify.py ├── test_v3_cachify.py └── test_suite.py ├── .gitignore ├── .github └── workflows │ └── python-publish.yml ├── setup.py └── CHANGELOGS.md /aiokeydb/version.py: -------------------------------------------------------------------------------- 1 | VERSION = '0.2.1' -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include aiokeydb * 2 | recursive-exclude * __pycache__ 3 | recursive-exclude * *.py[co] 4 | -------------------------------------------------------------------------------- /aiokeydb/types/indexes/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import KDBIndex 3 | from .kdbdict import KDBDict, AsyncKDBDict -------------------------------------------------------------------------------- /aiokeydb/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .serializer import BaseSerializer 3 | from .indexes import * -------------------------------------------------------------------------------- /aiokeydb/v2/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .serializer import BaseSerializer 3 | from .indexes import * -------------------------------------------------------------------------------- /aiokeydb/v2/types/indexes/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import KDBIndex 3 | from .kdbdict import KDBDict, AsyncKDBDict -------------------------------------------------------------------------------- /aiokeydb/v1/client/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.v1.client.schemas.session import KeyDBSession -------------------------------------------------------------------------------- /aiokeydb/v1/queues/worker/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.v1.queues.worker.base import Worker, WorkerTasks 4 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/graph/exceptions.py: -------------------------------------------------------------------------------- 1 | class VersionMismatchException(Exception): 2 | def __init__(self, version): 3 | self.version = version 4 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/json/_util.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] 4 | -------------------------------------------------------------------------------- /aiokeydb/v2/README.md: -------------------------------------------------------------------------------- 1 | # aiokeydb - v2 2 | 3 | This version will soon replace the current existing version 4 | where the focus is on maintaining upstream compatibility 5 | with `redis-py`. -------------------------------------------------------------------------------- /aiokeydb/v2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.v2.configs.worker import KeyDBWorkerSettings 4 | from aiokeydb.v2.configs.core import KeyDBSettings 5 | 6 | settings = KeyDBSettings() -------------------------------------------------------------------------------- /aiokeydb/queues.py: -------------------------------------------------------------------------------- 1 | """ 2 | For previous migration 3 | """ 4 | 5 | from aiokeydb.types.jobs import Job, CronJob, JobStatus, TaskType 6 | from aiokeydb.types.task_queue import TaskQueue 7 | from aiokeydb.types.worker import Worker 8 | -------------------------------------------------------------------------------- /aiokeydb/v2/queues.py: -------------------------------------------------------------------------------- 1 | """ 2 | For previous migration 3 | """ 4 | 5 | from aiokeydb.v2.types.jobs import Job, CronJob, JobStatus, TaskType 6 | from aiokeydb.v2.types.task_queue import TaskQueue 7 | from aiokeydb.v2.types.worker import Worker 8 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/search/_util.py: -------------------------------------------------------------------------------- 1 | def to_string(s): 2 | if isinstance(s, str): 3 | return s 4 | elif isinstance(s, bytes): 5 | return s.decode("utf-8", "ignore") 6 | else: 7 | return s # Not a string we care about 8 | -------------------------------------------------------------------------------- /aiokeydb/compat.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | try: 3 | from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] 4 | except ImportError: 5 | from typing_extensions import ( # lgtm [py/unused-import] 6 | Literal, 7 | Protocol, 8 | TypedDict, 9 | ) 10 | -------------------------------------------------------------------------------- /aiokeydb/v1/compat.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | try: 3 | from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] 4 | except ImportError: 5 | from typing_extensions import ( # lgtm [py/unused-import] 6 | Literal, 7 | Protocol, 8 | TypedDict, 9 | ) 10 | -------------------------------------------------------------------------------- /aiokeydb/v1/queues/errors.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.queues.types import Job 2 | 3 | class JobError(Exception): 4 | def __init__(self, job: 'Job'): 5 | super().__init__( 6 | f"Job {job.id} {job.status}\n\nThe above job failed with the following error:\n\n{job.error}" 7 | ) 8 | self.job = job -------------------------------------------------------------------------------- /aiokeydb/types/indexes/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def check_numeric(x) -> bool: 3 | """ 4 | Check if a value is numeric 5 | """ 6 | if isinstance(x, (int, float, complex)): return True 7 | if isinstance(x, str): 8 | try: 9 | float(x) 10 | return True 11 | except Exception as e: 12 | return False 13 | -------------------------------------------------------------------------------- /aiokeydb/v2/types/indexes/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def check_numeric(x) -> bool: 3 | """ 4 | Check if a value is numeric 5 | """ 6 | if isinstance(x, (int, float, complex)): return True 7 | if isinstance(x, str): 8 | try: 9 | float(x) 10 | return True 11 | except Exception as e: 12 | return False 13 | -------------------------------------------------------------------------------- /aiokeydb/backoff.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from redis.backoff import ( 4 | AbstractBackoff, 5 | ConstantBackoff, 6 | ExponentialBackoff, 7 | FullJitterBackoff, 8 | NoBackoff, 9 | EqualJitterBackoff, 10 | DecorrelatedJitterBackoff, 11 | ) 12 | 13 | def default_backoff() -> typing.Type[AbstractBackoff]: 14 | return EqualJitterBackoff() 15 | -------------------------------------------------------------------------------- /aiokeydb/v2/backoff.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from redis.backoff import ( 4 | AbstractBackoff, 5 | ConstantBackoff, 6 | ExponentialBackoff, 7 | FullJitterBackoff, 8 | NoBackoff, 9 | EqualJitterBackoff, 10 | DecorrelatedJitterBackoff, 11 | ) 12 | 13 | def default_backoff() -> typing.Type[AbstractBackoff]: 14 | return EqualJitterBackoff() 15 | -------------------------------------------------------------------------------- /aiokeydb/v1/queues/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.v1.queues.types import Job, JobStatus, TaskType 4 | from aiokeydb.v1.queues.queue import TaskQueue 5 | from aiokeydb.v1.queues.worker.base import Worker, WorkerTasks 6 | 7 | if hasattr(Job, 'model_rebuild'): 8 | Job.model_rebuild() 9 | else: 10 | Job.update_forward_refs(TaskQueue=TaskQueue) -------------------------------------------------------------------------------- /aiokeydb/v1/commands/search/document.py: -------------------------------------------------------------------------------- 1 | class Document: 2 | """ 3 | Represents a single document in a result set 4 | """ 5 | 6 | def __init__(self, id, payload=None, **fields): 7 | self.id = id 8 | self.payload = payload 9 | for k, v in fields.items(): 10 | setattr(self, k, v) 11 | 12 | def __repr__(self): 13 | return f"Document {self.__dict__}" 14 | -------------------------------------------------------------------------------- /aiokeydb/v2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | HIREDIS_AVAILABLE, 3 | HIREDIS_PACK_AVAILABLE, 4 | CRYPTOGRAPHY_AVAILABLE, 5 | str_if_bytes, 6 | safe_str, 7 | dict_merge, 8 | list_keys_to_dict, 9 | merge_result, 10 | from_url, 11 | pipeline, 12 | async_pipeline, 13 | get_ulimits, 14 | set_ulimits, 15 | full_name, 16 | args_to_key, 17 | import_string, 18 | ) -------------------------------------------------------------------------------- /aiokeydb/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | HIREDIS_AVAILABLE, 3 | HIREDIS_PACK_AVAILABLE, 4 | CRYPTOGRAPHY_AVAILABLE, 5 | str_if_bytes, 6 | safe_str, 7 | dict_merge, 8 | list_keys_to_dict, 9 | merge_result, 10 | from_url, 11 | pipeline, 12 | async_pipeline, 13 | get_ulimits, 14 | set_ulimits, 15 | full_name, 16 | args_to_key, 17 | import_string, 18 | ) 19 | 20 | from .lazy import get_keydb_settings -------------------------------------------------------------------------------- /aiokeydb/v1/commands/json/path.py: -------------------------------------------------------------------------------- 1 | class Path: 2 | """This class represents a path in a JSON value.""" 3 | 4 | strPath = "" 5 | 6 | @staticmethod 7 | def root_path(): 8 | """Return the root path's string representation.""" 9 | return "." 10 | 11 | def __init__(self, path): 12 | """Make a new path based on the string representation in `path`.""" 13 | self.strPath = path 14 | 15 | def __repr__(self): 16 | return self.strPath 17 | -------------------------------------------------------------------------------- /tests/test_keydb.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiokeydb 3 | 4 | 5 | keydb_uri = "keydb://localhost:6379/0" 6 | 7 | def sync_test(): 8 | keydb = aiokeydb.KeyDB.from_url(keydb_uri) 9 | keydb.set("foo", "bar") 10 | print(keydb.get("foo")) 11 | 12 | async def async_test(): 13 | keydb = aiokeydb.AsyncKeyDB.from_url(keydb_uri) 14 | await keydb.set("foo", "bar") 15 | print(await keydb.get("foo")) 16 | 17 | async def run_tests(): 18 | await async_test() 19 | sync_test() 20 | 21 | asyncio.run(run_tests()) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **cache** 2 | cache_** 3 | ***_cache 4 | ***_cache_* 5 | **cache/** 6 | *.DS_Store 7 | tests* 8 | __pycache__* 9 | *logs 10 | *dist 11 | *build 12 | **build.sh 13 | *test.py 14 | *.egg-info* 15 | *.vscode 16 | *test_ops 17 | **test 18 | **.ipynb** 19 | **__lazycls** 20 | **/authz 21 | mkdocs.yml 22 | *_docs* 23 | aiokeydb/lock/scripts.py 24 | aiokeydb/mutex/__init__.py 25 | aiokeydb/mutex/_keydb.py 26 | aiokeydb/mutex/errors.py 27 | aiokeydb/aiokeydb-py.code-workspace 28 | tests/private_* 29 | !tests/ 30 | aiokeydb/v2/types/dev.py 31 | tests/test_cron.py 32 | -------------------------------------------------------------------------------- /aiokeydb/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.configs.worker import KeyDBWorkerSettings 4 | from aiokeydb.configs.core import KeyDBSettings 5 | 6 | 7 | class ProxySettings: 8 | def __init__(self): 9 | self._settings = None 10 | 11 | def __getattr__(self, name): 12 | if self._settings is None: 13 | from aiokeydb.utils.lazy import get_keydb_settings 14 | self._settings = get_keydb_settings() 15 | return getattr(self._settings, name) 16 | 17 | settings: KeyDBSettings = ProxySettings() -------------------------------------------------------------------------------- /aiokeydb/v2/types/serializer.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Base Serializer Class that all other serializers should inherit from. 4 | """ 5 | 6 | import typing 7 | 8 | class BaseSerializer: 9 | 10 | @staticmethod 11 | def dumps(obj: typing.Any, **kwargs) -> bytes: 12 | """ 13 | Serialize the object to bytes 14 | """ 15 | raise NotImplementedError 16 | 17 | @staticmethod 18 | def loads(data: typing.Union[str, bytes, typing.Any], **kwargs) -> typing.Any: 19 | """ 20 | Deserialize the object from bytes 21 | """ 22 | raise NotImplementedError 23 | 24 | -------------------------------------------------------------------------------- /aiokeydb/v1/client/serializers/base.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Base Serializer Class that all other serializers should inherit from. 4 | """ 5 | 6 | import typing 7 | 8 | class BaseSerializer: 9 | 10 | @staticmethod 11 | def dumps(obj: typing.Any, **kwargs) -> bytes: 12 | """ 13 | Serialize the object to bytes 14 | """ 15 | raise NotImplementedError 16 | 17 | @staticmethod 18 | def loads(data: typing.Union[str, bytes, typing.Any], **kwargs) -> typing.Any: 19 | """ 20 | Deserialize the object from bytes 21 | """ 22 | raise NotImplementedError 23 | 24 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from aiokeydb import KeyDBClient 3 | 4 | keydb_uri = "keydb://.." 5 | 6 | KeyDBClient.init_session( 7 | uri = keydb_uri, 8 | ) 9 | 10 | @KeyDBClient.cachify() 11 | async def async_fibonacci(number: int): 12 | if number == 0: return 0 13 | elif number == 1: return 1 14 | return await async_fibonacci(number - 1) + await async_fibonacci(number - 2) 15 | 16 | 17 | @KeyDBClient.cachify() 18 | def fibonacci(number: int): 19 | if number == 0: return 0 20 | elif number == 1: return 1 21 | return fibonacci(number - 1) + fibonacci(number - 2) 22 | 23 | async def run_tests(): 24 | print(fibonacci(100)) 25 | print(await async_fibonacci(100)) 26 | 27 | asyncio.run(run_tests()) -------------------------------------------------------------------------------- /tests/test_task_queue.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from aiokeydb import KeyDBClient 3 | from aiokeydb.queues import TaskQueue, Worker 4 | from lazyops.utils import logger 5 | 6 | keydb_uri = "keydb://.." 7 | 8 | KeyDBClient.configure( 9 | url = keydb_uri, 10 | debug_enabled = True, 11 | ) 12 | 13 | @Worker.add_cronjob("*/1 * * * *") 14 | async def test_cron_task(*args, **kwargs): 15 | logger.info("Cron task ran") 16 | await asyncio.sleep(5) 17 | 18 | @Worker.add_function() 19 | async def test_task(*args, **kwargs): 20 | logger.info("Task ran") 21 | await asyncio.sleep(5) 22 | 23 | async def run_tests(): 24 | queue = TaskQueue("test_queue") 25 | worker = Worker(queue) 26 | await worker.start() 27 | 28 | asyncio.run(run_tests()) 29 | -------------------------------------------------------------------------------- /aiokeydb/v2/exceptions.py: -------------------------------------------------------------------------------- 1 | "Additional exceptions raised by the Redis client" 2 | 3 | from redis.exceptions import ( 4 | AuthenticationError, 5 | AuthenticationWrongNumberOfArgsError, 6 | BusyLoadingError, 7 | ChildDeadlockedError, 8 | ConnectionError, 9 | DataError, 10 | InvalidResponse, 11 | PubSubError, 12 | ReadOnlyError, 13 | RedisError, 14 | ResponseError, 15 | TimeoutError, 16 | WatchError, 17 | ) 18 | 19 | from typing import TYPE_CHECKING 20 | if TYPE_CHECKING: 21 | from aiokeydb.v2.types.jobs import Job 22 | 23 | class JobError(Exception): 24 | def __init__(self, job: 'Job'): 25 | super().__init__( 26 | f"Job {job.id} {job.status}\n\nThe above job failed with the following error:\n\n{job.error}" 27 | ) 28 | self.job = job -------------------------------------------------------------------------------- /aiokeydb/crc.py: -------------------------------------------------------------------------------- 1 | from binascii import crc_hqx 2 | 3 | from aiokeydb.typing import EncodedT 4 | 5 | # Redis Cluster's key space is divided into 16384 slots. 6 | # For more information see: https://github.com/redis/redis/issues/2576 7 | REDIS_CLUSTER_HASH_SLOTS = 16384 8 | 9 | __all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"] 10 | 11 | 12 | def key_slot(key: EncodedT, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: 13 | """Calculate key slot for a given key. 14 | See Keys distribution model in https://redis.io/topics/cluster-spec 15 | :param key - bytes 16 | :param bucket - int 17 | """ 18 | start = key.find(b"{") 19 | if start > -1: 20 | end = key.find(b"}", start + 1) 21 | if end > -1 and end != start + 1: 22 | key = key[start + 1 : end] 23 | return crc_hqx(key, 0) % bucket 24 | -------------------------------------------------------------------------------- /aiokeydb/v1/crc.py: -------------------------------------------------------------------------------- 1 | from binascii import crc_hqx 2 | 3 | from aiokeydb.v1.typing import EncodedT 4 | 5 | # Redis Cluster's key space is divided into 16384 slots. 6 | # For more information see: https://github.com/redis/redis/issues/2576 7 | REDIS_CLUSTER_HASH_SLOTS = 16384 8 | 9 | __all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"] 10 | 11 | 12 | def key_slot(key: EncodedT, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: 13 | """Calculate key slot for a given key. 14 | See Keys distribution model in https://redis.io/topics/cluster-spec 15 | :param key - bytes 16 | :param bucket - int 17 | """ 18 | start = key.find(b"{") 19 | if start > -1: 20 | end = key.find(b"}", start + 1) 21 | if end > -1 and end != start + 1: 22 | key = key[start + 1 : end] 23 | return crc_hqx(key, 0) % bucket 24 | -------------------------------------------------------------------------------- /aiokeydb/types/static.py: -------------------------------------------------------------------------------- 1 | 2 | import enum 3 | 4 | class JobStatus(str, enum.Enum): 5 | NEW = "new" 6 | DEFERRED = "deferred" 7 | QUEUED = "queued" 8 | ACTIVE = "active" 9 | ABORTED = "aborted" 10 | FAILED = "failed" 11 | COMPLETE = "complete" 12 | 13 | INCOMPLETE_STATUSES = {JobStatus.NEW, JobStatus.DEFERRED, JobStatus.QUEUED, JobStatus.ACTIVE} 14 | TERMINAL_STATUSES = {JobStatus.COMPLETE, JobStatus.FAILED, JobStatus.ABORTED} 15 | UNSUCCESSFUL_TERMINAL_STATUSES = TERMINAL_STATUSES - {JobStatus.COMPLETE} 16 | 17 | class TaskType(str, enum.Enum): 18 | """ 19 | The Type of Task 20 | for the worker 21 | """ 22 | default = "default" 23 | function = "function" 24 | cronjob = "cronjob" 25 | dependency = "dependency" 26 | context = "context" 27 | startup = "startup" 28 | shutdown = "shutdown" -------------------------------------------------------------------------------- /aiokeydb/v2/types/static.py: -------------------------------------------------------------------------------- 1 | 2 | import enum 3 | 4 | class JobStatus(str, enum.Enum): 5 | NEW = "new" 6 | DEFERRED = "deferred" 7 | QUEUED = "queued" 8 | ACTIVE = "active" 9 | ABORTED = "aborted" 10 | FAILED = "failed" 11 | COMPLETE = "complete" 12 | 13 | INCOMPLETE_STATUSES = {JobStatus.NEW, JobStatus.DEFERRED, JobStatus.QUEUED, JobStatus.ACTIVE} 14 | TERMINAL_STATUSES = {JobStatus.COMPLETE, JobStatus.FAILED, JobStatus.ABORTED} 15 | UNSUCCESSFUL_TERMINAL_STATUSES = TERMINAL_STATUSES - {JobStatus.COMPLETE} 16 | 17 | class TaskType(str, enum.Enum): 18 | """ 19 | The Type of Task 20 | for the worker 21 | """ 22 | default = "default" 23 | function = "function" 24 | cronjob = "cronjob" 25 | dependency = "dependency" 26 | context = "context" 27 | startup = "startup" 28 | shutdown = "shutdown" -------------------------------------------------------------------------------- /aiokeydb/exceptions.py: -------------------------------------------------------------------------------- 1 | "Additional exceptions raised by the Redis client" 2 | 3 | from redis.exceptions import ( 4 | AuthorizationError, 5 | AuthenticationError, 6 | AuthenticationWrongNumberOfArgsError, 7 | BusyLoadingError, 8 | ChildDeadlockedError, 9 | ConnectionError, 10 | DataError, 11 | InvalidResponse, 12 | PubSubError, 13 | ReadOnlyError, 14 | RedisError, 15 | ResponseError, 16 | TimeoutError, 17 | WatchError, 18 | NoScriptError, 19 | ) 20 | 21 | from typing import TYPE_CHECKING 22 | if TYPE_CHECKING: 23 | from aiokeydb.types.jobs import Job 24 | 25 | class JobError(Exception): 26 | def __init__(self, job: 'Job'): 27 | super().__init__( 28 | f"Job {job.id} {job.status}\n\nThe above job failed with the following error:\n\n{job.error}" 29 | ) 30 | self.job = job -------------------------------------------------------------------------------- /aiokeydb/v1/credentials.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | 4 | class CredentialProvider: 5 | """ 6 | Credentials Provider. 7 | """ 8 | 9 | def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: 10 | raise NotImplementedError("get_credentials must be implemented") 11 | 12 | 13 | class UsernamePasswordCredentialProvider(CredentialProvider): 14 | """ 15 | Simple implementation of CredentialProvider that just wraps static 16 | username and password. 17 | """ 18 | 19 | def __init__(self, username: Optional[str] = None, password: Optional[str] = None): 20 | self.username = username or "" 21 | self.password = password or "" 22 | 23 | def get_credentials(self): 24 | if self.username: 25 | return self.username, self.password 26 | return (self.password,) -------------------------------------------------------------------------------- /tests/test_multi_config.py: -------------------------------------------------------------------------------- 1 | from aiokeydb import KeyDBClient 2 | from lazyops.utils import logger 3 | 4 | default_uri = 'keydb://public.host.com:6379/0' 5 | keydb_dbs = { 6 | 'cache': { 7 | 'db_id': 1, 8 | }, 9 | 'db': { 10 | 'uri': 'keydb://127.0.0.1:6379/0', 11 | }, 12 | } 13 | 14 | KeyDBClient.configure( 15 | url = default_uri, 16 | debug_enabled = True, 17 | queue_db = 1, 18 | ) 19 | 20 | # now any sessions that are initialized will use the global settings 21 | 22 | sessions = {} 23 | # these will now be initialized 24 | 25 | # Initialize the first default session 26 | KeyDBClient.init_session() 27 | 28 | for name, config in keydb_dbs.items(): 29 | sessions[name] = KeyDBClient.init_session( 30 | name = name, 31 | **config 32 | ) 33 | logger.info(f'Session {name}: uri: {sessions[name].uri}') 34 | -------------------------------------------------------------------------------- /aiokeydb/v1/asyncio/utils.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | if TYPE_CHECKING: 4 | from aiokeydb.v1.asyncio.core import AsyncPipeline, AsyncKeyDB 5 | 6 | 7 | def async_from_url(url, **kwargs): 8 | """ 9 | Returns an active AsyncKeyDB client generated from the given database URL. 10 | 11 | Will attempt to extract the database id from the path url fragment, if 12 | none is provided. 13 | """ 14 | from aiokeydb.v1.asyncio.core import AsyncKeyDB 15 | 16 | return AsyncKeyDB.from_url(url, **kwargs) 17 | 18 | 19 | class async_pipeline: 20 | def __init__(self, keydb_obj: "AsyncKeyDB"): 21 | self.p: "AsyncPipeline" = keydb_obj.pipeline() 22 | 23 | async def __aenter__(self) -> "AsyncPipeline": 24 | return self.p 25 | 26 | async def __aexit__(self, exc_type, exc_value, traceback): 27 | await self.p.execute() 28 | del self.p 29 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.v1.commands.cluster import READ_COMMANDS, AsyncKeyDBClusterCommands, KeyDBClusterCommands 4 | from aiokeydb.v1.commands.core import AsyncCoreCommands, CoreCommands 5 | from aiokeydb.v1.commands.helpers import list_or_args 6 | from aiokeydb.v1.commands.parser import CommandsParser 7 | from aiokeydb.v1.commands.redismodules import AsyncRedisModuleCommands, RedisModuleCommands 8 | from aiokeydb.v1.commands.sentinel import AsyncSentinelCommands, SentinelCommands 9 | 10 | __all__ = [ 11 | "AsyncCoreCommands", 12 | "AsyncKeyDBClusterCommands", 13 | "AsyncRedisModuleCommands", 14 | "AsyncSentinelCommands", 15 | "CommandsParser", 16 | "CoreCommands", 17 | "READ_COMMANDS", 18 | "KeyDBClusterCommands", 19 | "RedisModuleCommands", 20 | "SentinelCommands", 21 | "list_or_args", 22 | ] 23 | -------------------------------------------------------------------------------- /aiokeydb/types/serializer.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Base Serializer Class that all other serializers should inherit from. 4 | """ 5 | 6 | import typing 7 | from types import ModuleType 8 | 9 | class BaseSerializer: 10 | 11 | @staticmethod 12 | def dumps(obj: typing.Any, **kwargs) -> bytes: 13 | """ 14 | Serialize the object to bytes 15 | """ 16 | raise NotImplementedError 17 | 18 | @staticmethod 19 | def loads(data: typing.Union[str, bytes, typing.Any], **kwargs) -> typing.Any: 20 | """ 21 | Deserialize the object from bytes 22 | """ 23 | raise NotImplementedError 24 | 25 | 26 | @staticmethod 27 | def register_module(module: ModuleType): 28 | """ 29 | Dummy method that should be overridden by serializers that support 30 | """ 31 | return 32 | 33 | @staticmethod 34 | def unregister_module(module: ModuleType): 35 | """ 36 | Dummy method that should be overridden by serializers that support 37 | """ 38 | return -------------------------------------------------------------------------------- /aiokeydb/v1/queues/imports/cron.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Import Handler for croniter 4 | """ 5 | 6 | from lazyops.utils import resolve_missing, require_missing_wrapper 7 | 8 | 9 | try: 10 | from croniter import croniter 11 | _croniter_available = True 12 | except ImportError: 13 | croniter = object 14 | _croniter_available = False 15 | 16 | 17 | def resolve_croniter( 18 | required: bool = False, 19 | ): 20 | """ 21 | Ensures that `croniter` is available 22 | """ 23 | global croniter, _croniter_available 24 | if not _croniter_available: 25 | resolve_missing('croniter', required = required) 26 | from croniter import croniter 27 | _croniter_available = True 28 | 29 | 30 | def require_croniter( 31 | required: bool = False, 32 | ): 33 | """ 34 | Wrapper for `resolve_croniter` that can be used as a decorator 35 | """ 36 | def decorator(func): 37 | return require_missing_wrapper( 38 | resolver = resolve_croniter, 39 | func = func, 40 | required = required 41 | ) 42 | return decorator -------------------------------------------------------------------------------- /aiokeydb/serializers/_msgpack.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the MsgPack class. 4 | 5 | - uses the `msgpack` module to serialize and deserialize data 6 | """ 7 | 8 | import typing 9 | 10 | from aiokeydb.types.serializer import BaseSerializer 11 | 12 | try: 13 | import msgpack 14 | _msgpack_avail = True 15 | except ImportError: 16 | msgpack = object 17 | _msgpack_avail = False 18 | 19 | if _msgpack_avail: 20 | 21 | class MsgPackSerializer(BaseSerializer): 22 | 23 | @staticmethod 24 | def dumps( 25 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 26 | *args, 27 | **kwargs 28 | ) -> typing.Union[bytes, str]: 29 | return msgpack.packb(obj, *args, **kwargs) 30 | 31 | @staticmethod 32 | def loads( 33 | data: typing.Union[str, bytes], 34 | *args, 35 | raw: bool = False, 36 | **kwargs 37 | ) -> typing.Any: 38 | return msgpack.unpackb(data, *args, raw = raw, **kwargs) 39 | 40 | else: 41 | # Fallback to JSON 42 | from aiokeydb.serializers._json import JsonSerializer 43 | 44 | MsgPackSerializer = JsonSerializer -------------------------------------------------------------------------------- /aiokeydb/v2/serializers/_msgpack.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the MsgPack class. 4 | 5 | - uses the `msgpack` module to serialize and deserialize data 6 | """ 7 | 8 | import typing 9 | 10 | from aiokeydb.v2.types import BaseSerializer 11 | 12 | try: 13 | import msgpack 14 | _msgpack_avail = True 15 | except ImportError: 16 | msgpack = object 17 | _msgpack_avail = False 18 | 19 | if _msgpack_avail: 20 | 21 | class MsgPackSerializer(BaseSerializer): 22 | 23 | @staticmethod 24 | def dumps( 25 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 26 | *args, 27 | **kwargs 28 | ) -> typing.Union[bytes, str]: 29 | return msgpack.packb(obj, *args, **kwargs) 30 | 31 | @staticmethod 32 | def loads( 33 | data: typing.Union[str, bytes], 34 | *args, 35 | raw: bool = False, 36 | **kwargs 37 | ) -> typing.Any: 38 | return msgpack.unpackb(data, *args, raw = raw, **kwargs) 39 | 40 | else: 41 | # Fallback to JSON 42 | from aiokeydb.v2.serializers._json import JsonSerializer 43 | 44 | MsgPackSerializer = JsonSerializer -------------------------------------------------------------------------------- /aiokeydb/v1/client/serializers/_msgpack.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the MsgPack class. 4 | 5 | - uses the `msgpack` module to serialize and deserialize data 6 | """ 7 | 8 | import typing 9 | 10 | from aiokeydb.v1.client.serializers.base import BaseSerializer 11 | 12 | try: 13 | import msgpack 14 | _msgpack_avail = True 15 | except ImportError: 16 | msgpack = object 17 | _msgpack_avail = False 18 | 19 | if _msgpack_avail: 20 | 21 | class MsgPackSerializer(BaseSerializer): 22 | 23 | @staticmethod 24 | def dumps( 25 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 26 | *args, 27 | **kwargs 28 | ) -> typing.Union[bytes, str]: 29 | return msgpack.packb(obj, *args, **kwargs) 30 | 31 | @staticmethod 32 | def loads( 33 | data: typing.Union[str, bytes], 34 | *args, 35 | raw: bool = False, 36 | **kwargs 37 | ) -> typing.Any: 38 | return msgpack.unpackb(data, *args, raw = raw, **kwargs) 39 | 40 | else: 41 | # Fallback to JSON 42 | from aiokeydb.v1.client.serializers._json import JsonSerializer 43 | 44 | MsgPackSerializer = JsonSerializer -------------------------------------------------------------------------------- /aiokeydb/v1/asyncio/client.py: -------------------------------------------------------------------------------- 1 | # Resolving the imports from previous versions 2 | from aiokeydb.v1.asyncio.connection import ( 3 | AsyncConnection, 4 | AsyncConnectionPool, 5 | AsyncSSLConnection, 6 | AsyncUnixDomainSocketConnection, 7 | ) 8 | from aiokeydb.v1.asyncio.lock import AsyncLock 9 | from aiokeydb.v1.commands import ( 10 | AsyncCoreCommands, 11 | RedisModuleCommands, 12 | AsyncSentinelCommands, 13 | list_or_args, 14 | ) 15 | from aiokeydb.v1.exceptions import ( 16 | ConnectionError, 17 | ExecAbortError, 18 | PubSubError, 19 | KeyDBError, 20 | ResponseError, 21 | TimeoutError, 22 | WatchError, 23 | ) 24 | 25 | from aiokeydb.v1.asyncio.core import ( 26 | ResponseCallbackProtocol, 27 | AsyncResponseCallbackProtocol, 28 | AsyncKeyDB, 29 | StrictAsyncKeyDB, 30 | AsyncPubSub, 31 | MonitorCommandInfo, 32 | AsyncMonitor, 33 | PubsubWorkerExceptionHandler, 34 | AsyncPubsubWorkerExceptionHandler, 35 | AsyncPipeline, 36 | ) 37 | 38 | from aiokeydb.v1.client.serializers import SerializerType, BaseSerializer 39 | from aiokeydb.v1.client.config import KeyDBSettings 40 | from aiokeydb.v1.client.schemas.session import KeyDBSession 41 | from aiokeydb.v1.client.core import KeyDBClient 42 | 43 | -------------------------------------------------------------------------------- /aiokeydb/v1/client/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | """ 4 | New Submodule that allows for 5 | easily initializing AsyncKeyDB and KeyDB 6 | for global usage 7 | """ 8 | 9 | # Resolve previous imports 10 | from aiokeydb.v1.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection 11 | from aiokeydb.v1.exceptions import ( 12 | ConnectionError, 13 | ExecAbortError, 14 | ModuleError, 15 | PubSubError, 16 | KeyDBError, 17 | ResponseError, 18 | TimeoutError, 19 | WatchError, 20 | ) 21 | from aiokeydb.v1.lock import Lock 22 | from aiokeydb.v1.core import ( 23 | AbstractKeyDB, 24 | KeyDB, 25 | StrictKeyDB, 26 | Monitor, 27 | PubSub, 28 | Pipeline, 29 | EMPTY_RESPONSE, 30 | NEVER_DECODE, 31 | CaseInsensitiveDict, 32 | bool_ok, 33 | ) 34 | 35 | # Add top level asyncio module imports 36 | from aiokeydb.v1.asyncio.lock import AsyncLock 37 | from aiokeydb.v1.asyncio.core import AsyncKeyDB, AsyncPubSub, AsyncPipeline 38 | 39 | from aiokeydb.v1.client.serializers import SerializerType, BaseSerializer 40 | from aiokeydb.v1.client.config import KeyDBSettings 41 | from aiokeydb.v1.client.schemas.session import KeyDBSession 42 | from aiokeydb.v1.client.meta import KeyDBClient 43 | 44 | -------------------------------------------------------------------------------- /aiokeydb/utils/lazy.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, Optional, TYPE_CHECKING 4 | 5 | if TYPE_CHECKING: 6 | from aiokeydb.configs.core import KeyDBSettings 7 | 8 | _keydb_settings: Optional['KeyDBSettings'] = None 9 | 10 | def get_keydb_settings(**kwargs) -> KeyDBSettings: 11 | """ 12 | Get the current KeyDB settings 13 | """ 14 | global _keydb_settings 15 | if _keydb_settings is None: 16 | from aiokeydb.configs.core import KeyDBSettings 17 | _keydb_settings = KeyDBSettings(**kwargs) 18 | return _keydb_settings 19 | 20 | def get_default_job_timeout() -> int: 21 | """ 22 | Get the default job timeout 23 | """ 24 | return get_keydb_settings().get_default_job_timeout() 25 | 26 | def get_default_job_ttl() -> int: 27 | """ 28 | Get the default job ttl 29 | """ 30 | return get_keydb_settings().get_default_job_ttl() 31 | 32 | 33 | def get_default_job_retries() -> int: 34 | """ 35 | Get the default job retries 36 | """ 37 | return get_keydb_settings().get_default_job_retries() 38 | 39 | 40 | def get_default_job_retry_delay() -> int: 41 | """ 42 | Get the default job retry delay 43 | """ 44 | return get_keydb_settings().get_default_job_retry_delay() 45 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | push: 13 | paths: 14 | # - 'setup.py' 15 | - 'aiokeydb/version.py' 16 | branches: 17 | - main 18 | release: 19 | types: [created] 20 | 21 | jobs: 22 | deploy: 23 | 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Python 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: '3.x' 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install build 36 | - name: Build package 37 | run: python -m build 38 | - name: Publish package 39 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 40 | with: 41 | user: __token__ 42 | password: ${{ secrets.PYPI_API_TOKEN }} 43 | -------------------------------------------------------------------------------- /aiokeydb/v1/queues/README.md: -------------------------------------------------------------------------------- 1 | # KeyDB Worker Queues 2 | 3 | KeyDB Worker Queues is a simple, fast, and reliable queue system for KeyDB. It is designed to be used in a distributed environment, where multiple KeyDB instances are used to process jobs. It is also designed to be used in a single instance environment, where a single KeyDB instance is used to process jobs. 4 | 5 | ```python 6 | import asyncio 7 | from aiokeydb import KeyDBClient 8 | from aiokeydb.queues import TaskQueue, Worker 9 | from lazyops.utils import logger 10 | 11 | 12 | # Configure the KeyDB Client - the default keydb client will use 13 | # db = 0, and queue uses 2 so that it doesn't conflict with other 14 | # by configuring it here, you can explicitly set the db to use 15 | keydb_uri = "keydb://127.0.0.1:6379/0" 16 | 17 | # Configure the Queue to use db = 1 instead of 2 18 | KeyDBClient.configure( 19 | url = keydb_uri, 20 | debug_enabled = True, 21 | queue_db = 1, 22 | ) 23 | 24 | @Worker.add_cronjob("*/1 * * * *") 25 | async def test_cron_task(*args, **kwargs): 26 | logger.info("Cron task ran") 27 | await asyncio.sleep(5) 28 | 29 | @Worker.add_function() 30 | async def test_task(*args, **kwargs): 31 | logger.info("Task ran") 32 | await asyncio.sleep(5) 33 | 34 | async def run_tests(): 35 | queue = TaskQueue("test_queue") 36 | worker = Worker(queue) 37 | await worker.start() 38 | 39 | asyncio.run(run_tests()) 40 | 41 | ``` -------------------------------------------------------------------------------- /aiokeydb/v2/serializers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from enum import Enum 4 | from typing import Type 5 | from aiokeydb.v2.types import BaseSerializer 6 | from aiokeydb.v2.serializers._json import JsonSerializer, OrJsonSerializer 7 | from aiokeydb.v2.serializers._pickle import PickleSerializer, DillSerializer 8 | from aiokeydb.v2.serializers._msgpack import MsgPackSerializer 9 | 10 | class SerializerType(str, Enum): 11 | """ 12 | Enum for the available serializers 13 | """ 14 | json = 'json' 15 | orjson = 'orjson' 16 | pickle = 'pickle' 17 | dill = 'dill' 18 | msgpack = 'msgpack' 19 | default = 'default' 20 | 21 | def get_serializer(self) -> Type[BaseSerializer]: 22 | """ 23 | Default Serializer = Dill 24 | """ 25 | 26 | if self == SerializerType.json: 27 | return JsonSerializer 28 | elif self == SerializerType.orjson: 29 | return OrJsonSerializer 30 | elif self == SerializerType.pickle: 31 | return PickleSerializer 32 | elif self == SerializerType.dill: 33 | return DillSerializer 34 | elif self == SerializerType.msgpack: 35 | return MsgPackSerializer 36 | elif self == SerializerType.default: 37 | return DillSerializer 38 | else: 39 | raise ValueError(f'Invalid serializer type: {self}') 40 | 41 | 42 | -------------------------------------------------------------------------------- /aiokeydb/v1/client/serializers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from enum import Enum 4 | from typing import Type 5 | from aiokeydb.v1.client.serializers.base import BaseSerializer 6 | from aiokeydb.v1.client.serializers._json import JsonSerializer, OrJsonSerializer 7 | from aiokeydb.v1.client.serializers._pickle import PickleSerializer, DillSerializer 8 | from aiokeydb.v1.client.serializers._msgpack import MsgPackSerializer 9 | 10 | class SerializerType(str, Enum): 11 | """ 12 | Enum for the available serializers 13 | """ 14 | json = 'json' 15 | orjson = 'orjson' 16 | pickle = 'pickle' 17 | dill = 'dill' 18 | msgpack = 'msgpack' 19 | default = 'default' 20 | 21 | def get_serializer(self) -> Type[BaseSerializer]: 22 | """ 23 | Default Serializer = Dill 24 | """ 25 | 26 | if self == SerializerType.json: 27 | return JsonSerializer 28 | elif self == SerializerType.orjson: 29 | return OrJsonSerializer 30 | elif self == SerializerType.pickle: 31 | return PickleSerializer 32 | elif self == SerializerType.dill: 33 | return DillSerializer 34 | elif self == SerializerType.msgpack: 35 | return MsgPackSerializer 36 | elif self == SerializerType.default: 37 | return DillSerializer 38 | else: 39 | raise ValueError(f'Invalid serializer type: {self}') 40 | 41 | 42 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/timeseries/utils.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.helpers import nativestr 2 | 3 | 4 | def list_to_dict(aList): 5 | return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} 6 | 7 | 8 | def parse_range(response): 9 | """Parse range response. Used by TS.RANGE and TS.REVRANGE.""" 10 | return [tuple((r[0], float(r[1]))) for r in response] 11 | 12 | 13 | def parse_m_range(response): 14 | """Parse multi range response. Used by TS.MRANGE and TS.MREVRANGE.""" 15 | res = [] 16 | for item in response: 17 | res.append({nativestr(item[0]): [list_to_dict(item[1]), parse_range(item[2])]}) 18 | return sorted(res, key=lambda d: list(d.keys())) 19 | 20 | 21 | def parse_get(response): 22 | """Parse get response. Used by TS.GET.""" 23 | if not response: 24 | return None 25 | return int(response[0]), float(response[1]) 26 | 27 | 28 | def parse_m_get(response): 29 | """Parse multi get response. Used by TS.MGET.""" 30 | res = [] 31 | for item in response: 32 | if not item[2]: 33 | res.append({nativestr(item[0]): [list_to_dict(item[1]), None, None]}) 34 | else: 35 | res.append( 36 | { 37 | nativestr(item[0]): [ 38 | list_to_dict(item[1]), 39 | int(item[2][0]), 40 | float(item[2][1]), 41 | ] 42 | } 43 | ) 44 | return sorted(res, key=lambda d: list(d.keys())) 45 | -------------------------------------------------------------------------------- /aiokeydb/commands/__init__.py: -------------------------------------------------------------------------------- 1 | import redis.commands.cluster 2 | import redis.commands.redismodules 3 | 4 | import redis.commands.bf as bf 5 | import redis.commands.graph as graph 6 | import redis.commands.json as json 7 | import redis.commands.search as search 8 | import redis.commands.timeseries as timeseries 9 | 10 | 11 | from redis.commands.cluster import RedisClusterCommands 12 | from redis.commands.core import AsyncCoreCommands, CoreCommands, AsyncScript 13 | from redis.commands.helpers import list_or_args 14 | from redis.commands.parser import CommandsParser 15 | from redis.commands.redismodules import RedisModuleCommands 16 | from redis.commands.sentinel import AsyncSentinelCommands, SentinelCommands 17 | 18 | __all__ = [ 19 | "AsyncCoreCommands", 20 | # "AsyncRedisClusterCommands", 21 | # "AsyncRedisModuleCommands", 22 | "AsyncSentinelCommands", 23 | "CommandsParser", 24 | "CoreCommands", 25 | # "READ_COMMANDS", 26 | "RedisClusterCommands", 27 | "RedisModuleCommands", 28 | "SentinelCommands", 29 | "list_or_args", 30 | "AsyncScript", 31 | ] 32 | 33 | 34 | if hasattr(redis.commands.cluster, "READ_COMMANDS"): 35 | READ_COMMANDS = redis.commands.cluster.READ_COMMANDS 36 | __all__ += ["READ_COMMANDS"] 37 | if hasattr(redis.commands.cluster, "AsyncRedisClusterCommands"): 38 | AsyncRedisClusterCommands = redis.commands.cluster.AsyncRedisClusterCommands 39 | __all__ += ["AsyncRedisClusterCommands"] 40 | if hasattr(redis.commands.redismodules, "AsyncRedisModuleCommands"): 41 | AsyncRedisModuleCommands = redis.commands.redismodules.AsyncRedisModuleCommands 42 | __all__ += ["AsyncRedisModuleCommands"] 43 | -------------------------------------------------------------------------------- /aiokeydb/v2/commands/__init__.py: -------------------------------------------------------------------------------- 1 | import redis.commands.cluster 2 | import redis.commands.redismodules 3 | 4 | import redis.commands.bf as bf 5 | import redis.commands.graph as graph 6 | import redis.commands.json as json 7 | import redis.commands.search as search 8 | import redis.commands.timeseries as timeseries 9 | 10 | 11 | from redis.commands.cluster import RedisClusterCommands 12 | from redis.commands.core import AsyncCoreCommands, CoreCommands, AsyncScript 13 | from redis.commands.helpers import list_or_args 14 | from redis.commands.parser import CommandsParser 15 | from redis.commands.redismodules import RedisModuleCommands 16 | from redis.commands.sentinel import AsyncSentinelCommands, SentinelCommands 17 | 18 | __all__ = [ 19 | "AsyncCoreCommands", 20 | # "AsyncRedisClusterCommands", 21 | # "AsyncRedisModuleCommands", 22 | "AsyncSentinelCommands", 23 | "CommandsParser", 24 | "CoreCommands", 25 | # "READ_COMMANDS", 26 | "RedisClusterCommands", 27 | "RedisModuleCommands", 28 | "SentinelCommands", 29 | "list_or_args", 30 | "AsyncScript", 31 | ] 32 | 33 | 34 | if hasattr(redis.commands.cluster, "READ_COMMANDS"): 35 | READ_COMMANDS = redis.commands.cluster.READ_COMMANDS 36 | __all__ += ["READ_COMMANDS"] 37 | if hasattr(redis.commands.cluster, "AsyncRedisClusterCommands"): 38 | AsyncRedisClusterCommands = redis.commands.cluster.AsyncRedisClusterCommands 39 | __all__ += ["AsyncRedisClusterCommands"] 40 | if hasattr(redis.commands.redismodules, "AsyncRedisModuleCommands"): 41 | AsyncRedisModuleCommands = redis.commands.redismodules.AsyncRedisModuleCommands 42 | __all__ += ["AsyncRedisModuleCommands"] 43 | -------------------------------------------------------------------------------- /tests/test_cron_v2.py: -------------------------------------------------------------------------------- 1 | 2 | import pytz 3 | import random 4 | import datetime 5 | import time 6 | import croniter 7 | from aiokeydb.v2.utils.cron import validate_cron_schedule 8 | 9 | 10 | tz = pytz.timezone('US/Central') 11 | local_date = tz.localize(datetime.datetime.now()) 12 | 13 | 14 | test_patterns = [ 15 | 'every {n} minutes', 16 | 'every {n} minutes and 10 seconds', 17 | 'every {n} minutes, 10 seconds', 18 | '{n} minutes and 10 seconds', 19 | '{n} minutes, 10 seconds', 20 | 'every {n} hours and 10 minutes and 10 seconds', 21 | 'every {n} hours, 10 minutes, 10 seconds', 22 | 'every {n} hours, 10 minutes and 10 seconds', 23 | 'every {n} hours and 10 minutes', 24 | 'every {n} hours', 25 | 'every {n} days and 10 hours and 10 minutes and 10 seconds', 26 | 'every {n} seconds', 27 | '{n} seconds', 28 | 'every 30 seconds', 29 | 'every 2 minutes', 30 | '5 s', 31 | '10 min', 32 | '1 hr', 33 | ] 34 | 35 | 36 | for pattern in test_patterns: 37 | pattern = pattern.format(n=random.randint(1, 15)) 38 | try: 39 | cron_expression = validate_cron_schedule(pattern) 40 | now = time.time() 41 | is_valid = croniter.croniter.is_valid(cron_expression) 42 | next_time = croniter.croniter(cron_expression, now).get_next() 43 | next_date = croniter.croniter(cron_expression, local_date).get_next(datetime.datetime) 44 | print(f"Pattern: {pattern}\nCron Expression: {cron_expression}\nValid: {is_valid}\nNext Time: {next_time} ({next_time - now:2f} s)\nNext Date: {next_date}") 45 | except ValueError as e: 46 | print(f"Pattern: {pattern}\nError: {str(e)}\n") 47 | print() 48 | 49 | -------------------------------------------------------------------------------- /aiokeydb/v1/client/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib, inspect 2 | from typing import Optional, List 3 | from aiokeydb.v1.client.types import ENOVAL 4 | 5 | __all__ = [ 6 | 'full_name', 7 | 'args_to_key' 8 | ] 9 | 10 | def full_name(func, follow_wrapper_chains=True): 11 | """ 12 | Return full name of `func` by adding the module and function name. 13 | 14 | If this function is decorated, attempt to unwrap it till the original function to use that 15 | function name by setting `follow_wrapper_chains` to True. 16 | """ 17 | if follow_wrapper_chains: func = inspect.unwrap(func) 18 | return f'{func.__module__}.{func.__qualname__}' 19 | 20 | def args_to_key( 21 | base, 22 | args: Optional[tuple] = None, 23 | kwargs: Optional[dict] = None, 24 | typed: bool = False, 25 | exclude: Optional[List[str]] = None 26 | ): 27 | """Create cache key out of function arguments. 28 | :param tuple base: base of key 29 | :param tuple args: function arguments 30 | :param dict kwargs: function keyword arguments 31 | :param bool typed: include types in cache key 32 | :return: cache key tuple 33 | """ 34 | key = base + args 35 | 36 | if kwargs: 37 | if exclude: kwargs = {k: v for k, v in kwargs.items() if k not in exclude} 38 | key += (ENOVAL,) 39 | sorted_items = sorted(kwargs.items()) 40 | 41 | for item in sorted_items: 42 | key += item 43 | 44 | if typed: 45 | key += tuple(type(arg) for arg in args) 46 | if kwargs: key += tuple(type(value) for _, value in sorted_items) 47 | 48 | cache_key = ':'.join(str(k) for k in key) 49 | return hashlib.md5(cache_key.encode()).hexdigest() 50 | -------------------------------------------------------------------------------- /aiokeydb/v2/serializers/_pickle.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the PickleSerializer class. 4 | 5 | - uses the pickle module to serialize and deserialize data 6 | - will use `dill` if it is installed 7 | """ 8 | import sys 9 | import pickle 10 | import typing 11 | import contextlib 12 | from aiokeydb.v2.types import BaseSerializer 13 | 14 | 15 | if sys.version_info.minor < 8: 16 | with contextlib.suppress(ImportError): 17 | import pickle5 as pickle 18 | 19 | try: 20 | import dill 21 | _dill_avail = True 22 | except ImportError: 23 | dill = object 24 | _dill_avail = False 25 | 26 | 27 | class DefaultProtocols: 28 | default: int = 4 29 | pickle: int = pickle.HIGHEST_PROTOCOL 30 | dill: int = dill.HIGHEST_PROTOCOL 31 | 32 | class PickleSerializer(BaseSerializer): 33 | 34 | @staticmethod 35 | def dumps(obj: typing.Any, protocol: int = DefaultProtocols.pickle, *args, **kwargs) -> bytes: 36 | return pickle.dumps(obj, protocol = protocol, *args, **kwargs) 37 | 38 | @staticmethod 39 | def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any: 40 | return pickle.loads(data, *args, **kwargs) 41 | 42 | if _dill_avail: 43 | class DillSerializer(BaseSerializer): 44 | 45 | @staticmethod 46 | def dumps(obj: typing.Any, protocol: int = DefaultProtocols.dill, *args, **kwargs) -> bytes: 47 | return dill.dumps(obj, protocol = protocol, *args, **kwargs) 48 | 49 | @staticmethod 50 | def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any: 51 | return dill.loads(data, *args, **kwargs) 52 | 53 | else: 54 | DillSerializer = PickleSerializer 55 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/json/decoders.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | 4 | from aiokeydb.v1.commands.helpers import nativestr 5 | 6 | 7 | def bulk_of_jsons(d): 8 | """Replace serialized JSON values with objects in a 9 | bulk array response (list). 10 | """ 11 | 12 | def _f(b): 13 | for index, item in enumerate(b): 14 | if item is not None: 15 | b[index] = d(item) 16 | return b 17 | 18 | return _f 19 | 20 | 21 | def decode_dict_keys(obj): 22 | """Decode the keys of the given dictionary with utf-8.""" 23 | newobj = copy.copy(obj) 24 | for k in obj.keys(): 25 | if isinstance(k, bytes): 26 | newobj[k.decode("utf-8")] = newobj[k] 27 | newobj.pop(k) 28 | return newobj 29 | 30 | 31 | def unstring(obj): 32 | """ 33 | Attempt to parse string to native integer formats. 34 | One can't simply call int/float in a try/catch because there is a 35 | semantic difference between (for example) 15.0 and 15. 36 | """ 37 | floatreg = "^\\d+.\\d+$" 38 | match = re.findall(floatreg, obj) 39 | if match != []: 40 | return float(match[0]) 41 | 42 | intreg = "^\\d+$" 43 | match = re.findall(intreg, obj) 44 | if match != []: 45 | return int(match[0]) 46 | return obj 47 | 48 | 49 | def decode_list(b): 50 | """ 51 | Given a non-deserializable object, make a best effort to 52 | return a useful set of results. 53 | """ 54 | if isinstance(b, list): 55 | return [nativestr(obj) for obj in b] 56 | elif isinstance(b, bytes): 57 | return unstring(nativestr(b)) 58 | elif isinstance(b, str): 59 | return unstring(b) 60 | return b 61 | -------------------------------------------------------------------------------- /aiokeydb/v1/client/serializers/_pickle.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the PickleSerializer class. 4 | 5 | - uses the pickle module to serialize and deserialize data 6 | - will use `dill` if it is installed 7 | """ 8 | import sys 9 | import pickle 10 | import typing 11 | import contextlib 12 | 13 | if sys.version_info.minor < 8: 14 | with contextlib.suppress(ImportError): 15 | import pickle5 as pickle 16 | 17 | try: 18 | import dill 19 | _dill_avail = True 20 | except ImportError: 21 | dill = object 22 | _dill_avail = False 23 | 24 | from aiokeydb.v1.client.serializers.base import BaseSerializer 25 | 26 | class DefaultProtocols: 27 | default: int = 4 28 | pickle: int = pickle.HIGHEST_PROTOCOL 29 | dill: int = dill.HIGHEST_PROTOCOL 30 | 31 | class PickleSerializer(BaseSerializer): 32 | 33 | @staticmethod 34 | def dumps(obj: typing.Any, protocol: int = DefaultProtocols.pickle, *args, **kwargs) -> bytes: 35 | return pickle.dumps(obj, protocol = protocol, *args, **kwargs) 36 | 37 | @staticmethod 38 | def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any: 39 | return pickle.loads(data, *args, **kwargs) 40 | 41 | if _dill_avail: 42 | class DillSerializer(BaseSerializer): 43 | 44 | @staticmethod 45 | def dumps(obj: typing.Any, protocol: int = DefaultProtocols.dill, *args, **kwargs) -> bytes: 46 | return dill.dumps(obj, protocol = protocol, *args, **kwargs) 47 | 48 | @staticmethod 49 | def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any: 50 | return dill.loads(data, *args, **kwargs) 51 | 52 | else: 53 | DillSerializer = PickleSerializer 54 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/search/suggestion.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.search._util import to_string 2 | 3 | 4 | class Suggestion: 5 | """ 6 | Represents a single suggestion being sent or returned from the 7 | autocomplete server 8 | """ 9 | 10 | def __init__(self, string, score=1.0, payload=None): 11 | self.string = to_string(string) 12 | self.payload = to_string(payload) 13 | self.score = score 14 | 15 | def __repr__(self): 16 | return self.string 17 | 18 | 19 | class SuggestionParser: 20 | """ 21 | Internal class used to parse results from the `SUGGET` command. 22 | This needs to consume either 1, 2, or 3 values at a time from 23 | the return value depending on what objects were requested 24 | """ 25 | 26 | def __init__(self, with_scores, with_payloads, ret): 27 | self.with_scores = with_scores 28 | self.with_payloads = with_payloads 29 | 30 | if with_scores and with_payloads: 31 | self.sugsize = 3 32 | self._scoreidx = 1 33 | self._payloadidx = 2 34 | elif with_scores: 35 | self.sugsize = 2 36 | self._scoreidx = 1 37 | elif with_payloads: 38 | self.sugsize = 2 39 | self._payloadidx = 1 40 | else: 41 | self.sugsize = 1 42 | self._scoreidx = -1 43 | 44 | self._sugs = ret 45 | 46 | def __iter__(self): 47 | for i in range(0, len(self._sugs), self.sugsize): 48 | ss = self._sugs[i] 49 | score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0 50 | payload = self._sugs[i + self._payloadidx] if self.with_payloads else None 51 | yield Suggestion(ss, score, payload) 52 | -------------------------------------------------------------------------------- /aiokeydb/serializers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from enum import Enum 4 | from typing import Type, TypeVar 5 | from aiokeydb.types import BaseSerializer 6 | from aiokeydb.serializers._json import JsonSerializer, OrJsonSerializer 7 | from aiokeydb.serializers._pickle import PickleSerializer, DillSerializer, DillSerializerv2, PickleSerializerv2, CloudPickleSerializer 8 | from aiokeydb.serializers._msgpack import MsgPackSerializer 9 | 10 | 11 | SerializerT = TypeVar('SerializerT', bound = Type[BaseSerializer]) 12 | 13 | class SerializerType(str, Enum): 14 | """ 15 | Enum for the available serializers 16 | """ 17 | json = 'json' 18 | orjson = 'orjson' 19 | pickle = 'pickle' 20 | dill = 'dill' 21 | msgpack = 'msgpack' 22 | default = 'default' 23 | 24 | picklev2 = 'picklev2' 25 | dillv2 = 'dillv2' 26 | 27 | cloudpickle = 'cloudpickle' 28 | 29 | def get_serializer(self) -> SerializerT: 30 | """ 31 | Default Serializer = Dill 32 | """ 33 | 34 | if self == SerializerType.json: 35 | return JsonSerializer 36 | elif self == SerializerType.orjson: 37 | return OrJsonSerializer 38 | elif self == SerializerType.pickle: 39 | return PickleSerializer 40 | elif self == SerializerType.dill: 41 | return DillSerializer 42 | elif self == SerializerType.picklev2: 43 | return PickleSerializerv2 44 | elif self == SerializerType.dillv2: 45 | return DillSerializerv2 46 | elif self == SerializerType.cloudpickle: 47 | return CloudPickleSerializer 48 | elif self == SerializerType.msgpack: 49 | return MsgPackSerializer 50 | elif self == SerializerType.default: 51 | return DillSerializer 52 | else: 53 | raise ValueError(f'Invalid serializer type: {self}') 54 | 55 | 56 | -------------------------------------------------------------------------------- /tests/test_dict.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import asyncio 4 | import uuid 5 | from aiokeydb import KeyDBClient 6 | from aiokeydb.types.cachify import CachifyKwargs 7 | from aiokeydb.types.jobs import FunctionTracker, Job 8 | from lazyops.types.models import BaseModel, Field 9 | from typing import ClassVar 10 | 11 | # The session can be explicitly initialized, or 12 | # will be lazily initialized on first use 13 | # through environment variables with all 14 | # params being prefixed with `KEYDB_` 15 | 16 | keydb_uri = "keydb://localhost:6379/0" 17 | 18 | # Initialize the Session 19 | session = KeyDBClient.init_session(uri = keydb_uri) 20 | 21 | class DummyObject(BaseModel): 22 | key1: float # = Field(default_factory=time.time) 23 | 24 | data_dict = { 25 | "key1": "value1", 26 | "key2": 234, 27 | "key3": CachifyKwargs(), 28 | "key4": FunctionTracker(function = 'test'), 29 | "key5": Job(function = 'test'), 30 | # DummyObject(key1 = time.time()), 31 | } 32 | 33 | 34 | def test_dict(): 35 | 36 | for key, value in data_dict.items(): 37 | session[key] = value 38 | 39 | for key, value in data_dict.items(): 40 | assert session[key] == value 41 | assert key in session 42 | 43 | for key, value in data_dict.items(): 44 | del session[key] 45 | 46 | 47 | async def test_async_dict(): 48 | 49 | session.configure_dict_methods(async_enabled = True) 50 | 51 | for key, value in data_dict.items(): 52 | session[key] = value 53 | 54 | for key, value in data_dict.items(): 55 | stored_value = await session[key] 56 | assert stored_value == value 57 | assert key in session 58 | print(stored_value) 59 | 60 | for key, value in data_dict.items(): 61 | del session[key] 62 | 63 | async def run_tests(): 64 | 65 | test_dict() 66 | await test_async_dict() 67 | 68 | if __name__ == "__main__": 69 | asyncio.run(run_tests()) 70 | 71 | 72 | -------------------------------------------------------------------------------- /aiokeydb/retry.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from time import sleep 3 | 4 | from aiokeydb.exceptions import ConnectionError, TimeoutError 5 | 6 | 7 | class Retry: 8 | """Retry a specific number of times after a failure""" 9 | 10 | def __init__( 11 | self, 12 | backoff, 13 | retries, 14 | supported_errors=(ConnectionError, TimeoutError, socket.timeout), 15 | ): 16 | """ 17 | Initialize a `Retry` object with a `Backoff` object 18 | that retries a maximum of `retries` times. 19 | `retries` can be negative to retry forever. 20 | You can specify the types of supported errors which trigger 21 | a retry with the `supported_errors` parameter. 22 | """ 23 | self._backoff = backoff 24 | self._retries = retries 25 | self._supported_errors = supported_errors 26 | 27 | def update_supported_errors(self, specified_errors: list): 28 | """ 29 | Updates the supported errors with the specified error types 30 | """ 31 | self._supported_errors = tuple( 32 | set(self._supported_errors + tuple(specified_errors)) 33 | ) 34 | 35 | def call_with_retry(self, do, fail): 36 | """ 37 | Execute an operation that might fail and returns its result, or 38 | raise the exception that was thrown depending on the `Backoff` object. 39 | `do`: the operation to call. Expects no argument. 40 | `fail`: the failure handler, expects the last error that was thrown 41 | """ 42 | self._backoff.reset() 43 | failures = 0 44 | while True: 45 | try: 46 | return do() 47 | except self._supported_errors as error: 48 | failures += 1 49 | fail(error) 50 | if self._retries >= 0 and failures > self._retries: 51 | raise error 52 | backoff = self._backoff.compute(failures) 53 | if backoff > 0: 54 | sleep(backoff) 55 | -------------------------------------------------------------------------------- /aiokeydb/v1/retry.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from time import sleep 3 | 4 | from aiokeydb.v1.exceptions import ConnectionError, TimeoutError 5 | 6 | 7 | class Retry: 8 | """Retry a specific number of times after a failure""" 9 | 10 | def __init__( 11 | self, 12 | backoff, 13 | retries, 14 | supported_errors=(ConnectionError, TimeoutError, socket.timeout), 15 | ): 16 | """ 17 | Initialize a `Retry` object with a `Backoff` object 18 | that retries a maximum of `retries` times. 19 | `retries` can be negative to retry forever. 20 | You can specify the types of supported errors which trigger 21 | a retry with the `supported_errors` parameter. 22 | """ 23 | self._backoff = backoff 24 | self._retries = retries 25 | self._supported_errors = supported_errors 26 | 27 | def update_supported_errors(self, specified_errors: list): 28 | """ 29 | Updates the supported errors with the specified error types 30 | """ 31 | self._supported_errors = tuple( 32 | set(self._supported_errors + tuple(specified_errors)) 33 | ) 34 | 35 | def call_with_retry(self, do, fail): 36 | """ 37 | Execute an operation that might fail and returns its result, or 38 | raise the exception that was thrown depending on the `Backoff` object. 39 | `do`: the operation to call. Expects no argument. 40 | `fail`: the failure handler, expects the last error that was thrown 41 | """ 42 | self._backoff.reset() 43 | failures = 0 44 | while True: 45 | try: 46 | return do() 47 | except self._supported_errors as error: 48 | failures += 1 49 | fail(error) 50 | if self._retries >= 0 and failures > self._retries: 51 | raise error 52 | backoff = self._backoff.compute(failures) 53 | if backoff > 0: 54 | sleep(backoff) 55 | -------------------------------------------------------------------------------- /aiokeydb/v1/asyncio/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # Keep naming convention to explicitly include Async 4 | # to avoid confusion with the builtin sync Clients / modules 5 | from aiokeydb.v1.asyncio.core import AsyncKeyDB, StrictAsyncKeyDB 6 | from aiokeydb.v1.asyncio.cluster import AsyncKeyDBCluster 7 | from aiokeydb.v1.asyncio.connection import ( 8 | AsyncBlockingConnectionPool, 9 | AsyncConnection, 10 | AsyncConnectionPool, 11 | AsyncSSLConnection, 12 | AsyncUnixDomainSocketConnection, 13 | ) 14 | 15 | from aiokeydb.v1.asyncio.parser import CommandsParser 16 | from aiokeydb.v1.asyncio.sentinel import ( 17 | AsyncSentinel, 18 | AsyncSentinelConnectionPool, 19 | AsyncSentinelManagedConnection, 20 | AsyncSentinelManagedSSLConnection, 21 | ) 22 | from aiokeydb.v1.asyncio.utils import async_from_url 23 | from aiokeydb.v1.exceptions import ( 24 | AuthenticationError, 25 | AuthenticationWrongNumberOfArgsError, 26 | BusyLoadingError, 27 | ChildDeadlockedError, 28 | ConnectionError, 29 | DataError, 30 | InvalidResponse, 31 | PubSubError, 32 | ReadOnlyError, 33 | KeyDBError, 34 | ResponseError, 35 | TimeoutError, 36 | WatchError, 37 | ) 38 | 39 | 40 | __all__ = [ 41 | "AuthenticationError", 42 | "AuthenticationWrongNumberOfArgsError", 43 | "AsyncBlockingConnectionPool", 44 | "BusyLoadingError", 45 | "ChildDeadlockedError", 46 | "CommandsParser", 47 | "AsyncConnection", 48 | "ConnectionError", 49 | "AsyncConnectionPool", 50 | "DataError", 51 | "async_from_url", 52 | "InvalidResponse", 53 | "PubSubError", 54 | "ReadOnlyError", 55 | "AsyncKeyDB", 56 | "AsyncKeyDBCluster", 57 | "KeyDBError", 58 | "ResponseError", 59 | "AsyncSentinel", 60 | "AsyncSentinelConnectionPool", 61 | "AsyncSentinelManagedConnection", 62 | "AsyncSentinelManagedSSLConnection", 63 | "AsyncSSLConnection", 64 | "StrictAsyncKeyDB", 65 | "TimeoutError", 66 | "AsyncUnixDomainSocketConnection", 67 | "WatchError", 68 | ] 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from setuptools import setup, find_packages 4 | 5 | if sys.version_info.major != 3: 6 | raise RuntimeError("This package requires Python 3+") 7 | 8 | pkg_name = 'aiokeydb' 9 | gitrepo = 'trisongz/aiokeydb-py' 10 | root = Path(__file__).parent 11 | version = root.joinpath('aiokeydb/version.py').read_text().split('VERSION = ', 1)[-1].strip().replace('-', '').replace("'", '') 12 | 13 | 14 | requirements = [ 15 | "deprecated>=1.2.3", 16 | "packaging>=20.4", 17 | 'importlib-metadata >= 1.0; python_version < "3.8"', 18 | 'typing-extensions; python_version<"3.8"', 19 | "async-timeout>=4.0.2", 20 | 'lazyops>=0.2.61', 21 | 'redis<5.0.0', 22 | "pydantic", 23 | "anyio", 24 | "croniter", 25 | "tenacity", 26 | # "hiredis", 27 | ] 28 | 29 | args = { 30 | 'packages': find_packages(include=[ 31 | "aiokeydb", 32 | "aiokeydb.asyncio", 33 | "aiokeydb.commands", 34 | "aiokeydb.commands.bf", 35 | "aiokeydb.commands.json", 36 | "aiokeydb.commands.search", 37 | "aiokeydb.commands.timeseries", 38 | "aiokeydb.commands.graph", 39 | ]), 40 | 'install_requires': requirements, 41 | 'include_package_data': True, 42 | 'long_description': root.joinpath('README.md').read_text(encoding='utf-8'), 43 | 'entry_points': {}, 44 | 'extras_require': { 45 | "hiredis": ["hiredis>=1.0.0"], 46 | "ocsp": ["cryptography>=36.0.1", "pyopenssl==20.0.1", "requests>=2.26.0"], 47 | }, 48 | } 49 | 50 | 51 | setup( 52 | name=pkg_name, 53 | version=version, 54 | url=f'https://github.com/{gitrepo}', 55 | license='MIT Style', 56 | description='Python client for KeyDB database and key-value store', 57 | author='Tri Songz', 58 | author_email='ts@growthengineai.com', 59 | long_description_content_type="text/markdown", 60 | classifiers=[ 61 | 'Intended Audience :: Developers', 62 | 'License :: OSI Approved :: MIT License', 63 | 'Programming Language :: Python :: 3.7', 64 | 'Topic :: Software Development :: Libraries', 65 | ], 66 | **args 67 | ) -------------------------------------------------------------------------------- /aiokeydb/typing.py: -------------------------------------------------------------------------------- 1 | # from __future__ import annotations 2 | 3 | from datetime import datetime, timedelta 4 | from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union 5 | 6 | from redis.compat import Protocol 7 | if TYPE_CHECKING: 8 | from aiokeydb.connection import ( 9 | ConnectionPool, 10 | Encoder, 11 | AsyncConnectionPool, 12 | AsyncEncoder 13 | ) 14 | 15 | Number = Union[int, float] 16 | EncodedT = Union[bytes, memoryview] 17 | DecodedT = Union[str, int, float] 18 | EncodableT = Union[EncodedT, DecodedT] 19 | AbsExpiryT = Union[int, datetime] 20 | ExpiryT = Union[float, timedelta] 21 | ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix 22 | BitfieldOffsetT = Union[int, str] # str allows for #x syntax 23 | _StringLikeT = Union[bytes, str, memoryview] 24 | KeyT = _StringLikeT # Main redis key space 25 | PatternT = _StringLikeT # Patterns matched against keys, fields etc 26 | FieldT = EncodableT # Fields within hash tables, streams and geo commands 27 | KeysT = Union[KeyT, Iterable[KeyT]] 28 | ChannelT = _StringLikeT 29 | GroupT = _StringLikeT # Consumer group 30 | ConsumerT = _StringLikeT # Consumer name 31 | StreamIdT = Union[int, _StringLikeT] 32 | ScriptTextT = _StringLikeT 33 | TimeoutSecT = Union[int, float, _StringLikeT] 34 | # Mapping is not covariant in the key type, which prevents 35 | # Mapping[_StringLikeT, X] from accepting arguments of type Dict[str, X]. Using 36 | # a TypeVar instead of a Union allows mappings with any of the permitted types 37 | # to be passed. Care is needed if there is more than one such mapping in a 38 | # type signature because they will all be required to be the same key type. 39 | AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) 40 | AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) 41 | AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) 42 | 43 | 44 | class CommandsProtocol(Protocol): 45 | connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] 46 | 47 | def execute_command(self, *args, **options): 48 | ... 49 | 50 | 51 | class ClusterCommandsProtocol(CommandsProtocol): 52 | encoder: Union["AsyncEncoder", "Encoder"] 53 | 54 | def execute_command(self, *args, **options) -> Union[Any, Awaitable]: 55 | ... 56 | -------------------------------------------------------------------------------- /aiokeydb/v2/typing.py: -------------------------------------------------------------------------------- 1 | # from __future__ import annotations 2 | 3 | from datetime import datetime, timedelta 4 | from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union 5 | 6 | from redis.compat import Protocol 7 | if TYPE_CHECKING: 8 | from aiokeydb.v2.connection import ( 9 | ConnectionPool, 10 | Encoder, 11 | AsyncConnectionPool, 12 | AsyncEncoder 13 | ) 14 | 15 | Number = Union[int, float] 16 | EncodedT = Union[bytes, memoryview] 17 | DecodedT = Union[str, int, float] 18 | EncodableT = Union[EncodedT, DecodedT] 19 | AbsExpiryT = Union[int, datetime] 20 | ExpiryT = Union[float, timedelta] 21 | ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix 22 | BitfieldOffsetT = Union[int, str] # str allows for #x syntax 23 | _StringLikeT = Union[bytes, str, memoryview] 24 | KeyT = _StringLikeT # Main redis key space 25 | PatternT = _StringLikeT # Patterns matched against keys, fields etc 26 | FieldT = EncodableT # Fields within hash tables, streams and geo commands 27 | KeysT = Union[KeyT, Iterable[KeyT]] 28 | ChannelT = _StringLikeT 29 | GroupT = _StringLikeT # Consumer group 30 | ConsumerT = _StringLikeT # Consumer name 31 | StreamIdT = Union[int, _StringLikeT] 32 | ScriptTextT = _StringLikeT 33 | TimeoutSecT = Union[int, float, _StringLikeT] 34 | # Mapping is not covariant in the key type, which prevents 35 | # Mapping[_StringLikeT, X] from accepting arguments of type Dict[str, X]. Using 36 | # a TypeVar instead of a Union allows mappings with any of the permitted types 37 | # to be passed. Care is needed if there is more than one such mapping in a 38 | # type signature because they will all be required to be the same key type. 39 | AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) 40 | AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) 41 | AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) 42 | 43 | 44 | class CommandsProtocol(Protocol): 45 | connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] 46 | 47 | def execute_command(self, *args, **options): 48 | ... 49 | 50 | 51 | class ClusterCommandsProtocol(CommandsProtocol): 52 | encoder: Union["AsyncEncoder", "Encoder"] 53 | 54 | def execute_command(self, *args, **options) -> Union[Any, Awaitable]: 55 | ... 56 | -------------------------------------------------------------------------------- /aiokeydb/v1/typing.py: -------------------------------------------------------------------------------- 1 | # from __future__ import annotations 2 | 3 | from datetime import datetime, timedelta 4 | from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union 5 | 6 | from aiokeydb.v1.compat import Protocol 7 | 8 | if TYPE_CHECKING: 9 | from aiokeydb.v1.asyncio.connection import AsyncConnectionPool 10 | from aiokeydb.v1.asyncio.connection import Encoder as AsyncEncoder 11 | from aiokeydb.v1.connection import ConnectionPool, Encoder 12 | 13 | 14 | Number = Union[int, float] 15 | EncodedT = Union[bytes, memoryview] 16 | DecodedT = Union[str, int, float] 17 | EncodableT = Union[EncodedT, DecodedT] 18 | AbsExpiryT = Union[int, datetime] 19 | ExpiryT = Union[float, timedelta] 20 | ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix 21 | BitfieldOffsetT = Union[int, str] # str allows for #x syntax 22 | _StringLikeT = Union[bytes, str, memoryview] 23 | KeyT = _StringLikeT # Main redis key space 24 | PatternT = _StringLikeT # Patterns matched against keys, fields etc 25 | FieldT = EncodableT # Fields within hash tables, streams and geo commands 26 | KeysT = Union[KeyT, Iterable[KeyT]] 27 | ChannelT = _StringLikeT 28 | GroupT = _StringLikeT # Consumer group 29 | ConsumerT = _StringLikeT # Consumer name 30 | StreamIdT = Union[int, _StringLikeT] 31 | ScriptTextT = _StringLikeT 32 | TimeoutSecT = Union[int, float, _StringLikeT] 33 | # Mapping is not covariant in the key type, which prevents 34 | # Mapping[_StringLikeT, X] from accepting arguments of type Dict[str, X]. Using 35 | # a TypeVar instead of a Union allows mappings with any of the permitted types 36 | # to be passed. Care is needed if there is more than one such mapping in a 37 | # type signature because they will all be required to be the same key type. 38 | AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) 39 | AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) 40 | AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) 41 | 42 | 43 | class CommandsProtocol(Protocol): 44 | connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] 45 | 46 | def execute_command(self, *args, **options): 47 | ... 48 | 49 | 50 | class ClusterCommandsProtocol(CommandsProtocol): 51 | encoder: Union["AsyncEncoder", "Encoder"] 52 | 53 | def execute_command(self, *args, **options) -> Union[Any, Awaitable]: 54 | ... 55 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/graph/path.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.graph.edge import Edge 2 | from aiokeydb.v1.commands.graph.node import Node 3 | 4 | 5 | class Path: 6 | def __init__(self, nodes, edges): 7 | if not (isinstance(nodes, list) and isinstance(edges, list)): 8 | raise TypeError("nodes and edges must be list") 9 | 10 | self._nodes = nodes 11 | self._edges = edges 12 | self.append_type = Node 13 | 14 | @classmethod 15 | def new_empty_path(cls): 16 | return cls([], []) 17 | 18 | def nodes(self): 19 | return self._nodes 20 | 21 | def edges(self): 22 | return self._edges 23 | 24 | def get_node(self, index): 25 | return self._nodes[index] 26 | 27 | def get_relationship(self, index): 28 | return self._edges[index] 29 | 30 | def first_node(self): 31 | return self._nodes[0] 32 | 33 | def last_node(self): 34 | return self._nodes[-1] 35 | 36 | def edge_count(self): 37 | return len(self._edges) 38 | 39 | def nodes_count(self): 40 | return len(self._nodes) 41 | 42 | def add_node(self, node): 43 | if not isinstance(node, self.append_type): 44 | raise AssertionError("Add Edge before adding Node") 45 | self._nodes.append(node) 46 | self.append_type = Edge 47 | return self 48 | 49 | def add_edge(self, edge): 50 | if not isinstance(edge, self.append_type): 51 | raise AssertionError("Add Node before adding Edge") 52 | self._edges.append(edge) 53 | self.append_type = Node 54 | return self 55 | 56 | def __eq__(self, other): 57 | return self.nodes() == other.nodes() and self.edges() == other.edges() 58 | 59 | def __str__(self): 60 | res = "<" 61 | edge_count = self.edge_count() 62 | for i in range(0, edge_count): 63 | node_id = self.get_node(i).id 64 | res += "(" + str(node_id) + ")" 65 | edge = self.get_relationship(i) 66 | res += ( 67 | "-[" + str(int(edge.id)) + "]->" 68 | if edge.src_node == node_id 69 | else "<-[" + str(int(edge.id)) + "]-" 70 | ) 71 | node_id = self.get_node(edge_count).id 72 | res += "(" + str(node_id) + ")" 73 | res += ">" 74 | return res 75 | -------------------------------------------------------------------------------- /aiokeydb/utils/cron.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cron Utils 3 | """ 4 | 5 | import re 6 | import croniter 7 | 8 | # '*/5 * * * *' # every 5 minutes 9 | # '*/10 * * * *' # every 10 minutes 10 | 11 | # _schedule_fmt = { 12 | # 'minutes': '*/{value} * * * *', 13 | # 'hours': '* * */{value} * * *', 14 | # 'days': '* * * */{value} * *', 15 | # 'weeks': '* * * * */{value} *', 16 | # } 17 | 18 | ## v3 19 | 20 | _time_aliases_groups = { 21 | 'seconds': ['s', 'sec', 'secs'], 22 | 'minutes': ['m', 'min', 'mins'], 23 | 'hours': ['h', 'hr', 'hrs'], 24 | 'days': ['d', 'day'], 25 | 'weeks': ['w', 'wk', 'wks'], 26 | 'months': ['mo', 'mon', 'mons'], 27 | } 28 | _time_aliases = {alias: unit for unit, aliases in _time_aliases_groups.items() for alias in aliases} 29 | _time_pattern = re.compile(r'(?:(?:every )?(\d+) (\w+))(?:, | and )?') 30 | 31 | def validate_cron_schedule(cron_schedule: str) -> str: 32 | """ 33 | Convert natural language to cron format using regex patterns 34 | """ 35 | if croniter.croniter.is_valid(cron_schedule): return cron_schedule 36 | time_units = { 37 | 'seconds': None, 38 | 'minutes': '*', 39 | 'hours': '*', 40 | 'days': '*', 41 | 'weeks': '*', 42 | 'months': '*' 43 | } 44 | match = _time_pattern.findall(cron_schedule) 45 | if not match: raise ValueError(f"Invalid cron expression: {cron_schedule}") 46 | 47 | for num, unit in match: 48 | if unit in _time_aliases: unit = _time_aliases[unit] 49 | if not unit.endswith('s'): unit += 's' 50 | if unit not in time_units: 51 | raise ValueError(f"Invalid time unit in cron expression: unit: {unit}, num: {num}") 52 | time_units[unit] = f'*/{num}' 53 | 54 | if time_units['hours'] != "*" and time_units['minutes'] == '*': 55 | time_units['minutes'] = 0 56 | if time_units['days'] != "*" and time_units['hours'] == '*': 57 | time_units['hours'] = 0 58 | if time_units['weeks'] != "*" and time_units['days'] == '*': 59 | time_units['days'] = 0 60 | if time_units['months'] != "*" and time_units['weeks'] == '*': 61 | time_units['weeks'] = 0 62 | 63 | cron_expression = f"{time_units['minutes']} {time_units['hours']} {time_units['days']} {time_units['months']} {time_units['weeks']}" 64 | if time_units['seconds']: 65 | cron_expression += f" {time_units['seconds']}" 66 | return cron_expression.strip() 67 | 68 | 69 | -------------------------------------------------------------------------------- /aiokeydb/v2/utils/cron.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cron Utils 3 | """ 4 | 5 | import re 6 | import croniter 7 | 8 | # '*/5 * * * *' # every 5 minutes 9 | # '*/10 * * * *' # every 10 minutes 10 | 11 | # _schedule_fmt = { 12 | # 'minutes': '*/{value} * * * *', 13 | # 'hours': '* * */{value} * * *', 14 | # 'days': '* * * */{value} * *', 15 | # 'weeks': '* * * * */{value} *', 16 | # } 17 | 18 | ## v3 19 | 20 | _time_aliases_groups = { 21 | 'seconds': ['s', 'sec', 'secs'], 22 | 'minutes': ['m', 'min', 'mins'], 23 | 'hours': ['h', 'hr', 'hrs'], 24 | 'days': ['d', 'day'], 25 | 'weeks': ['w', 'wk', 'wks'], 26 | 'months': ['mo', 'mon', 'mons'], 27 | } 28 | _time_aliases = {alias: unit for unit, aliases in _time_aliases_groups.items() for alias in aliases} 29 | _time_pattern = re.compile(r'(?:(?:every )?(\d+) (\w+))(?:, | and )?') 30 | 31 | def validate_cron_schedule(cron_schedule: str) -> str: 32 | """ 33 | Convert natural language to cron format using regex patterns 34 | """ 35 | if croniter.croniter.is_valid(cron_schedule): return cron_schedule 36 | time_units = { 37 | 'seconds': None, 38 | 'minutes': '*', 39 | 'hours': '*', 40 | 'days': '*', 41 | 'weeks': '*', 42 | 'months': '*' 43 | } 44 | match = _time_pattern.findall(cron_schedule) 45 | if not match: raise ValueError(f"Invalid cron expression: {cron_schedule}") 46 | 47 | for num, unit in match: 48 | if unit in _time_aliases: unit = _time_aliases[unit] 49 | if not unit.endswith('s'): unit += 's' 50 | if unit not in time_units: 51 | raise ValueError(f"Invalid time unit in cron expression: unit: {unit}, num: {num}") 52 | time_units[unit] = f'*/{num}' 53 | 54 | if time_units['hours'] != "*" and time_units['minutes'] == '*': 55 | time_units['minutes'] = 0 56 | if time_units['days'] != "*" and time_units['hours'] == '*': 57 | time_units['hours'] = 0 58 | if time_units['weeks'] != "*" and time_units['days'] == '*': 59 | time_units['days'] = 0 60 | if time_units['months'] != "*" and time_units['weeks'] == '*': 61 | time_units['weeks'] = 0 62 | 63 | cron_expression = f"{time_units['minutes']} {time_units['hours']} {time_units['days']} {time_units['months']} {time_units['weeks']}" 64 | if time_units['seconds']: 65 | cron_expression += f" {time_units['seconds']}" 66 | return cron_expression.strip() 67 | 68 | 69 | -------------------------------------------------------------------------------- /aiokeydb/v1/asyncio/retry.py: -------------------------------------------------------------------------------- 1 | from asyncio import sleep 2 | from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar 3 | 4 | from aiokeydb.v1.exceptions import ConnectionError, KeyDBError, TimeoutError 5 | 6 | if TYPE_CHECKING: 7 | from aiokeydb.v1.backoff import AbstractBackoff 8 | 9 | 10 | T = TypeVar("T") 11 | 12 | 13 | class Retry: 14 | """Retry a specific number of times after a failure""" 15 | 16 | __slots__ = "_backoff", "_retries", "_supported_errors" 17 | 18 | def __init__( 19 | self, 20 | backoff: "AbstractBackoff", 21 | retries: int, 22 | supported_errors: Tuple[Type[KeyDBError], ...] = ( 23 | ConnectionError, 24 | TimeoutError, 25 | ), 26 | ): 27 | """ 28 | Initialize a `Retry` object with a `Backoff` object 29 | that retries a maximum of `retries` times. 30 | `retries` can be negative to retry forever. 31 | You can specify the types of supported errors which trigger 32 | a retry with the `supported_errors` parameter. 33 | """ 34 | self._backoff = backoff 35 | self._retries = retries 36 | self._supported_errors = supported_errors 37 | 38 | def update_supported_errors(self, specified_errors: list): 39 | """ 40 | Updates the supported errors with the specified error types 41 | """ 42 | self._supported_errors = tuple( 43 | set(self._supported_errors + tuple(specified_errors)) 44 | ) 45 | 46 | async def call_with_retry( 47 | self, do: Callable[[], Awaitable[T]], fail: Callable[[KeyDBError], Any] 48 | ) -> T: 49 | """ 50 | Execute an operation that might fail and returns its result, or 51 | raise the exception that was thrown depending on the `Backoff` object. 52 | `do`: the operation to call. Expects no argument. 53 | `fail`: the failure handler, expects the last error that was thrown 54 | """ 55 | self._backoff.reset() 56 | failures = 0 57 | while True: 58 | try: 59 | return await do() 60 | except self._supported_errors as error: 61 | failures += 1 62 | await fail(error) 63 | if self._retries >= 0 and failures > self._retries: 64 | raise error 65 | backoff = self._backoff.compute(failures) 66 | if backoff > 0: 67 | await sleep(backoff) 68 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/search/result.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.search._util import to_string 2 | from aiokeydb.v1.commands.search.document import Document 3 | 4 | 5 | class Result: 6 | """ 7 | Represents the result of a search query, and has an array of Document 8 | objects 9 | """ 10 | 11 | def __init__( 12 | self, res, hascontent, duration=0, has_payload=False, with_scores=False 13 | ): 14 | """ 15 | - **snippets**: An optional dictionary of the form 16 | {field: snippet_size} for snippet formatting 17 | """ 18 | 19 | self.total = res[0] 20 | self.duration = duration 21 | self.docs = [] 22 | 23 | step = 1 24 | if hascontent: 25 | step = step + 1 26 | if has_payload: 27 | step = step + 1 28 | if with_scores: 29 | step = step + 1 30 | 31 | offset = 2 if with_scores else 1 32 | 33 | for i in range(1, len(res), step): 34 | id = to_string(res[i]) 35 | payload = to_string(res[i + offset]) if has_payload else None 36 | # fields_offset = 2 if has_payload else 1 37 | fields_offset = offset + 1 if has_payload else offset 38 | score = float(res[i + 1]) if with_scores else None 39 | 40 | fields = {} 41 | if hascontent and res[i + fields_offset] is not None: 42 | fields = ( 43 | dict( 44 | dict( 45 | zip( 46 | map(to_string, res[i + fields_offset][::2]), 47 | map(to_string, res[i + fields_offset][1::2]), 48 | ) 49 | ) 50 | ) 51 | if hascontent 52 | else {} 53 | ) 54 | try: 55 | del fields["id"] 56 | except KeyError: 57 | pass 58 | 59 | try: 60 | fields["json"] = fields["$"] 61 | del fields["$"] 62 | except KeyError: 63 | pass 64 | 65 | doc = ( 66 | Document(id, score=score, payload=payload, **fields) 67 | if with_scores 68 | else Document(id, payload=payload, **fields) 69 | ) 70 | self.docs.append(doc) 71 | 72 | def __repr__(self): 73 | return f"Result{{{self.total} total, docs: {self.docs}}}" 74 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/graph/node.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.helpers import quote_string 2 | 3 | 4 | class Node: 5 | """ 6 | A node within the graph. 7 | """ 8 | 9 | def __init__(self, node_id=None, alias=None, label=None, properties=None): 10 | """ 11 | Create a new node. 12 | """ 13 | self.id = node_id 14 | self.alias = alias 15 | if isinstance(label, list): 16 | label = [inner_label for inner_label in label if inner_label != ""] 17 | 18 | if ( 19 | label is None 20 | or label == "" 21 | or (isinstance(label, list) and len(label) == 0) 22 | ): 23 | self.label = None 24 | self.labels = None 25 | elif isinstance(label, str): 26 | self.label = label 27 | self.labels = [label] 28 | elif isinstance(label, list) and all( 29 | [isinstance(inner_label, str) for inner_label in label] 30 | ): 31 | self.label = label[0] 32 | self.labels = label 33 | else: 34 | raise AssertionError( 35 | "label should be either None, " "string or a list of strings" 36 | ) 37 | 38 | self.properties = properties or {} 39 | 40 | def to_string(self): 41 | res = "" 42 | if self.properties: 43 | props = ",".join( 44 | key + ":" + str(quote_string(val)) 45 | for key, val in sorted(self.properties.items()) 46 | ) 47 | res += "{" + props + "}" 48 | 49 | return res 50 | 51 | def __str__(self): 52 | res = "(" 53 | if self.alias: 54 | res += self.alias 55 | if self.labels: 56 | res += ":" + ":".join(self.labels) 57 | if self.properties: 58 | props = ",".join( 59 | key + ":" + str(quote_string(val)) 60 | for key, val in sorted(self.properties.items()) 61 | ) 62 | res += "{" + props + "}" 63 | res += ")" 64 | 65 | return res 66 | 67 | def __eq__(self, rhs): 68 | # Quick positive check, if both IDs are set. 69 | if self.id is not None and rhs.id is not None and self.id != rhs.id: 70 | return False 71 | 72 | # Label should match. 73 | if self.label != rhs.label: 74 | return False 75 | 76 | # Quick check for number of properties. 77 | if len(self.properties) != len(rhs.properties): 78 | return False 79 | 80 | # Compare properties. 81 | if self.properties != rhs.properties: 82 | return False 83 | 84 | return True 85 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/graph/edge.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.helpers import quote_string 2 | from aiokeydb.v1.commands.graph.node import Node 3 | 4 | 5 | class Edge: 6 | """ 7 | An edge connecting two nodes. 8 | """ 9 | 10 | def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None): 11 | """ 12 | Create a new edge. 13 | """ 14 | if src_node is None or dest_node is None: 15 | # NOTE(bors-42): It makes sense to change AssertionError to 16 | # ValueError here 17 | raise AssertionError("Both src_node & dest_node must be provided") 18 | 19 | self.id = edge_id 20 | self.relation = relation or "" 21 | self.properties = properties or {} 22 | self.src_node = src_node 23 | self.dest_node = dest_node 24 | 25 | def to_string(self): 26 | res = "" 27 | if self.properties: 28 | props = ",".join( 29 | key + ":" + str(quote_string(val)) 30 | for key, val in sorted(self.properties.items()) 31 | ) 32 | res += "{" + props + "}" 33 | 34 | return res 35 | 36 | def __str__(self): 37 | # Source node. 38 | if isinstance(self.src_node, Node): 39 | res = str(self.src_node) 40 | else: 41 | res = "()" 42 | 43 | # Edge 44 | res += "-[" 45 | if self.relation: 46 | res += ":" + self.relation 47 | if self.properties: 48 | props = ",".join( 49 | key + ":" + str(quote_string(val)) 50 | for key, val in sorted(self.properties.items()) 51 | ) 52 | res += "{" + props + "}" 53 | res += "]->" 54 | 55 | # Dest node. 56 | if isinstance(self.dest_node, Node): 57 | res += str(self.dest_node) 58 | else: 59 | res += "()" 60 | 61 | return res 62 | 63 | def __eq__(self, rhs): 64 | # Quick positive check, if both IDs are set. 65 | if self.id is not None and rhs.id is not None and self.id == rhs.id: 66 | return True 67 | 68 | # Source and destination nodes should match. 69 | if self.src_node != rhs.src_node: 70 | return False 71 | 72 | if self.dest_node != rhs.dest_node: 73 | return False 74 | 75 | # Relation should match. 76 | if self.relation != rhs.relation: 77 | return False 78 | 79 | # Quick check for number of properties. 80 | if len(self.properties) != len(rhs.properties): 81 | return False 82 | 83 | # Compare properties. 84 | if self.properties != rhs.properties: 85 | return False 86 | 87 | return True 88 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/search/indexDefinition.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class IndexType(Enum): 5 | """Enum of the currently supported index types.""" 6 | 7 | HASH = 1 8 | JSON = 2 9 | 10 | 11 | class IndexDefinition: 12 | """IndexDefinition is used to define a index definition for automatic 13 | indexing on Hash or Json update.""" 14 | 15 | def __init__( 16 | self, 17 | prefix=[], 18 | filter=None, 19 | language_field=None, 20 | language=None, 21 | score_field=None, 22 | score=1.0, 23 | payload_field=None, 24 | index_type=None, 25 | ): 26 | self.args = [] 27 | self._append_index_type(index_type) 28 | self._append_prefix(prefix) 29 | self._append_filter(filter) 30 | self._append_language(language_field, language) 31 | self._append_score(score_field, score) 32 | self._append_payload(payload_field) 33 | 34 | def _append_index_type(self, index_type): 35 | """Append `ON HASH` or `ON JSON` according to the enum.""" 36 | if index_type is IndexType.HASH: 37 | self.args.extend(["ON", "HASH"]) 38 | elif index_type is IndexType.JSON: 39 | self.args.extend(["ON", "JSON"]) 40 | elif index_type is not None: 41 | raise RuntimeError(f"index_type must be one of {list(IndexType)}") 42 | 43 | def _append_prefix(self, prefix): 44 | """Append PREFIX.""" 45 | if len(prefix) > 0: 46 | self.args.append("PREFIX") 47 | self.args.append(len(prefix)) 48 | for p in prefix: 49 | self.args.append(p) 50 | 51 | def _append_filter(self, filter): 52 | """Append FILTER.""" 53 | if filter is not None: 54 | self.args.append("FILTER") 55 | self.args.append(filter) 56 | 57 | def _append_language(self, language_field, language): 58 | """Append LANGUAGE_FIELD and LANGUAGE.""" 59 | if language_field is not None: 60 | self.args.append("LANGUAGE_FIELD") 61 | self.args.append(language_field) 62 | if language is not None: 63 | self.args.append("LANGUAGE") 64 | self.args.append(language) 65 | 66 | def _append_score(self, score_field, score): 67 | """Append SCORE_FIELD and SCORE.""" 68 | if score_field is not None: 69 | self.args.append("SCORE_FIELD") 70 | self.args.append(score_field) 71 | if score is not None: 72 | self.args.append("SCORE") 73 | self.args.append(score) 74 | 75 | def _append_payload(self, payload_field): 76 | """Append PAYLOAD_FIELD.""" 77 | if payload_field is not None: 78 | self.args.append("PAYLOAD_FIELD") 79 | self.args.append(payload_field) 80 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/bf/info.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.helpers import nativestr 2 | 3 | 4 | class BFInfo(object): 5 | capacity = None 6 | size = None 7 | filterNum = None 8 | insertedNum = None 9 | expansionRate = None 10 | 11 | def __init__(self, args): 12 | response = dict(zip(map(nativestr, args[::2]), args[1::2])) 13 | self.capacity = response["Capacity"] 14 | self.size = response["Size"] 15 | self.filterNum = response["Number of filters"] 16 | self.insertedNum = response["Number of items inserted"] 17 | self.expansionRate = response["Expansion rate"] 18 | 19 | 20 | class CFInfo(object): 21 | size = None 22 | bucketNum = None 23 | filterNum = None 24 | insertedNum = None 25 | deletedNum = None 26 | bucketSize = None 27 | expansionRate = None 28 | maxIteration = None 29 | 30 | def __init__(self, args): 31 | response = dict(zip(map(nativestr, args[::2]), args[1::2])) 32 | self.size = response["Size"] 33 | self.bucketNum = response["Number of buckets"] 34 | self.filterNum = response["Number of filters"] 35 | self.insertedNum = response["Number of items inserted"] 36 | self.deletedNum = response["Number of items deleted"] 37 | self.bucketSize = response["Bucket size"] 38 | self.expansionRate = response["Expansion rate"] 39 | self.maxIteration = response["Max iterations"] 40 | 41 | 42 | class CMSInfo(object): 43 | width = None 44 | depth = None 45 | count = None 46 | 47 | def __init__(self, args): 48 | response = dict(zip(map(nativestr, args[::2]), args[1::2])) 49 | self.width = response["width"] 50 | self.depth = response["depth"] 51 | self.count = response["count"] 52 | 53 | 54 | class TopKInfo(object): 55 | k = None 56 | width = None 57 | depth = None 58 | decay = None 59 | 60 | def __init__(self, args): 61 | response = dict(zip(map(nativestr, args[::2]), args[1::2])) 62 | self.k = response["k"] 63 | self.width = response["width"] 64 | self.depth = response["depth"] 65 | self.decay = response["decay"] 66 | 67 | 68 | class TDigestInfo(object): 69 | compression = None 70 | capacity = None 71 | mergedNodes = None 72 | unmergedNodes = None 73 | mergedWeight = None 74 | unmergedWeight = None 75 | totalCompressions = None 76 | 77 | def __init__(self, args): 78 | response = dict(zip(map(nativestr, args[::2]), args[1::2])) 79 | self.compression = response["Compression"] 80 | self.capacity = response["Capacity"] 81 | self.mergedNodes = response["Merged nodes"] 82 | self.unmergedNodes = response["Unmerged nodes"] 83 | self.mergedWeight = response["Merged weight"] 84 | self.unmergedWeight = response["Unmerged weight"] 85 | self.totalCompressions = response["Total compressions"] 86 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/redismodules.py: -------------------------------------------------------------------------------- 1 | from json import JSONDecoder, JSONEncoder 2 | 3 | 4 | class RedisModuleCommands: 5 | """This class contains the wrapper functions to bring supported redis 6 | modules into the command namespace. 7 | """ 8 | 9 | def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()): 10 | """Access the json namespace, providing support for redis json.""" 11 | 12 | from .json import JSON 13 | 14 | jj = JSON(client=self, encoder=encoder, decoder=decoder) 15 | return jj 16 | 17 | def ft(self, index_name="idx"): 18 | """Access the search namespace, providing support for redis search.""" 19 | 20 | from .search import Search 21 | 22 | s = Search(client=self, index_name=index_name) 23 | return s 24 | 25 | def ts(self): 26 | """Access the timeseries namespace, providing support for 27 | redis timeseries data. 28 | """ 29 | 30 | from .timeseries import TimeSeries 31 | 32 | s = TimeSeries(client=self) 33 | return s 34 | 35 | def bf(self): 36 | """Access the bloom namespace.""" 37 | 38 | from .bf import BFBloom 39 | 40 | bf = BFBloom(client=self) 41 | return bf 42 | 43 | def cf(self): 44 | """Access the bloom namespace.""" 45 | 46 | from .bf import CFBloom 47 | 48 | cf = CFBloom(client=self) 49 | return cf 50 | 51 | def cms(self): 52 | """Access the bloom namespace.""" 53 | 54 | from .bf import CMSBloom 55 | 56 | cms = CMSBloom(client=self) 57 | return cms 58 | 59 | def topk(self): 60 | """Access the bloom namespace.""" 61 | 62 | from .bf import TOPKBloom 63 | 64 | topk = TOPKBloom(client=self) 65 | return topk 66 | 67 | def tdigest(self): 68 | """Access the bloom namespace.""" 69 | 70 | from .bf import TDigestBloom 71 | 72 | tdigest = TDigestBloom(client=self) 73 | return tdigest 74 | 75 | def graph(self, index_name="idx"): 76 | """Access the graph namespace, providing support for 77 | redis graph data. 78 | """ 79 | 80 | from .graph import Graph 81 | 82 | g = Graph(client=self, name=index_name) 83 | return g 84 | 85 | 86 | class AsyncRedisModuleCommands(RedisModuleCommands): 87 | def ft(self, index_name="idx"): 88 | """Access the search namespace, providing support for redis search.""" 89 | 90 | from .search import AsyncSearch 91 | 92 | s = AsyncSearch(client=self, index_name=index_name) 93 | return s 94 | 95 | def graph(self, index_name="idx"): 96 | """Access the graph namespace, providing support for 97 | redis graph data. 98 | """ 99 | 100 | from .graph import AsyncGraph 101 | 102 | g = AsyncGraph(client=self, name=index_name) 103 | return g 104 | -------------------------------------------------------------------------------- /aiokeydb/v2/utils/logs.py: -------------------------------------------------------------------------------- 1 | from lazyops.utils.logs import get_logger, STATUS_COLOR, COLORED_MESSAGE_MAP, FALLBACK_STATUS_COLOR 2 | 3 | 4 | class ColorMap: 5 | green: str = '\033[0;32m' 6 | red: str = '\033[0;31m' 7 | yellow: str = '\033[0;33m' 8 | blue: str = '\033[0;34m' 9 | magenta: str = '\033[0;35m' 10 | cyan: str = '\033[0;36m' 11 | white: str = '\033[0;37m' 12 | bold: str = '\033[1m' 13 | reset: str = '\033[0m' 14 | 15 | 16 | 17 | 18 | class CustomizeLogger: 19 | 20 | @staticmethod 21 | def logger_formatter(record: dict) -> str: 22 | """ 23 | To add a custom format for a module, add another `elif` clause with code to determine `extra` and `level`. 24 | 25 | From that module and all submodules, call logger with `logger.bind(foo='bar').info(msg)`. 26 | Then you can access it with `record['extra'].get('foo')`. 27 | """ 28 | extra = '{name}:{function}: ' 29 | 30 | if record.get('extra'): 31 | if record['extra'].get('request_id'): 32 | extra = '{name}:{function}:request_id: {extra[request_id]} ' 33 | 34 | elif record['extra'].get('job_id') and record['extra'].get('queue_name') and record['extra'].get('kind'): 35 | status = record['extra'].get('status') 36 | color = STATUS_COLOR.get(status, FALLBACK_STATUS_COLOR) 37 | kind_color = STATUS_COLOR.get(record.get('extra', {}).get('kind'), FALLBACK_STATUS_COLOR) 38 | if not record['extra'].get('worker_name'): 39 | record['extra']['worker_name'] = '' 40 | extra = '{extra[queue_name]}:{extra[worker_name]}:<' + kind_color + '>{extra[kind]:<9} <' + color + '>{extra[job_id]} ' 41 | # if record['extra'].get('cid'): 42 | # extra += '({extra[cid]}/{extra[conn]}) ' 43 | 44 | elif record['extra'].get('kind') and record['extra'].get('queue_name'): 45 | if not record['extra'].get('worker_name'): 46 | record['extra']['worker_name'] = '' 47 | kind_color = STATUS_COLOR.get(record.get('extra', {}).get('kind'), FALLBACK_STATUS_COLOR) 48 | extra = '{extra[queue_name]}:{extra[worker_name]}:<' + kind_color + '>{extra[kind]:<9} ' 49 | # if record['extra'].get('cid'): 50 | # extra += '({extra[cid]}/{extra[conn]}) ' 51 | 52 | 53 | if 'result=tensor([' not in str(record['message']): 54 | return "{level: <8} {time:YYYY-MM-DD HH:mm:ss.SSS}: "\ 55 | + extra + "{message}\n" 56 | msg = str(record['message'])[:100].replace('{', '(').replace('}', ')') 57 | return "{level: <8} {time:YYYY-MM-DD HH:mm:ss.SSS}: "\ 58 | + extra + "" + msg + f"{STATUS_COLOR['reset']}\n" 59 | 60 | 61 | logger = get_logger( 62 | format = CustomizeLogger.logger_formatter, 63 | ) 64 | -------------------------------------------------------------------------------- /tests/test_v2_cachify.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import asyncio 4 | import uuid 5 | import multiprocessing 6 | from aiokeydb.v2.client import KeyDBClient 7 | from aiokeydb.v2.queues import Worker, TaskQueue 8 | from aiokeydb.v2.client import logger 9 | 10 | @KeyDBClient.cachify(cache_ttl=10) 11 | async def a_cached_func(number: int): 12 | return number + 1 13 | 14 | @KeyDBClient.worker.add_fallback_function() 15 | async def w_a_cached_func(ctx, number: int): 16 | 17 | @KeyDBClient.cachify(cache_ttl=10) 18 | async def inner(n: int): 19 | return n + 1 20 | 21 | return await inner(number) 22 | 23 | @KeyDBClient.cachify(cache_ttl=10) 24 | def cached_func(number: int): 25 | return number + 1 26 | 27 | @KeyDBClient.worker.add_fallback_function() 28 | def w_cached_func(ctx, number: int): 29 | 30 | @KeyDBClient.cachify(cache_ttl=10) 31 | def inner(n: int): 32 | return n + 1 33 | 34 | return inner(number) 35 | 36 | 37 | async def start_worker(): 38 | tq = TaskQueue("test_queue") 39 | KeyDBClient.worker.set_queue_func(tq) 40 | worker = Worker(tq) 41 | await worker.start() 42 | 43 | async def test_cached_func(n: int = 5, runs: int = 10): 44 | # Test that both results are the same. 45 | sync_t, async_t = 0.0, 0.0 46 | 47 | for i in range(runs): 48 | t = time.time() 49 | print(f'[Async - {i}/{runs}] Result: {await a_cached_func(n)}') 50 | tt = time.time() - t 51 | print(f'[Async - {i}/{runs}] Time: {tt:.2f}s') 52 | async_t += tt 53 | print(f'[Async] Cache Average Time: {async_t / runs:.2f}s | Total Time: {async_t:.2f}s') 54 | 55 | for i in range(runs): 56 | t = time.time() 57 | print(f'[Sync - {i}/{runs}] Result: {cached_func(n)}') 58 | tt = time.time() - t 59 | print(f'[Sync - {i}/{runs}] Time: {tt:.2f}s') 60 | sync_t += tt 61 | print(f'[Sync] Cache Average Time: {sync_t / runs:.2f}s | Total Time: {sync_t:.2f}s') 62 | 63 | async def test_worker_func(n: int = 5, runs: int = 10): 64 | # test inner caching 65 | sync_t, async_t = 0.0, 0.0 66 | 67 | for i in range(runs): 68 | t = time.time() 69 | print(f'[Async Inner - {i}/{runs}] Result: {await w_a_cached_func(number = n)}') 70 | tt = time.time() - t 71 | print(f'[Async Inner - {i}/{runs}] Time: {tt:.2f}s') 72 | async_t += tt 73 | print(f'[Async Inner] Cache Average Time: {async_t / runs:.2f}s | Total Time: {async_t:.2f}s') 74 | 75 | for i in range(runs): 76 | t = time.time() 77 | print(f'[Sync Inner - {i}/{runs}] Result: {await w_cached_func(number = n)}') 78 | tt = time.time() - t 79 | print(f'[Sync Inner - {i}/{runs}] Time: {tt:.2f}s') 80 | sync_t += tt 81 | print(f'[Sync Inner] Cache Average Time: {sync_t / runs:.2f}s | Total Time: {sync_t:.2f}s') 82 | # task.join() 83 | sys.exit(0) 84 | 85 | # await worker.stop() 86 | # await asyncio.gather(task) 87 | # task.cancel() 88 | 89 | async def run_tests(): 90 | await test_cached_func() 91 | # proc = 92 | 93 | 94 | asyncio.run(test_cached_func()) 95 | 96 | # async def run_tests(fib_n: int = 15, fib_runs: int = 10, setget_runs: int = 10): 97 | -------------------------------------------------------------------------------- /aiokeydb/v1/backoff.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class AbstractBackoff(ABC): 6 | """Backoff interface""" 7 | 8 | def reset(self): 9 | """ 10 | Reset internal state before an operation. 11 | `reset` is called once at the beginning of 12 | every call to `Retry.call_with_retry` 13 | """ 14 | pass 15 | 16 | @abstractmethod 17 | def compute(self, failures): 18 | """Compute backoff in seconds upon failure""" 19 | pass 20 | 21 | 22 | class ConstantBackoff(AbstractBackoff): 23 | """Constant backoff upon failure""" 24 | 25 | def __init__(self, backoff): 26 | """`backoff`: backoff time in seconds""" 27 | self._backoff = backoff 28 | 29 | def compute(self, failures): 30 | return self._backoff 31 | 32 | 33 | class NoBackoff(ConstantBackoff): 34 | """No backoff upon failure""" 35 | 36 | def __init__(self): 37 | super().__init__(0) 38 | 39 | 40 | class ExponentialBackoff(AbstractBackoff): 41 | """Exponential backoff upon failure""" 42 | 43 | def __init__(self, cap, base): 44 | """ 45 | `cap`: maximum backoff time in seconds 46 | `base`: base backoff time in seconds 47 | """ 48 | self._cap = cap 49 | self._base = base 50 | 51 | def compute(self, failures): 52 | return min(self._cap, self._base * 2**failures) 53 | 54 | 55 | class FullJitterBackoff(AbstractBackoff): 56 | """Full jitter backoff upon failure""" 57 | 58 | def __init__(self, cap, base): 59 | """ 60 | `cap`: maximum backoff time in seconds 61 | `base`: base backoff time in seconds 62 | """ 63 | self._cap = cap 64 | self._base = base 65 | 66 | def compute(self, failures): 67 | return random.uniform(0, min(self._cap, self._base * 2**failures)) 68 | 69 | 70 | class EqualJitterBackoff(AbstractBackoff): 71 | """Equal jitter backoff upon failure""" 72 | 73 | def __init__(self, cap, base): 74 | """ 75 | `cap`: maximum backoff time in seconds 76 | `base`: base backoff time in seconds 77 | """ 78 | self._cap = cap 79 | self._base = base 80 | 81 | def compute(self, failures): 82 | temp = min(self._cap, self._base * 2**failures) / 2 83 | return temp + random.uniform(0, temp) 84 | 85 | 86 | class DecorrelatedJitterBackoff(AbstractBackoff): 87 | """Decorrelated jitter backoff upon failure""" 88 | 89 | def __init__(self, cap, base): 90 | """ 91 | `cap`: maximum backoff time in seconds 92 | `base`: base backoff time in seconds 93 | """ 94 | self._cap = cap 95 | self._base = base 96 | self._previous_backoff = 0 97 | 98 | def reset(self): 99 | self._previous_backoff = 0 100 | 101 | def compute(self, failures): 102 | max_backoff = max(self._base, self._previous_backoff * 3) 103 | temp = random.uniform(self._base, max_backoff) 104 | self._previous_backoff = min(self._cap, temp) 105 | return self._previous_backoff 106 | -------------------------------------------------------------------------------- /tests/test_v3_cachify.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import asyncio 4 | import uuid 5 | import multiprocessing 6 | from aiokeydb.client import KeyDBClient 7 | from aiokeydb.queues import Worker, TaskQueue 8 | from aiokeydb.client import logger 9 | 10 | session = KeyDBClient.init_session() 11 | 12 | @session.cachify_v2(ttl=10) 13 | async def a_cached_func(number: int): 14 | return number + 1 15 | 16 | @KeyDBClient.worker.add_fallback_function() 17 | async def w_a_cached_func(ctx, number: int): 18 | 19 | @session.cachify_v2(ttl=10) 20 | async def inner(n: int): 21 | return n + 1 22 | 23 | return await inner(number) 24 | 25 | @session.cachify_v2(cache_ttl=10) 26 | def cached_func(number: int): 27 | return number + 1 28 | 29 | @KeyDBClient.worker.add_fallback_function() 30 | def w_cached_func(ctx, number: int): 31 | 32 | @session.cachify_v2(cache_ttl=10) 33 | def inner(n: int): 34 | return n + 1 35 | 36 | return inner(number) 37 | 38 | 39 | async def start_worker(): 40 | tq = TaskQueue("test_queue") 41 | KeyDBClient.worker.set_queue_func(tq) 42 | worker = Worker(tq) 43 | await worker.start() 44 | 45 | async def test_cached_func(n: int = 5, runs: int = 10): 46 | # Test that both results are the same. 47 | sync_t, async_t = 0.0, 0.0 48 | 49 | for i in range(runs): 50 | t = time.time() 51 | print(f'[Async - {i}/{runs}] Result: {await a_cached_func(n)}') 52 | tt = time.time() - t 53 | print(f'[Async - {i}/{runs}] Time: {tt:.2f}s') 54 | async_t += tt 55 | print(f'[Async] Cache Average Time: {async_t / runs:.2f}s | Total Time: {async_t:.2f}s') 56 | 57 | for i in range(runs): 58 | t = time.time() 59 | print(f'[Sync - {i}/{runs}] Result: {cached_func(n)}') 60 | tt = time.time() - t 61 | print(f'[Sync - {i}/{runs}] Time: {tt:.2f}s') 62 | sync_t += tt 63 | print(f'[Sync] Cache Average Time: {sync_t / runs:.2f}s | Total Time: {sync_t:.2f}s') 64 | 65 | async def test_worker_func(n: int = 5, runs: int = 10): 66 | # test inner caching 67 | sync_t, async_t = 0.0, 0.0 68 | 69 | for i in range(runs): 70 | t = time.time() 71 | print(f'[Async Inner - {i}/{runs}] Result: {await w_a_cached_func(number = n)}') 72 | tt = time.time() - t 73 | print(f'[Async Inner - {i}/{runs}] Time: {tt:.2f}s') 74 | async_t += tt 75 | print(f'[Async Inner] Cache Average Time: {async_t / runs:.2f}s | Total Time: {async_t:.2f}s') 76 | 77 | for i in range(runs): 78 | t = time.time() 79 | print(f'[Sync Inner - {i}/{runs}] Result: {await w_cached_func(number = n)}') 80 | tt = time.time() - t 81 | print(f'[Sync Inner - {i}/{runs}] Time: {tt:.2f}s') 82 | sync_t += tt 83 | print(f'[Sync Inner] Cache Average Time: {sync_t / runs:.2f}s | Total Time: {sync_t:.2f}s') 84 | # task.join() 85 | sys.exit(0) 86 | 87 | # await worker.stop() 88 | # await asyncio.gather(task) 89 | # task.cancel() 90 | 91 | async def run_tests(): 92 | await test_cached_func() 93 | # proc = 94 | 95 | 96 | asyncio.run(test_cached_func()) 97 | 98 | # async def run_tests(fib_n: int = 15, fib_runs: int = 10, setget_runs: int = 10): 99 | -------------------------------------------------------------------------------- /CHANGELOGS.md: -------------------------------------------------------------------------------- 1 | # aiokeydb changelogs 2 | 3 | ## 0.2.0rc0 (2023-10-16) 4 | 5 | **Breaking Changes Inbound** 6 | 7 | - Rework entire module, deprecating previous implementation into `v1`, maintaining legacy `v2` namespace, while transforming the previous `v2` into the primary module. The final release version of `0.2.0` will not contain any modules, but rather reference the `v2` namespace as the primary module. 8 | 9 | - Rework `pydantic` dependencies to support both `v1/v2`. 10 | 11 | ### rc1 12 | 13 | - Rework `cachify` module 14 | 15 | 16 | 17 | ## 0.1.31 (2023-03-29) 18 | - Add new `add_fallback_function` method for Worker, which allows for a fallback function to be called when the worker fails to execute the task. (v2) 19 | 20 | 21 | ## 0.1.30 (2023-03-28) 22 | - Start of migration of library to maintain upstream compatibility 23 | with `redis-py`. 24 | - Usable in `aiokeydb.v2` namespace. Will complete full migration by v0.1.50 25 | - Attempts to maintain backwards compatibility with `aiokeydb` v0.1.19 26 | - All classes are now subclassed from `redis` and `redis.asyncio` rather than being explictly defined. 27 | - Worker tasks functions are now callable via `KeyDBClient` and `KeyDBWorkerSettings` 28 | - Rework `ConnectionPool` and `AsyncConnectionPool` to handle resetting of the connection pool when the maximum number of connections are reached. 29 | 30 | 31 | ## 0.1.19 (2023-03-08) 32 | - Add utility to set ulimits when initializing the connection pool. 33 | 34 | ## 0.1.18 (2023-03-08) 35 | - Resolve ConnectionPools with reset capabilities 36 | - Refactor `KeyDBSession` to utilize the ConnectionPool initialized by `KeyDBClient` 37 | - Refactor `KeyDBClient` to initialize Sessions using shared connection pools for async and sync in order to avoid spawning a new connection pool per session. 38 | - Moved certain class vars to its own state for `KeyDBSesssion` 39 | - Reorder Worker Queue initialization to prevent overlapping event loops 40 | - Implement certain changes from `redis-py` 41 | - kept previous `KeyDBClient` that is accessible via `aiokeydb.client.core` vs `aiokeydb.client.meta` 42 | 43 | ## 0.1.7 (2023-02-01) 44 | - Resolve worker issues for startup ctx 45 | 46 | ## 0.1.4 (2022-12-22) 47 | - hotfix for locks. 48 | 49 | ## 0.1.3 (2022-12-21) 50 | - add `_lazy_init` param to `KeyDBClient.cachify` to postpone session initialization if the session has not already been configured 51 | - add `_cache_invalidator` param to `KeyDBClient.cachify` to allow for custom cache invalidation logic. If the result is True, the key will be deleted first prior to fetching from the cache. 52 | - add `debug_enabled` param to `KeyDBSettings` to enable debug logging for the `KeyDBClient` session. 53 | 54 | 55 | ## 0.1.2 (2022-12-20) 56 | - add `overwrite` option for `KeyDBClient.configure` to overwrite the default session. 57 | 58 | ## 0.0.12 (2022-12-08) 59 | 60 | - Migration of `aiokeydb.client` -> `aiokeydb.core` and `aiokeydb.asyncio.client` -> `aiokeydb.asyncio.core` 61 | - Unified API available through new `KeyDBClient` class that creates `sessions` which are `KeyDBSession` inherits from `KeyDB` and `AsyncKeyDBClient` class that inherits from `AsyncKeyDB` class 62 | - Implemented `.cachify` method that allows for a caching decorator to be created for a function 63 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/timeseries/info.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.helpers import nativestr 2 | from aiokeydb.v1.commands.timeseries.utils import list_to_dict 3 | 4 | 5 | class TSInfo: 6 | """ 7 | Hold information and statistics on the time-series. 8 | Can be created using ``tsinfo`` command 9 | https://oss.redis.com/redistimeseries/commands/#tsinfo. 10 | """ 11 | 12 | rules = [] 13 | labels = [] 14 | sourceKey = None 15 | chunk_count = None 16 | memory_usage = None 17 | total_samples = None 18 | retention_msecs = None 19 | last_time_stamp = None 20 | first_time_stamp = None 21 | 22 | max_samples_per_chunk = None 23 | chunk_size = None 24 | duplicate_policy = None 25 | 26 | def __init__(self, args): 27 | """ 28 | Hold information and statistics on the time-series. 29 | 30 | The supported params that can be passed as args: 31 | 32 | rules: 33 | A list of compaction rules of the time series. 34 | sourceKey: 35 | Key name for source time series in case the current series 36 | is a target of a rule. 37 | chunkCount: 38 | Number of Memory Chunks used for the time series. 39 | memoryUsage: 40 | Total number of bytes allocated for the time series. 41 | totalSamples: 42 | Total number of samples in the time series. 43 | labels: 44 | A list of label-value pairs that represent the metadata 45 | labels of the time series. 46 | retentionTime: 47 | Retention time, in milliseconds, for the time series. 48 | lastTimestamp: 49 | Last timestamp present in the time series. 50 | firstTimestamp: 51 | First timestamp present in the time series. 52 | maxSamplesPerChunk: 53 | Deprecated. 54 | chunkSize: 55 | Amount of memory, in bytes, allocated for data. 56 | duplicatePolicy: 57 | Policy that will define handling of duplicate samples. 58 | 59 | Can read more about on 60 | https://oss.redis.com/redistimeseries/configuration/#duplicate_policy 61 | """ 62 | response = dict(zip(map(nativestr, args[::2]), args[1::2])) 63 | self.rules = response.get("rules") 64 | self.source_key = response.get("sourceKey") 65 | self.chunk_count = response.get("chunkCount") 66 | self.memory_usage = response.get("memoryUsage") 67 | self.total_samples = response.get("totalSamples") 68 | self.labels = list_to_dict(response.get("labels")) 69 | self.retention_msecs = response.get("retentionTime") 70 | self.last_timestamp = response.get("lastTimestamp") 71 | self.first_timestamp = response.get("firstTimestamp") 72 | if "maxSamplesPerChunk" in response: 73 | self.max_samples_per_chunk = response["maxSamplesPerChunk"] 74 | self.chunk_size = ( 75 | self.max_samples_per_chunk * 16 76 | ) # backward compatible changes 77 | if "chunkSize" in response: 78 | self.chunk_size = response["chunkSize"] 79 | if "duplicatePolicy" in response: 80 | self.duplicate_policy = response["duplicatePolicy"] 81 | if type(self.duplicate_policy) == bytes: 82 | self.duplicate_policy = self.duplicate_policy.decode() 83 | -------------------------------------------------------------------------------- /aiokeydb/v2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.v2.core import KeyDB, AsyncKeyDB 4 | from aiokeydb.v2.cluster import KeyDBCluster, AsyncKeyDBCluster 5 | from aiokeydb.v2.connection import ( 6 | BlockingConnectionPool, 7 | Connection, 8 | ConnectionPool, 9 | SSLConnection, 10 | UnixDomainSocketConnection, 11 | AsyncBlockingConnectionPool, 12 | AsyncConnection, 13 | AsyncConnectionPool, 14 | AsyncSSLConnection, 15 | AsyncUnixDomainSocketConnection, 16 | ) 17 | from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 18 | from aiokeydb.v2.exceptions import ( 19 | AuthenticationError, 20 | AuthenticationWrongNumberOfArgsError, 21 | BusyLoadingError, 22 | ChildDeadlockedError, 23 | ConnectionError, 24 | DataError, 25 | InvalidResponse, 26 | PubSubError, 27 | ReadOnlyError, 28 | ResponseError, 29 | TimeoutError, 30 | WatchError, 31 | JobError, 32 | ) 33 | from aiokeydb.v2.sentinel import ( 34 | Sentinel, 35 | SentinelConnectionPool, 36 | SentinelManagedConnection, 37 | AsyncSentinel, 38 | AsyncSentinelConnectionPool, 39 | AsyncSentinelManagedConnection, 40 | ) 41 | 42 | from aiokeydb.v2.utils import from_url 43 | 44 | # Handle Client 45 | from aiokeydb.v2.serializers import SerializerType 46 | from aiokeydb.v2.configs import KeyDBSettings, KeyDBWorkerSettings, settings 47 | from aiokeydb.v2.types.session import KeyDBSession 48 | from aiokeydb.v2.client import KeyDBClient 49 | 50 | # Handle Queues 51 | from aiokeydb.v2.types.jobs import Job, CronJob 52 | from aiokeydb.v2.types.task_queue import TaskQueue 53 | from aiokeydb.v2.types.worker import Worker 54 | 55 | # Add KeyDB Index Types 56 | from aiokeydb.v2.types.indexes import ( 57 | KDBIndex, 58 | KDBDict, 59 | AsyncKDBDict, 60 | ) 61 | 62 | # Job.update_forward_refs() 63 | 64 | __all__ = [ 65 | "AuthenticationError", 66 | "AuthenticationWrongNumberOfArgsError", 67 | "BlockingConnectionPool", 68 | "BusyLoadingError", 69 | "ChildDeadlockedError", 70 | "Connection", 71 | "ConnectionError", 72 | "ConnectionPool", 73 | "DataError", 74 | "from_url", 75 | "InvalidResponse", 76 | "PubSubError", 77 | "ReadOnlyError", 78 | "KeyDB", 79 | "KeyDBCluster", 80 | # "KeyDBError", 81 | "ResponseError", 82 | "Sentinel", 83 | "SentinelConnectionPool", 84 | "SentinelManagedConnection", 85 | "SentinelManagedSSLConnection", 86 | "SSLConnection", 87 | # "StrictKeyDB", 88 | "TimeoutError", 89 | "UnixDomainSocketConnection", 90 | "WatchError", 91 | "JobError", 92 | "CredentialProvider", 93 | "UsernamePasswordCredentialProvider", 94 | # Async 95 | "AsyncKeyDB", 96 | "AsyncKeyDBCluster", 97 | # "StrictAsyncKeyDB", 98 | "AsyncBlockingConnectionPool", 99 | "AsyncConnection", 100 | "AsyncConnectionPool", 101 | "AsyncSSLConnection", 102 | "AsyncUnixDomainSocketConnection", 103 | "AsyncSentinel", 104 | "AsyncSentinelConnectionPool", 105 | "AsyncSentinelManagedConnection", 106 | # "AsyncSentinelManagedSSLConnection", 107 | 108 | # Client 109 | "SerializerType", 110 | "KeyDBSettings", 111 | "KeyDBWorkerSettings", 112 | "KeyDBSession", 113 | "KeyDBClient", 114 | 115 | # Queues 116 | "TaskQueue", 117 | "Worker", 118 | ] 119 | -------------------------------------------------------------------------------- /aiokeydb/utils/logs.py: -------------------------------------------------------------------------------- 1 | from lazyops.utils.logs import Logger, get_logger 2 | 3 | # , STATUS_COLOR, COLORED_MESSAGE_MAP, FALLBACK_STATUS_COLOR 4 | 5 | 6 | class ColorMap: 7 | green: str = '\033[0;32m' 8 | red: str = '\033[0;31m' 9 | yellow: str = '\033[0;33m' 10 | blue: str = '\033[0;34m' 11 | magenta: str = '\033[0;35m' 12 | cyan: str = '\033[0;36m' 13 | white: str = '\033[0;37m' 14 | bold: str = '\033[1m' 15 | reset: str = '\033[0m' 16 | 17 | 18 | 19 | 20 | # class CustomizeLogger: 21 | 22 | # @classmethod 23 | # def worker_logger_formatter(cls, record: dict) -> str: 24 | # """ 25 | # Formats the log message for the worker. 26 | # """ 27 | # extra = '' 28 | # if record['extra'].get('job_id') and record['extra'].get('queue_name') and record['extra'].get('kind'): 29 | # status = record['extra'].get('status') 30 | # color = STATUS_COLOR.get(status, FALLBACK_STATUS_COLOR) 31 | # kind_color = STATUS_COLOR.get(record.get('extra', {}).get('kind'), FALLBACK_STATUS_COLOR) 32 | # if not record['extra'].get('worker_name'): 33 | # record['extra']['worker_name'] = '' 34 | # extra = f'<{kind_color}>' + '{extra[kind]}:' 35 | # extra += '{extra[worker_name]}:{extra[queue_name]}:' 36 | # extra += '{extra[job_id]}' 37 | # if status: 38 | # extra += f':<{color}>' + '{extra[status]}' 39 | # extra += ': ' 40 | 41 | # elif record['extra'].get('kind') and record['extra'].get('queue_name'): 42 | # if not record['extra'].get('worker_name'): 43 | # record['extra']['worker_name'] = '' 44 | # kind_color = STATUS_COLOR.get(record.get('extra', {}).get('kind'), FALLBACK_STATUS_COLOR) 45 | # extra = f'<{kind_color}>' + '{extra[kind]}:' 46 | # extra += '{extra[worker_name]}:{extra[queue_name]:<18} ' 47 | 48 | # return extra 49 | 50 | 51 | # @classmethod 52 | # def logger_formatter(cls, record: dict) -> str: 53 | # """ 54 | # To add a custom format for a module, add another `elif` clause with code to determine `extra` and `level`. 55 | 56 | # From that module and all submodules, call logger with `logger.bind(foo='bar').info(msg)`. 57 | # Then you can access it with `record['extra'].get('foo')`. 58 | # """ 59 | # extra = '{name}:{function}: ' 60 | 61 | # if record.get('extra'): 62 | # if record['extra'].get('request_id'): 63 | # extra = '{name}:{function}:request_id: {extra[request_id]} ' 64 | 65 | # elif (record['extra'].get('queue_name') or record['extra'].get('worker_name')) and record['extra'].get('kind'): 66 | # extra = cls.worker_logger_formatter(record) 67 | 68 | # if 'result=tensor([' not in str(record['message']): 69 | # return "{level: <8} {time:YYYY-MM-DD HH:mm:ss.SSS}: "\ 70 | # + extra + "{message}\n" 71 | 72 | # msg = str(record['message'])[:100].replace('{', '(').replace('}', ')') 73 | # return "{level: <8} {time:YYYY-MM-DD HH:mm:ss.SSS}: "\ 74 | # + extra + "" + msg + f"{STATUS_COLOR['reset']}\n" 75 | 76 | 77 | logger: Logger = get_logger( 78 | __name__, 79 | # format = CustomizeLogger.logger_formatter, 80 | ) 81 | -------------------------------------------------------------------------------- /aiokeydb/v1/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from typing import Any, Dict, Mapping, Union 4 | 5 | try: 6 | import hiredis # noqa 7 | 8 | HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.") 9 | HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command") 10 | except ImportError: 11 | HIREDIS_AVAILABLE = False 12 | HIREDIS_PACK_AVAILABLE = False 13 | 14 | 15 | try: 16 | import cryptography # noqa 17 | 18 | CRYPTOGRAPHY_AVAILABLE = True 19 | except ImportError: 20 | CRYPTOGRAPHY_AVAILABLE = False 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | def from_url(url, asyncio: bool = False, **kwargs): 25 | """ 26 | Returns an active Redis client generated from the given database URL. 27 | 28 | Will attempt to extract the database id from the path url fragment, if 29 | none is provided. 30 | """ 31 | if asyncio: 32 | from aiokeydb.v1.asyncio.core import AsyncKeyDB 33 | return AsyncKeyDB.from_url(url, **kwargs) 34 | 35 | from aiokeydb.v1.core import KeyDB 36 | return KeyDB.from_url(url, **kwargs) 37 | 38 | 39 | @contextmanager 40 | def pipeline(keydb_obj): 41 | p = keydb_obj.pipeline() 42 | yield p 43 | p.execute() 44 | 45 | 46 | def str_if_bytes(value: Union[str, bytes]) -> str: 47 | return ( 48 | value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value 49 | ) 50 | 51 | 52 | def safe_str(value): 53 | return str(str_if_bytes(value)) 54 | 55 | 56 | def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: 57 | """ 58 | Merge all provided dicts into 1 dict. 59 | *dicts : `dict` 60 | dictionaries to merge 61 | """ 62 | merged = {} 63 | 64 | for d in dicts: 65 | merged.update(d) 66 | 67 | return merged 68 | 69 | 70 | def list_keys_to_dict(key_list, callback): 71 | return dict.fromkeys(key_list, callback) 72 | 73 | 74 | def merge_result(command, res): 75 | """ 76 | Merge all items in `res` into a list. 77 | 78 | This command is used when sending a command to multiple nodes 79 | and the result from each node should be merged into a single list. 80 | 81 | res : 'dict' 82 | """ 83 | result = set() 84 | 85 | for v in res.values(): 86 | for value in v: 87 | result.add(value) 88 | 89 | return list(result) 90 | 91 | 92 | def get_ulimits(): 93 | import resource 94 | soft_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) 95 | return soft_limit 96 | 97 | def set_ulimits(max_connections: int = 500): 98 | """ 99 | Sets the system ulimits 100 | to allow for the maximum number of open connections 101 | 102 | - if the current ulimit > max_connections, then it is ignored 103 | - if it is less, then we set it. 104 | """ 105 | import resource 106 | 107 | soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) 108 | if soft_limit > max_connections: return 109 | if hard_limit < max_connections: 110 | logger.warning(f"The current hard limit ({hard_limit}) is less than max_connections ({max_connections}).") 111 | new_hard_limit = max(hard_limit, max_connections) 112 | logger.info(f"Setting new ulimits to ({soft_limit}, {hard_limit}) -> ({max_connections}, {new_hard_limit})") 113 | resource.setrlimit(resource.RLIMIT_NOFILE, (max_connections + 10, new_hard_limit)) 114 | new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE) 115 | logger.info(f"New Limits: ({new_soft}, {new_hard})") 116 | 117 | -------------------------------------------------------------------------------- /aiokeydb/types/compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Resolver for Pydantic v1/v2 imports with additional helpers 3 | """ 4 | 5 | 6 | """ 7 | Resolver for Pydantic v1/v2 imports with additional helpers 8 | """ 9 | 10 | 11 | import typing 12 | from lazyops.utils.imports import resolve_missing 13 | 14 | 15 | # Handle v1/v2 of pydantic 16 | try: 17 | from pydantic import validator as _validator 18 | from pydantic import model_validator as base_root_validator 19 | 20 | PYD_VERSION = 2 21 | 22 | def root_validator(*args, **kwargs): 23 | """ 24 | v1 Compatible root validator 25 | """ 26 | def decorator(func): 27 | _pre_kw = kwargs.pop('pre', None) 28 | kwargs['mode'] = 'before' if _pre_kw is True else kwargs.get('mode', 'wrap') 29 | return base_root_validator(*args, **kwargs)(func) 30 | 31 | return decorator 32 | 33 | def pre_root_validator(*args, **kwargs): 34 | def decorator(func): 35 | return base_root_validator(*args, mode='before', **kwargs)(func) 36 | return decorator 37 | 38 | def validator(*args, **kwargs): 39 | def decorator(func): 40 | return _validator(*args, **kwargs)(classmethod(func)) 41 | return decorator 42 | 43 | except ImportError: 44 | from pydantic import root_validator, validator 45 | 46 | PYD_VERSION = 1 47 | 48 | def pre_root_validator(*args, **kwargs): 49 | def decorator(func): 50 | return root_validator(*args, pre=True, **kwargs)(func) 51 | return decorator 52 | 53 | 54 | try: 55 | from pydantic_settings import BaseSettings 56 | 57 | except ImportError: 58 | if PYD_VERSION == 2: 59 | resolve_missing('pydantic-settings', required = True) 60 | from pydantic_settings import BaseSettings 61 | else: 62 | from pydantic import BaseSettings 63 | 64 | 65 | from pydantic import BaseModel, Field 66 | from pydantic.fields import FieldInfo 67 | 68 | def get_pyd_dict(model: typing.Union[BaseModel, BaseSettings], **kwargs) -> typing.Dict[str, typing.Any]: 69 | """ 70 | Get a dict from a pydantic model 71 | """ 72 | if kwargs: kwargs = {k:v for k,v in kwargs.items() if v is not None} 73 | return model.model_dump(**kwargs) if PYD_VERSION == 2 else model.dict(**kwargs) 74 | 75 | def get_pyd_fields_dict(model: typing.Type[typing.Union[BaseModel, BaseSettings]]) -> typing.Dict[str, FieldInfo]: 76 | """ 77 | Get a dict of fields from a pydantic model 78 | """ 79 | return model.model_fields if PYD_VERSION == 2 else model.__fields__ 80 | 81 | def get_pyd_field_names(model: typing.Type[typing.Union[BaseModel, BaseSettings]]) -> typing.List[str]: 82 | """ 83 | Get a list of field names from a pydantic model 84 | """ 85 | return list(get_pyd_fields_dict(model).keys()) 86 | 87 | def get_pyd_fields(model: typing.Type[typing.Union[BaseModel, BaseSettings]]) -> typing.List[FieldInfo]: 88 | """ 89 | Get a list of fields from a pydantic model 90 | """ 91 | return list(get_pyd_fields_dict(model).values()) 92 | 93 | def pyd_parse_obj(model: typing.Type[typing.Union[BaseModel, BaseSettings]], obj: typing.Any, **kwargs) -> typing.Union[BaseModel, BaseSettings]: 94 | """ 95 | Parse an object into a pydantic model 96 | """ 97 | return model.model_validate(obj, **kwargs) if PYD_VERSION == 2 else model.parse_obj(obj) 98 | 99 | 100 | def get_pyd_schema(model: typing.Type[typing.Union[BaseModel, BaseSettings]], **kwargs) -> typing.Dict[str, typing.Any]: 101 | """ 102 | Get a pydantic schema 103 | """ 104 | return model.schema(**kwargs) if PYD_VERSION == 2 else model.model_json_schema(**kwargs) 105 | -------------------------------------------------------------------------------- /aiokeydb/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from aiokeydb.core import KeyDB, AsyncKeyDB 4 | from aiokeydb.cluster import KeyDBCluster, AsyncKeyDBCluster 5 | from aiokeydb.connection import ( 6 | BlockingConnectionPool, 7 | Connection, 8 | ConnectionPool, 9 | SSLConnection, 10 | UnixDomainSocketConnection, 11 | AsyncBlockingConnectionPool, 12 | AsyncConnection, 13 | AsyncConnectionPool, 14 | AsyncSSLConnection, 15 | AsyncUnixDomainSocketConnection, 16 | ) 17 | from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider 18 | from aiokeydb.exceptions import ( 19 | AuthenticationError, 20 | AuthenticationWrongNumberOfArgsError, 21 | BusyLoadingError, 22 | ChildDeadlockedError, 23 | ConnectionError, 24 | DataError, 25 | InvalidResponse, 26 | PubSubError, 27 | ReadOnlyError, 28 | ResponseError, 29 | TimeoutError, 30 | WatchError, 31 | JobError, 32 | ) 33 | from aiokeydb.sentinel import ( 34 | Sentinel, 35 | SentinelConnectionPool, 36 | SentinelManagedConnection, 37 | AsyncSentinel, 38 | AsyncSentinelConnectionPool, 39 | AsyncSentinelManagedConnection, 40 | ) 41 | 42 | from aiokeydb.utils.base import from_url 43 | from aiokeydb.utils.lazy import get_keydb_settings 44 | 45 | # Handle Client 46 | from aiokeydb.serializers import SerializerType 47 | from aiokeydb.configs import KeyDBSettings, KeyDBWorkerSettings, settings 48 | from aiokeydb.types.session import KeyDBSession 49 | from aiokeydb.client import KeyDBClient 50 | 51 | # Handle Queues 52 | from aiokeydb.types.jobs import Job, CronJob 53 | from aiokeydb.types.task_queue import TaskQueue 54 | from aiokeydb.types.worker import Worker 55 | 56 | # Add KeyDB Index Types 57 | from aiokeydb.types.indexes import ( 58 | KDBIndex, 59 | KDBDict, 60 | AsyncKDBDict, 61 | ) 62 | 63 | from aiokeydb.version import VERSION as __version__ 64 | 65 | def int_or_str(value): 66 | try: 67 | return int(value) 68 | except ValueError: 69 | return value 70 | 71 | 72 | VERSION = tuple(map(int_or_str, __version__.split("."))) 73 | 74 | # Job.update_forward_refs() 75 | 76 | __all__ = [ 77 | "AuthenticationError", 78 | "AuthenticationWrongNumberOfArgsError", 79 | "BlockingConnectionPool", 80 | "BusyLoadingError", 81 | "ChildDeadlockedError", 82 | "Connection", 83 | "ConnectionError", 84 | "ConnectionPool", 85 | "DataError", 86 | "from_url", 87 | "InvalidResponse", 88 | "PubSubError", 89 | "ReadOnlyError", 90 | "KeyDB", 91 | "KeyDBCluster", 92 | # "KeyDBError", 93 | "ResponseError", 94 | "Sentinel", 95 | "SentinelConnectionPool", 96 | "SentinelManagedConnection", 97 | "SentinelManagedSSLConnection", 98 | "SSLConnection", 99 | # "StrictKeyDB", 100 | "TimeoutError", 101 | "UnixDomainSocketConnection", 102 | "WatchError", 103 | "JobError", 104 | "CredentialProvider", 105 | "UsernamePasswordCredentialProvider", 106 | # Async 107 | "AsyncKeyDB", 108 | "AsyncKeyDBCluster", 109 | # "StrictAsyncKeyDB", 110 | "AsyncBlockingConnectionPool", 111 | "AsyncConnection", 112 | "AsyncConnectionPool", 113 | "AsyncSSLConnection", 114 | "AsyncUnixDomainSocketConnection", 115 | "AsyncSentinel", 116 | "AsyncSentinelConnectionPool", 117 | "AsyncSentinelManagedConnection", 118 | # "AsyncSentinelManagedSSLConnection", 119 | 120 | # Client 121 | "SerializerType", 122 | "KeyDBSettings", 123 | "KeyDBWorkerSettings", 124 | "KeyDBSession", 125 | "KeyDBClient", 126 | 127 | # Queues 128 | "TaskQueue", 129 | "Worker", 130 | ] 131 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/timeseries/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import aiokeydb 4 | 5 | from aiokeydb.v1.commands.helpers import parse_to_list 6 | from aiokeydb.v1.commands.timeseries.commands import ( 7 | ALTER_CMD, 8 | CREATE_CMD, 9 | CREATERULE_CMD, 10 | DEL_CMD, 11 | DELETERULE_CMD, 12 | GET_CMD, 13 | INFO_CMD, 14 | MGET_CMD, 15 | MRANGE_CMD, 16 | MREVRANGE_CMD, 17 | QUERYINDEX_CMD, 18 | RANGE_CMD, 19 | REVRANGE_CMD, 20 | TimeSeriesCommands, 21 | ) 22 | from aiokeydb.v1.commands.timeseries.info import TSInfo 23 | from aiokeydb.v1.commands.timeseries.utils import parse_get, parse_m_get, parse_m_range, parse_range 24 | 25 | 26 | class TimeSeries(TimeSeriesCommands): 27 | """ 28 | This class subclasses redis-py's `Redis` and implements RedisTimeSeries's 29 | commands (prefixed with "ts"). 30 | The client allows to interact with RedisTimeSeries and use all of it's 31 | functionality. 32 | """ 33 | 34 | def __init__(self, client=None, **kwargs): 35 | """Create a new RedisTimeSeries client.""" 36 | # Set the module commands' callbacks 37 | self.MODULE_CALLBACKS = { 38 | CREATE_CMD: aiokeydb.v1.client.bool_ok, 39 | ALTER_CMD: aiokeydb.v1.client.bool_ok, 40 | CREATERULE_CMD: aiokeydb.v1.client.bool_ok, 41 | DEL_CMD: int, 42 | DELETERULE_CMD: aiokeydb.v1.client.bool_ok, 43 | RANGE_CMD: parse_range, 44 | REVRANGE_CMD: parse_range, 45 | MRANGE_CMD: parse_m_range, 46 | MREVRANGE_CMD: parse_m_range, 47 | GET_CMD: parse_get, 48 | MGET_CMD: parse_m_get, 49 | INFO_CMD: TSInfo, 50 | QUERYINDEX_CMD: parse_to_list, 51 | } 52 | 53 | self.client = client 54 | self.execute_command = client.execute_command 55 | 56 | for key, value in self.MODULE_CALLBACKS.items(): 57 | self.client.set_response_callback(key, value) 58 | 59 | def pipeline(self, transaction=True, shard_hint=None): 60 | """Creates a pipeline for the TimeSeries module, that can be used 61 | for executing only TimeSeries commands and core commands. 62 | 63 | Usage example: 64 | 65 | r = redis.Redis() 66 | pipe = r.ts().pipeline() 67 | for i in range(100): 68 | pipeline.add("with_pipeline", i, 1.1 * i) 69 | pipeline.execute() 70 | 71 | """ 72 | if isinstance(self.client, aiokeydb.v1.KeyDBCluster): 73 | p = ClusterPipeline( 74 | nodes_manager=self.client.nodes_manager, 75 | commands_parser=self.client.commands_parser, 76 | startup_nodes=self.client.nodes_manager.startup_nodes, 77 | result_callbacks=self.client.result_callbacks, 78 | cluster_response_callbacks=self.client.cluster_response_callbacks, 79 | cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, 80 | read_from_replicas=self.client.read_from_replicas, 81 | reinitialize_steps=self.client.reinitialize_steps, 82 | lock=self.client._lock, 83 | ) 84 | 85 | else: 86 | p = Pipeline( 87 | connection_pool=self.client.connection_pool, 88 | response_callbacks=self.MODULE_CALLBACKS, 89 | transaction=transaction, 90 | shard_hint=shard_hint, 91 | ) 92 | return p 93 | 94 | 95 | class ClusterPipeline(TimeSeriesCommands, aiokeydb.v1.cluster.ClusterPipeline): 96 | """Cluster pipeline for the module.""" 97 | 98 | 99 | class Pipeline(TimeSeriesCommands, aiokeydb.v1.client.Pipeline): 100 | """Pipeline for the module.""" 101 | -------------------------------------------------------------------------------- /aiokeydb/v1/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backwards compatible import for aiokeydb.v1.v1 3 | """ 4 | 5 | from __future__ import absolute_import 6 | 7 | import sys 8 | 9 | from aiokeydb.v1.core import KeyDB, StrictKeyDB 10 | from aiokeydb.v1.cluster import KeyDBCluster 11 | from aiokeydb.v1.connection import ( 12 | BlockingConnectionPool, 13 | Connection, 14 | ConnectionPool, 15 | SSLConnection, 16 | UnixDomainSocketConnection, 17 | ) 18 | from aiokeydb.v1.credentials import CredentialProvider, UsernamePasswordCredentialProvider 19 | from aiokeydb.v1.exceptions import ( 20 | AuthenticationError, 21 | AuthenticationWrongNumberOfArgsError, 22 | BusyLoadingError, 23 | ChildDeadlockedError, 24 | ConnectionError, 25 | DataError, 26 | InvalidResponse, 27 | PubSubError, 28 | ReadOnlyError, 29 | KeyDBError, 30 | ResponseError, 31 | TimeoutError, 32 | WatchError, 33 | ) 34 | from aiokeydb.v1.sentinel import ( 35 | Sentinel, 36 | SentinelConnectionPool, 37 | SentinelManagedConnection, 38 | SentinelManagedSSLConnection, 39 | ) 40 | from aiokeydb.v1.utils import from_url 41 | 42 | # Handle Async 43 | 44 | from aiokeydb.v1.asyncio import ( 45 | AsyncKeyDB, 46 | StrictAsyncKeyDB, 47 | AsyncBlockingConnectionPool, 48 | AsyncConnection, 49 | AsyncConnectionPool, 50 | AsyncSSLConnection, 51 | AsyncUnixDomainSocketConnection, 52 | AsyncSentinel, 53 | AsyncSentinelConnectionPool, 54 | AsyncSentinelManagedConnection, 55 | AsyncSentinelManagedSSLConnection, 56 | async_from_url 57 | ) 58 | 59 | # Handle Client 60 | 61 | from aiokeydb.v1.client.serializers import SerializerType 62 | from aiokeydb.v1.client.config import KeyDBSettings 63 | from aiokeydb.v1.client.schemas.session import KeyDBSession 64 | from aiokeydb.v1.client.meta import KeyDBClient 65 | 66 | if sys.version_info >= (3, 8): 67 | from importlib import metadata 68 | else: 69 | import importlib_metadata as metadata 70 | 71 | 72 | def int_or_str(value): 73 | try: 74 | return int(value) 75 | except ValueError: 76 | return value 77 | 78 | from aiokeydb.version import VERSION as __version__ 79 | 80 | 81 | # try: 82 | # __version__ = metadata.version("aiokeydb") 83 | # except metadata.PackageNotFoundError: 84 | # __version__ = "99.99.99" 85 | 86 | 87 | VERSION = tuple(map(int_or_str, __version__.split("."))) 88 | 89 | __all__ = [ 90 | "AuthenticationError", 91 | "AuthenticationWrongNumberOfArgsError", 92 | "BlockingConnectionPool", 93 | "BusyLoadingError", 94 | "ChildDeadlockedError", 95 | "Connection", 96 | "ConnectionError", 97 | "ConnectionPool", 98 | "DataError", 99 | "from_url", 100 | "InvalidResponse", 101 | "PubSubError", 102 | "ReadOnlyError", 103 | "KeyDB", 104 | "KeyDBCluster", 105 | "KeyDBError", 106 | "ResponseError", 107 | "Sentinel", 108 | "SentinelConnectionPool", 109 | "SentinelManagedConnection", 110 | "SentinelManagedSSLConnection", 111 | "SSLConnection", 112 | "StrictKeyDB", 113 | "TimeoutError", 114 | "UnixDomainSocketConnection", 115 | "WatchError", 116 | "CredentialProvider", 117 | "UsernamePasswordCredentialProvider", 118 | # Async 119 | "AsyncKeyDB", 120 | "StrictAsyncKeyDB", 121 | "AsyncBlockingConnectionPool", 122 | "AsyncConnection", 123 | "AsyncConnectionPool", 124 | "AsyncSSLConnection", 125 | "AsyncUnixDomainSocketConnection", 126 | "AsyncSentinel", 127 | "AsyncSentinelConnectionPool", 128 | "AsyncSentinelManagedConnection", 129 | "AsyncSentinelManagedSSLConnection", 130 | "async_from_url", 131 | 132 | # Client 133 | "SerializerType", 134 | "KeyDBSettings", 135 | "KeyDBSession", 136 | "KeyDBClient", 137 | ] 138 | -------------------------------------------------------------------------------- /aiokeydb/v1/asyncio/parser.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union 2 | 3 | from aiokeydb.v1.exceptions import KeyDBError, ResponseError 4 | 5 | if TYPE_CHECKING: 6 | from aiokeydb.v1.asyncio.cluster import AsyncClusterNode 7 | 8 | 9 | class CommandsParser: 10 | """ 11 | Parses Redis commands to get command keys. 12 | 13 | COMMAND output is used to determine key locations. 14 | Commands that do not have a predefined key location are flagged with 'movablekeys', 15 | and these commands' keys are determined by the command 'COMMAND GETKEYS'. 16 | 17 | NOTE: Due to a bug in redis<7.0, this does not work properly 18 | for EVAL or EVALSHA when the `numkeys` arg is 0. 19 | - issue: https://github.com/redis/redis/issues/9493 20 | - fix: https://github.com/redis/redis/pull/9733 21 | 22 | So, don't use this with EVAL or EVALSHA. 23 | """ 24 | 25 | __slots__ = ("commands", "node") 26 | 27 | def __init__(self) -> None: 28 | self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} 29 | 30 | async def initialize(self, node: Optional["AsyncClusterNode"] = None) -> None: 31 | if node: 32 | self.node = node 33 | 34 | commands = await self.node.execute_command("COMMAND") 35 | for cmd, command in commands.items(): 36 | if "movablekeys" in command["flags"]: 37 | commands[cmd] = -1 38 | elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: 39 | commands[cmd] = 0 40 | elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: 41 | commands[cmd] = 1 42 | self.commands = {cmd.upper(): command for cmd, command in commands.items()} 43 | 44 | # As soon as this PR is merged into Redis, we should reimplement 45 | # our logic to use COMMAND INFO changes to determine the key positions 46 | # https://github.com/redis/redis/pull/8324 47 | async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: 48 | if len(args) < 2: 49 | # The command has no keys in it 50 | return None 51 | 52 | try: 53 | command = self.commands[args[0]] 54 | except KeyError: 55 | # try to split the command name and to take only the main command 56 | # e.g. 'memory' for 'memory usage' 57 | args = args[0].split() + list(args[1:]) 58 | cmd_name = args[0].upper() 59 | if cmd_name not in self.commands: 60 | # We'll try to reinitialize the commands cache, if the engine 61 | # version has changed, the commands may not be current 62 | await self.initialize() 63 | if cmd_name not in self.commands: 64 | raise KeyDBError( 65 | f"{cmd_name} command doesn't exist in Redis commands" 66 | ) 67 | 68 | command = self.commands[cmd_name] 69 | 70 | if command == 1: 71 | return (args[1],) 72 | if command == 0: 73 | return None 74 | if command == -1: 75 | return await self._get_moveable_keys(*args) 76 | 77 | last_key_pos = command["last_key_pos"] 78 | if last_key_pos < 0: 79 | last_key_pos = len(args) + last_key_pos 80 | return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] 81 | 82 | async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: 83 | try: 84 | keys = await self.node.execute_command("COMMAND GETKEYS", *args) 85 | except ResponseError as e: 86 | message = e.__str__() 87 | if ( 88 | "Invalid arguments" in message 89 | or "The command has no key arguments" in message 90 | ): 91 | return None 92 | else: 93 | raise e 94 | return keys 95 | -------------------------------------------------------------------------------- /aiokeydb/v1/queues/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import uuid 4 | import random 5 | import typing 6 | import traceback 7 | import contextlib 8 | import contextvars 9 | import asyncio 10 | import functools 11 | from concurrent import futures 12 | 13 | if typing.TYPE_CHECKING: 14 | from aiokeydb.v1.client.config import KeyDBSettings 15 | 16 | from lazyops.utils.logs import default_logger as logger 17 | 18 | _QueueSettings: 'KeyDBSettings' = None 19 | _NodeName: str = None 20 | _ThreadPool: futures.ThreadPoolExecutor = None 21 | 22 | def get_settings() -> 'KeyDBSettings': 23 | """ 24 | Lazily initialize the worker settings 25 | """ 26 | global _QueueSettings 27 | if _QueueSettings is None: 28 | from aiokeydb.v1.client.core import KeyDBClient 29 | _QueueSettings = KeyDBClient.get_settings() 30 | return _QueueSettings 31 | 32 | def get_hostname() -> str: 33 | """ 34 | Lazily initialize the worker node name 35 | """ 36 | global _NodeName 37 | if _NodeName is None: 38 | import socket 39 | while _NodeName is None: 40 | with contextlib.suppress(Exception): 41 | _NodeName = socket.gethostname() 42 | time.sleep(0.5) 43 | return _NodeName 44 | 45 | def get_thread_pool( 46 | n_workers: int = None 47 | ) -> futures.ThreadPoolExecutor: 48 | """ 49 | Lazily initialize the worker thread pool 50 | """ 51 | global _ThreadPool 52 | if _ThreadPool is None: 53 | if n_workers is None: 54 | settings = get_settings() 55 | n_workers = settings.worker.threadpool_size 56 | _ThreadPool = futures.ThreadPoolExecutor(max_workers = n_workers) 57 | return _ThreadPool 58 | 59 | 60 | 61 | async def run_in_executor(ctx: typing.Dict[str, typing.Any], func: typing.Callable, *args, **kwargs): 62 | blocking = functools.partial(func, *args, **kwargs) 63 | loop = asyncio.get_running_loop() 64 | return await loop.run_in_executor(ctx['pool'], blocking) 65 | 66 | 67 | def get_and_log_exc(): 68 | error = traceback.format_exc() 69 | logger.error(f'node={get_hostname()}, {error}') 70 | return error 71 | 72 | 73 | def now(): 74 | return int(time.time() * 1000) 75 | 76 | def uuid1(): 77 | return str(uuid.uuid1()) 78 | 79 | def uuid4(): 80 | return str(uuid.uuid4()) 81 | 82 | def millis(s): 83 | return s * 1000 84 | 85 | def seconds(ms): 86 | return ms / 1000 87 | 88 | 89 | def exponential_backoff( 90 | attempts, 91 | base_delay, 92 | max_delay=None, 93 | jitter=True, 94 | ): 95 | """ 96 | Get the next delay for retries in exponential backoff. 97 | 98 | attempts: Number of attempts so far 99 | base_delay: Base delay, in seconds 100 | max_delay: Max delay, in seconds. If None (default), there is no max. 101 | jitter: If True, add a random jitter to the delay 102 | """ 103 | if max_delay is None: 104 | max_delay = float("inf") 105 | backoff = min(max_delay, base_delay * 2 ** max(attempts - 1, 0)) 106 | if jitter: 107 | backoff = backoff * random.random() 108 | return backoff 109 | 110 | 111 | _JobKeyMethod = { 112 | 'uuid1': uuid1, 113 | 'uuid4': uuid4, 114 | } 115 | 116 | _JobKeyFunc: typing.Callable = None 117 | 118 | def _get_jobkey_func(): 119 | global _JobKeyFunc 120 | if _JobKeyFunc is None: 121 | _JobKeyFunc = _JobKeyMethod[get_settings().worker.job_key_method] 122 | return _JobKeyFunc() 123 | 124 | def get_default_job_key(): 125 | return _get_jobkey_func() 126 | 127 | 128 | def ensure_coroutine_function(func): 129 | if asyncio.iscoroutinefunction(func): 130 | return func 131 | 132 | async def wrapped(*args, **kwargs): 133 | loop = asyncio.get_running_loop() 134 | ctx = contextvars.copy_context() 135 | return await loop.run_in_executor( 136 | executor = None, func = lambda: ctx.run(func, *args, **kwargs) 137 | ) 138 | 139 | return wrapped -------------------------------------------------------------------------------- /aiokeydb/v1/commands/sentinel.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | class SentinelCommands: 5 | """ 6 | A class containing the commands specific to redis sentinel. This class is 7 | to be used as a mixin. 8 | """ 9 | 10 | def sentinel(self, *args): 11 | """Redis Sentinel's SENTINEL command.""" 12 | warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) 13 | 14 | def sentinel_get_master_addr_by_name(self, service_name): 15 | """Returns a (host, port) pair for the given ``service_name``""" 16 | return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) 17 | 18 | def sentinel_master(self, service_name): 19 | """Returns a dictionary containing the specified masters state.""" 20 | return self.execute_command("SENTINEL MASTER", service_name) 21 | 22 | def sentinel_masters(self): 23 | """Returns a list of dictionaries containing each master's state.""" 24 | return self.execute_command("SENTINEL MASTERS") 25 | 26 | def sentinel_monitor(self, name, ip, port, quorum): 27 | """Add a new master to Sentinel to be monitored""" 28 | return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) 29 | 30 | def sentinel_remove(self, name): 31 | """Remove a master from Sentinel's monitoring""" 32 | return self.execute_command("SENTINEL REMOVE", name) 33 | 34 | def sentinel_sentinels(self, service_name): 35 | """Returns a list of sentinels for ``service_name``""" 36 | return self.execute_command("SENTINEL SENTINELS", service_name) 37 | 38 | def sentinel_set(self, name, option, value): 39 | """Set Sentinel monitoring parameters for a given master""" 40 | return self.execute_command("SENTINEL SET", name, option, value) 41 | 42 | def sentinel_slaves(self, service_name): 43 | """Returns a list of slaves for ``service_name``""" 44 | return self.execute_command("SENTINEL SLAVES", service_name) 45 | 46 | def sentinel_reset(self, pattern): 47 | """ 48 | This command will reset all the masters with matching name. 49 | The pattern argument is a glob-style pattern. 50 | 51 | The reset process clears any previous state in a master (including a 52 | failover in progress), and removes every slave and sentinel already 53 | discovered and associated with the master. 54 | """ 55 | return self.execute_command("SENTINEL RESET", pattern, once=True) 56 | 57 | def sentinel_failover(self, new_master_name): 58 | """ 59 | Force a failover as if the master was not reachable, and without 60 | asking for agreement to other Sentinels (however a new version of the 61 | configuration will be published so that the other Sentinels will 62 | update their configurations). 63 | """ 64 | return self.execute_command("SENTINEL FAILOVER", new_master_name) 65 | 66 | def sentinel_ckquorum(self, new_master_name): 67 | """ 68 | Check if the current Sentinel configuration is able to reach the 69 | quorum needed to failover a master, and the majority needed to 70 | authorize the failover. 71 | 72 | This command should be used in monitoring systems to check if a 73 | Sentinel deployment is ok. 74 | """ 75 | return self.execute_command("SENTINEL CKQUORUM", new_master_name, once=True) 76 | 77 | def sentinel_flushconfig(self): 78 | """ 79 | Force Sentinel to rewrite its configuration on disk, including the 80 | current Sentinel state. 81 | 82 | Normally Sentinel rewrites the configuration every time something 83 | changes in its state (in the context of the subset of the state which 84 | is persisted on disk across restart). 85 | However sometimes it is possible that the configuration file is lost 86 | because of operation errors, disk failures, package upgrade scripts or 87 | configuration managers. In those cases a way to to force Sentinel to 88 | rewrite the configuration file is handy. 89 | 90 | This command works even if the previous configuration file is 91 | completely missing. 92 | """ 93 | return self.execute_command("SENTINEL FLUSHCONFIG") 94 | 95 | 96 | class AsyncSentinelCommands(SentinelCommands): 97 | async def sentinel(self, *args) -> None: 98 | """Redis Sentinel's SENTINEL command.""" 99 | super().sentinel(*args) 100 | -------------------------------------------------------------------------------- /aiokeydb/v2/serializers/_json.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the JsonSerializer class. 4 | 5 | - uses the json module to serialize and deserialize data 6 | """ 7 | 8 | 9 | import json 10 | import typing 11 | import datetime 12 | import dataclasses 13 | import contextlib 14 | 15 | from aiokeydb.v2.types import BaseSerializer 16 | 17 | try: 18 | import numpy as np 19 | except ImportError: 20 | np = None 21 | 22 | 23 | try: 24 | import orjson 25 | _orjson_avail = True 26 | except ImportError: 27 | orjson = object 28 | _orjson_avail = False 29 | 30 | def object_serializer(obj: typing.Any) -> typing.Any: 31 | if isinstance(obj, dict): 32 | return {k: object_serializer(v) for k, v in obj.items()} 33 | 34 | if isinstance(obj, bytes): 35 | return obj.decode('utf-8') 36 | 37 | if isinstance(obj, (str, list, dict, int, float, bool, type(None))): 38 | return obj 39 | 40 | if dataclasses.is_dataclass(obj): 41 | return dataclasses.asdict(obj) 42 | 43 | if hasattr(obj, 'dict'): # test for pydantic models 44 | return obj.dict() 45 | 46 | if hasattr(obj, 'get_secret_value'): 47 | return obj.get_secret_value() 48 | 49 | if hasattr(obj, 'as_posix'): 50 | return obj.as_posix() 51 | 52 | if hasattr(obj, "numpy"): # Checks for TF tensors without needing the import 53 | return obj.numpy().tolist() 54 | 55 | if hasattr(obj, 'tolist'): # Checks for torch tensors without importing 56 | return obj.tolist() 57 | 58 | if isinstance(obj, (datetime.date, datetime.datetime)): 59 | return obj.isoformat() 60 | 61 | if np is not None: 62 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): 63 | return int(obj) 64 | 65 | if isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 66 | return float(obj) 67 | 68 | else: 69 | # Try to convert to a primitive type 70 | with contextlib.suppress(Exception): 71 | return int(obj) 72 | with contextlib.suppress(Exception): 73 | return float(obj) 74 | 75 | raise TypeError(f"Object of type {type(obj)} is not JSON serializable") 76 | 77 | 78 | class ObjectEncoder(json.JSONEncoder): 79 | 80 | def default(self, obj: typing.Any): # pylint: disable=arguments-differ,method-hidden 81 | with contextlib.suppress(Exception): 82 | return object_serializer(obj) 83 | return json.JSONEncoder.default(self, obj) 84 | 85 | 86 | class JsonSerializer(BaseSerializer): 87 | 88 | @staticmethod 89 | def dumps( 90 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 91 | *args, 92 | default: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any] = None, 93 | cls: typing.Type[json.JSONEncoder] = ObjectEncoder, 94 | **kwargs 95 | ) -> str: 96 | return json.dumps(obj, *args, default = default, cls = cls, **kwargs) 97 | 98 | @staticmethod 99 | def loads( 100 | data: typing.Union[str, bytes], 101 | *args, 102 | **kwargs 103 | ) -> typing.Union[typing.Dict[typing.Any, typing.Any], typing.List[str], typing.Any]: 104 | return json.loads(data, *args, **kwargs) 105 | 106 | if _orjson_avail: 107 | 108 | class OrJsonSerializer(BaseSerializer): 109 | 110 | @staticmethod 111 | def dumps( 112 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 113 | *args, 114 | default: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any] = None, 115 | _serializer: typing.Callable = object_serializer, 116 | **kwargs 117 | ) -> str: 118 | """ 119 | We encode the data first using the object_serializer function 120 | """ 121 | if _serializer: # pragma: no cover 122 | obj = _serializer(obj) 123 | return orjson.dumps(obj, default=default, *args, **kwargs).decode() 124 | 125 | @staticmethod 126 | def loads( 127 | data: typing.Union[str, bytes], 128 | *args, 129 | **kwargs 130 | ) -> typing.Union[typing.Dict[typing.Any, typing.Any], typing.List[str], typing.Any]: 131 | return orjson.loads(data, *args, **kwargs) 132 | 133 | else: 134 | OrJsonSerializer = JsonSerializer 135 | 136 | 137 | -------------------------------------------------------------------------------- /aiokeydb/serializers/_json.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the JsonSerializer class. 4 | 5 | - uses the json module to serialize and deserialize data 6 | """ 7 | 8 | 9 | import json 10 | import typing 11 | import datetime 12 | import dataclasses 13 | import contextlib 14 | 15 | from aiokeydb.types.serializer import BaseSerializer 16 | 17 | try: 18 | import numpy as np 19 | except ImportError: 20 | np = None 21 | 22 | 23 | try: 24 | import orjson 25 | _orjson_avail = True 26 | except ImportError: 27 | orjson = object 28 | _orjson_avail = False 29 | 30 | def object_serializer(obj: typing.Any) -> typing.Any: 31 | if isinstance(obj, dict): 32 | return {k: object_serializer(v) for k, v in obj.items()} 33 | 34 | if isinstance(obj, bytes): 35 | return obj.decode('utf-8') 36 | 37 | if isinstance(obj, (str, list, dict, int, float, bool, type(None))): 38 | return obj 39 | 40 | if dataclasses.is_dataclass(obj): 41 | return dataclasses.asdict(obj) 42 | 43 | if hasattr(obj, 'dict'): # test for pydantic models 44 | return obj.dict() 45 | 46 | if hasattr(obj, 'get_secret_value'): 47 | return obj.get_secret_value() 48 | 49 | if hasattr(obj, 'as_posix'): 50 | return obj.as_posix() 51 | 52 | if hasattr(obj, "numpy"): # Checks for TF tensors without needing the import 53 | return obj.numpy().tolist() 54 | 55 | if hasattr(obj, 'tolist'): # Checks for torch tensors without importing 56 | return obj.tolist() 57 | 58 | if isinstance(obj, (datetime.date, datetime.datetime)): 59 | return obj.isoformat() 60 | 61 | if np is not None: 62 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): 63 | return int(obj) 64 | 65 | if isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 66 | return float(obj) 67 | 68 | else: 69 | # Try to convert to a primitive type 70 | with contextlib.suppress(Exception): 71 | return int(obj) 72 | with contextlib.suppress(Exception): 73 | return float(obj) 74 | 75 | raise TypeError(f"Object of type {type(obj)} is not JSON serializable") 76 | 77 | 78 | class ObjectEncoder(json.JSONEncoder): 79 | 80 | def default(self, obj: typing.Any): # pylint: disable=arguments-differ,method-hidden 81 | with contextlib.suppress(Exception): 82 | return object_serializer(obj) 83 | return json.JSONEncoder.default(self, obj) 84 | 85 | 86 | class JsonSerializer(BaseSerializer): 87 | 88 | @staticmethod 89 | def dumps( 90 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 91 | *args, 92 | default: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any] = None, 93 | cls: typing.Type[json.JSONEncoder] = ObjectEncoder, 94 | **kwargs 95 | ) -> str: 96 | return json.dumps(obj, *args, default = default, cls = cls, **kwargs) 97 | 98 | @staticmethod 99 | def loads( 100 | data: typing.Union[str, bytes], 101 | *args, 102 | **kwargs 103 | ) -> typing.Union[typing.Dict[typing.Any, typing.Any], typing.List[str], typing.Any]: 104 | return json.loads(data, *args, **kwargs) 105 | 106 | if _orjson_avail: 107 | 108 | class OrJsonSerializer(BaseSerializer): 109 | 110 | @staticmethod 111 | def dumps( 112 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 113 | *args, 114 | default: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any] = None, 115 | _serializer: typing.Callable = object_serializer, 116 | **kwargs 117 | ) -> str: 118 | """ 119 | We encode the data first using the object_serializer function 120 | """ 121 | if _serializer: # pragma: no cover 122 | obj = _serializer(obj) 123 | return orjson.dumps(obj, default=default, *args, **kwargs).decode() 124 | 125 | @staticmethod 126 | def loads( 127 | data: typing.Union[str, bytes], 128 | *args, 129 | **kwargs 130 | ) -> typing.Union[typing.Dict[typing.Any, typing.Any], typing.List[str], typing.Any]: 131 | return orjson.loads(data, *args, **kwargs) 132 | 133 | else: 134 | OrJsonSerializer = JsonSerializer 135 | 136 | 137 | -------------------------------------------------------------------------------- /aiokeydb/v1/client/serializers/_json.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the JsonSerializer class. 4 | 5 | - uses the json module to serialize and deserialize data 6 | """ 7 | 8 | 9 | import json 10 | import typing 11 | import datetime 12 | import dataclasses 13 | import contextlib 14 | 15 | from aiokeydb.v1.client.serializers.base import BaseSerializer 16 | 17 | try: 18 | import numpy as np 19 | except ImportError: 20 | np = None 21 | 22 | 23 | try: 24 | import orjson 25 | _orjson_avail = True 26 | except ImportError: 27 | orjson = object 28 | _orjson_avail = False 29 | 30 | def object_serializer(obj: typing.Any) -> typing.Any: 31 | if isinstance(obj, dict): 32 | return {k: object_serializer(v) for k, v in obj.items()} 33 | 34 | if isinstance(obj, bytes): 35 | return obj.decode('utf-8') 36 | 37 | if isinstance(obj, (str, list, dict, int, float, bool, type(None))): 38 | return obj 39 | 40 | if dataclasses.is_dataclass(obj): 41 | return dataclasses.asdict(obj) 42 | 43 | if hasattr(obj, 'dict'): # test for pydantic models 44 | return obj.dict() 45 | 46 | if hasattr(obj, 'get_secret_value'): 47 | return obj.get_secret_value() 48 | 49 | if hasattr(obj, 'as_posix'): 50 | return obj.as_posix() 51 | 52 | if hasattr(obj, "numpy"): # Checks for TF tensors without needing the import 53 | return obj.numpy().tolist() 54 | 55 | if hasattr(obj, 'tolist'): # Checks for torch tensors without importing 56 | return obj.tolist() 57 | 58 | if isinstance(obj, (datetime.date, datetime.datetime)): 59 | return obj.isoformat() 60 | 61 | if np is not None: 62 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): 63 | return int(obj) 64 | 65 | if isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 66 | return float(obj) 67 | 68 | else: 69 | # Try to convert to a primitive type 70 | with contextlib.suppress(Exception): 71 | return int(obj) 72 | with contextlib.suppress(Exception): 73 | return float(obj) 74 | 75 | raise TypeError(f"Object of type {type(obj)} is not JSON serializable") 76 | 77 | 78 | class ObjectEncoder(json.JSONEncoder): 79 | 80 | def default(self, obj: typing.Any): # pylint: disable=arguments-differ,method-hidden 81 | with contextlib.suppress(Exception): 82 | return object_serializer(obj) 83 | return json.JSONEncoder.default(self, obj) 84 | 85 | 86 | class JsonSerializer(BaseSerializer): 87 | 88 | @staticmethod 89 | def dumps( 90 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 91 | *args, 92 | default: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any] = None, 93 | cls: typing.Type[json.JSONEncoder] = ObjectEncoder, 94 | **kwargs 95 | ) -> str: 96 | return json.dumps(obj, *args, default = default, cls = cls, **kwargs) 97 | 98 | @staticmethod 99 | def loads( 100 | data: typing.Union[str, bytes], 101 | *args, 102 | **kwargs 103 | ) -> typing.Union[typing.Dict[typing.Any, typing.Any], typing.List[str], typing.Any]: 104 | return json.loads(data, *args, **kwargs) 105 | 106 | if _orjson_avail: 107 | 108 | class OrJsonSerializer(BaseSerializer): 109 | 110 | @staticmethod 111 | def dumps( 112 | obj: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], 113 | *args, 114 | default: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any] = None, 115 | _serializer: typing.Callable = object_serializer, 116 | **kwargs 117 | ) -> str: 118 | """ 119 | We encode the data first using the object_serializer function 120 | """ 121 | if _serializer: # pragma: no cover 122 | obj = _serializer(obj) 123 | return orjson.dumps(obj, default=default, *args, **kwargs).decode() 124 | 125 | @staticmethod 126 | def loads( 127 | data: typing.Union[str, bytes], 128 | *args, 129 | **kwargs 130 | ) -> typing.Union[typing.Dict[typing.Any, typing.Any], typing.List[str], typing.Any]: 131 | return orjson.loads(data, *args, **kwargs) 132 | 133 | else: 134 | OrJsonSerializer = JsonSerializer 135 | 136 | 137 | -------------------------------------------------------------------------------- /aiokeydb/serializers/_pickle.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implements the PickleSerializer class. 4 | 5 | - uses the pickle module to serialize and deserialize data 6 | - will use `dill` if it is installed 7 | """ 8 | import sys 9 | import pickle 10 | import typing 11 | import binascii 12 | import contextlib 13 | from io import BytesIO 14 | from aiokeydb.types.serializer import BaseSerializer 15 | from pickle import DEFAULT_PROTOCOL, Pickler, Unpickler 16 | 17 | 18 | if sys.version_info.minor < 8: 19 | with contextlib.suppress(ImportError): 20 | import pickle5 as pickle 21 | 22 | try: 23 | import dill 24 | from dill import DEFAULT_PROTOCOL as DILL_DEFAULT_PROTOCOL 25 | _dill_avail = True 26 | except ImportError: 27 | dill = object 28 | _dill_avail = False 29 | 30 | 31 | class DefaultProtocols: 32 | default: int = 4 33 | pickle: int = pickle.HIGHEST_PROTOCOL 34 | dill: int = dill.HIGHEST_PROTOCOL 35 | 36 | class PickleSerializer(BaseSerializer): 37 | 38 | @staticmethod 39 | def dumps(obj: typing.Any, protocol: int = DefaultProtocols.pickle, *args, **kwargs) -> bytes: 40 | return pickle.dumps(obj, protocol = protocol, *args, **kwargs) 41 | 42 | @staticmethod 43 | def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any: 44 | return pickle.loads(data, *args, **kwargs) 45 | 46 | class PickleSerializerv2(BaseSerializer): 47 | 48 | @staticmethod 49 | def dumps(obj: typing.Any, protocol: int = DEFAULT_PROTOCOL, *args, **kwargs) -> str: 50 | """ 51 | v2 Encoding 52 | """ 53 | f = BytesIO() 54 | p = Pickler(f, protocol = protocol) 55 | p.dump(obj) 56 | return f.getvalue().hex() 57 | 58 | @staticmethod 59 | def loads(data: typing.Union[str, typing.Any], *args, **kwargs) -> typing.Any: 60 | """ 61 | V2 Decoding 62 | """ 63 | return Unpickler(BytesIO(binascii.unhexlify(data))).load() 64 | 65 | if _dill_avail: 66 | from dill import Pickler as DillPickler, Unpickler as DillUnpickler 67 | 68 | class DillSerializer(BaseSerializer): 69 | 70 | @staticmethod 71 | def dumps(obj: typing.Any, protocol: int = DefaultProtocols.dill, *args, **kwargs) -> bytes: 72 | return dill.dumps(obj, protocol = protocol, *args, **kwargs) 73 | 74 | @staticmethod 75 | def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any: 76 | return dill.loads(data, *args, **kwargs) 77 | 78 | class DillSerializerv2(BaseSerializer): 79 | 80 | @staticmethod 81 | def dumps(obj: typing.Any, protocol: int = DILL_DEFAULT_PROTOCOL, *args, **kwargs) -> str: 82 | """ 83 | v2 Encoding 84 | """ 85 | f = BytesIO() 86 | p = DillPickler(f, protocol = protocol) 87 | p.dump(obj) 88 | return f.getvalue().hex() 89 | 90 | @staticmethod 91 | def loads(data: typing.Union[str, typing.Any], *args, **kwargs) -> typing.Any: 92 | """ 93 | V2 Decoding 94 | """ 95 | return DillUnpickler(BytesIO(binascii.unhexlify(data))).load() 96 | 97 | else: 98 | DillSerializer = PickleSerializer 99 | DillSerializerv2 = PickleSerializerv2 100 | 101 | try: 102 | import cloudpickle 103 | from types import ModuleType 104 | 105 | class CloudPickleSerializer(BaseSerializer): 106 | 107 | @staticmethod 108 | def dumps(obj: typing.Any, protocol: int = cloudpickle.DEFAULT_PROTOCOL, *args, **kwargs) -> bytes: 109 | """ 110 | Dumps an object to bytes 111 | """ 112 | return cloudpickle.dumps(obj, protocol = protocol, *args, **kwargs) 113 | 114 | @staticmethod 115 | def loads(data: typing.Union[str, bytes, typing.Any], *args, **kwargs) -> typing.Any: 116 | """ 117 | Loads an object from bytes 118 | """ 119 | return cloudpickle.loads(data, *args, **kwargs) 120 | 121 | @staticmethod 122 | def register_module(module: ModuleType): 123 | """ 124 | Registers a module with cloudpickle 125 | """ 126 | cloudpickle.register_pickle_by_value(module) 127 | 128 | @staticmethod 129 | def unregister_module(module: ModuleType): 130 | """ 131 | Registers a class with cloudpickle 132 | """ 133 | cloudpickle.unregister_pickle_by_value(module) 134 | 135 | 136 | except ImportError: 137 | CloudPickleSerializer = PickleSerializer -------------------------------------------------------------------------------- /aiokeydb/v1/commands/search/reducers.py: -------------------------------------------------------------------------------- 1 | from aiokeydb.v1.commands.search.aggregation import Reducer, SortDirection 2 | 3 | 4 | class FieldOnlyReducer(Reducer): 5 | def __init__(self, field): 6 | super().__init__(field) 7 | self._field = field 8 | 9 | 10 | class count(Reducer): 11 | """ 12 | Counts the number of results in the group 13 | """ 14 | 15 | NAME = "COUNT" 16 | 17 | def __init__(self): 18 | super().__init__() 19 | 20 | 21 | class sum(FieldOnlyReducer): 22 | """ 23 | Calculates the sum of all the values in the given fields within the group 24 | """ 25 | 26 | NAME = "SUM" 27 | 28 | def __init__(self, field): 29 | super().__init__(field) 30 | 31 | 32 | class min(FieldOnlyReducer): 33 | """ 34 | Calculates the smallest value in the given field within the group 35 | """ 36 | 37 | NAME = "MIN" 38 | 39 | def __init__(self, field): 40 | super().__init__(field) 41 | 42 | 43 | class max(FieldOnlyReducer): 44 | """ 45 | Calculates the largest value in the given field within the group 46 | """ 47 | 48 | NAME = "MAX" 49 | 50 | def __init__(self, field): 51 | super().__init__(field) 52 | 53 | 54 | class avg(FieldOnlyReducer): 55 | """ 56 | Calculates the mean value in the given field within the group 57 | """ 58 | 59 | NAME = "AVG" 60 | 61 | def __init__(self, field): 62 | super().__init__(field) 63 | 64 | 65 | class tolist(FieldOnlyReducer): 66 | """ 67 | Returns all the matched properties in a list 68 | """ 69 | 70 | NAME = "TOLIST" 71 | 72 | def __init__(self, field): 73 | super().__init__(field) 74 | 75 | 76 | class count_distinct(FieldOnlyReducer): 77 | """ 78 | Calculate the number of distinct values contained in all the results in 79 | the group for the given field 80 | """ 81 | 82 | NAME = "COUNT_DISTINCT" 83 | 84 | def __init__(self, field): 85 | super().__init__(field) 86 | 87 | 88 | class count_distinctish(FieldOnlyReducer): 89 | """ 90 | Calculate the number of distinct values contained in all the results in the 91 | group for the given field. This uses a faster algorithm than 92 | `count_distinct` but is less accurate 93 | """ 94 | 95 | NAME = "COUNT_DISTINCTISH" 96 | 97 | 98 | class quantile(Reducer): 99 | """ 100 | Return the value for the nth percentile within the range of values for the 101 | field within the group. 102 | """ 103 | 104 | NAME = "QUANTILE" 105 | 106 | def __init__(self, field, pct): 107 | super().__init__(field, str(pct)) 108 | self._field = field 109 | 110 | 111 | class stddev(FieldOnlyReducer): 112 | """ 113 | Return the standard deviation for the values within the group 114 | """ 115 | 116 | NAME = "STDDEV" 117 | 118 | def __init__(self, field): 119 | super().__init__(field) 120 | 121 | 122 | class first_value(Reducer): 123 | """ 124 | Selects the first value within the group according to sorting parameters 125 | """ 126 | 127 | NAME = "FIRST_VALUE" 128 | 129 | def __init__(self, field, *byfields): 130 | """ 131 | Selects the first value of the given field within the group. 132 | 133 | ### Parameter 134 | 135 | - **field**: Source field used for the value 136 | - **byfields**: How to sort the results. This can be either the 137 | *class* of `aggregation.Asc` or `aggregation.Desc` in which 138 | case the field `field` is also used as the sort input. 139 | 140 | `byfields` can also be one or more *instances* of `Asc` or `Desc` 141 | indicating the sort order for these fields 142 | """ 143 | 144 | fieldstrs = [] 145 | if ( 146 | len(byfields) == 1 147 | and isinstance(byfields[0], type) 148 | and issubclass(byfields[0], SortDirection) 149 | ): 150 | byfields = [byfields[0](field)] 151 | 152 | for f in byfields: 153 | fieldstrs += [f.field, f.DIRSTRING] 154 | 155 | args = [field] 156 | if fieldstrs: 157 | args += ["BY"] + fieldstrs 158 | super().__init__(*args) 159 | self._field = field 160 | 161 | 162 | class random_sample(Reducer): 163 | """ 164 | Returns a random sample of items from the dataset, from the given property 165 | """ 166 | 167 | NAME = "RANDOM_SAMPLE" 168 | 169 | def __init__(self, field, size): 170 | """ 171 | ### Parameter 172 | 173 | **field**: Field to sample from 174 | **size**: Return this many items (can be less) 175 | """ 176 | args = [field, str(size)] 177 | super().__init__(*args) 178 | self._field = field 179 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/helpers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import string 4 | from typing import List, Tuple 5 | 6 | from aiokeydb.v1.typing import KeysT, KeyT 7 | 8 | 9 | def list_or_args(keys: KeysT, args: Tuple[KeyT, ...]) -> List[KeyT]: 10 | # returns a single new list combining keys and args 11 | try: 12 | iter(keys) 13 | # a string or bytes instance can be iterated, but indicates 14 | # keys wasn't passed as a list 15 | if isinstance(keys, (bytes, str)): 16 | keys = [keys] 17 | else: 18 | keys = list(keys) 19 | except TypeError: 20 | keys = [keys] 21 | if args: 22 | keys.extend(args) 23 | return keys 24 | 25 | 26 | def nativestr(x): 27 | """Return the decoded binary string, or a string, depending on type.""" 28 | r = x.decode("utf-8", "replace") if isinstance(x, bytes) else x 29 | if r == "null": 30 | return 31 | return r 32 | 33 | 34 | def delist(x): 35 | """Given a list of binaries, return the stringified version.""" 36 | if x is None: 37 | return x 38 | return [nativestr(obj) for obj in x] 39 | 40 | 41 | def parse_to_list(response): 42 | """Optimistically parse the response to a list.""" 43 | res = [] 44 | 45 | if response is None: 46 | return res 47 | 48 | for item in response: 49 | try: 50 | res.append(int(item)) 51 | except ValueError: 52 | try: 53 | res.append(float(item)) 54 | except ValueError: 55 | res.append(nativestr(item)) 56 | except TypeError: 57 | res.append(None) 58 | return res 59 | 60 | 61 | def parse_list_to_dict(response): 62 | res = {} 63 | for i in range(0, len(response), 2): 64 | if isinstance(response[i], list): 65 | res["Child iterators"].append(parse_list_to_dict(response[i])) 66 | elif isinstance(response[i + 1], list): 67 | res["Child iterators"] = [parse_list_to_dict(response[i + 1])] 68 | else: 69 | try: 70 | res[response[i]] = float(response[i + 1]) 71 | except (TypeError, ValueError): 72 | res[response[i]] = response[i + 1] 73 | return res 74 | 75 | 76 | def parse_to_dict(response): 77 | if response is None: 78 | return {} 79 | 80 | res = {} 81 | for det in response: 82 | if isinstance(det[1], list): 83 | res[det[0]] = parse_list_to_dict(det[1]) 84 | else: 85 | try: # try to set the attribute. may be provided without value 86 | try: # try to convert the value to float 87 | res[det[0]] = float(det[1]) 88 | except (TypeError, ValueError): 89 | res[det[0]] = det[1] 90 | except IndexError: 91 | pass 92 | return res 93 | 94 | 95 | def random_string(length=10): 96 | """ 97 | Returns a random N character long string. 98 | """ 99 | return "".join( # nosec 100 | random.choice(string.ascii_lowercase) for x in range(length) 101 | ) 102 | 103 | 104 | def quote_string(v): 105 | """ 106 | RedisGraph strings must be quoted, 107 | quote_string wraps given v with quotes incase 108 | v is a string. 109 | """ 110 | 111 | if isinstance(v, bytes): 112 | v = v.decode() 113 | elif not isinstance(v, str): 114 | return v 115 | if len(v) == 0: 116 | return '""' 117 | 118 | v = v.replace('"', '\\"') 119 | 120 | return f'"{v}"' 121 | 122 | 123 | def decode_dict_keys(obj): 124 | """Decode the keys of the given dictionary with utf-8.""" 125 | newobj = copy.copy(obj) 126 | for k in obj.keys(): 127 | if isinstance(k, bytes): 128 | newobj[k.decode("utf-8")] = newobj[k] 129 | newobj.pop(k) 130 | return newobj 131 | 132 | 133 | def stringify_param_value(value): 134 | """ 135 | Turn a parameter value into a string suitable for the params header of 136 | a Cypher command. 137 | You may pass any value that would be accepted by `json.dumps()`. 138 | 139 | Ways in which output differs from that of `str()`: 140 | * Strings are quoted. 141 | * None --> "null". 142 | * In dictionaries, keys are _not_ quoted. 143 | 144 | :param value: The parameter value to be turned into a string. 145 | :return: string 146 | """ 147 | 148 | if isinstance(value, str): 149 | return quote_string(value) 150 | elif value is None: 151 | return "null" 152 | elif isinstance(value, (list, tuple)): 153 | return f'[{",".join(map(stringify_param_value, value))}]' 154 | elif isinstance(value, dict): 155 | return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa 156 | else: 157 | return str(value) 158 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/search/field.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from aiokeydb.v1.exceptions import DataError 4 | 5 | 6 | class Field: 7 | 8 | NUMERIC = "NUMERIC" 9 | TEXT = "TEXT" 10 | WEIGHT = "WEIGHT" 11 | GEO = "GEO" 12 | TAG = "TAG" 13 | VECTOR = "VECTOR" 14 | SORTABLE = "SORTABLE" 15 | NOINDEX = "NOINDEX" 16 | AS = "AS" 17 | 18 | def __init__( 19 | self, 20 | name: str, 21 | args: List[str] = None, 22 | sortable: bool = False, 23 | no_index: bool = False, 24 | as_name: str = None, 25 | ): 26 | if args is None: 27 | args = [] 28 | self.name = name 29 | self.args = args 30 | self.args_suffix = list() 31 | self.as_name = as_name 32 | 33 | if sortable: 34 | self.args_suffix.append(Field.SORTABLE) 35 | if no_index: 36 | self.args_suffix.append(Field.NOINDEX) 37 | 38 | if no_index and not sortable: 39 | raise ValueError("Non-Sortable non-Indexable fields are ignored") 40 | 41 | def append_arg(self, value): 42 | self.args.append(value) 43 | 44 | def keydb_args(self): 45 | args = [self.name] 46 | if self.as_name: 47 | args += [self.AS, self.as_name] 48 | args += self.args 49 | args += self.args_suffix 50 | return args 51 | 52 | 53 | class TextField(Field): 54 | """ 55 | TextField is used to define a text field in a schema definition 56 | """ 57 | 58 | NOSTEM = "NOSTEM" 59 | PHONETIC = "PHONETIC" 60 | 61 | def __init__( 62 | self, 63 | name: str, 64 | weight: float = 1.0, 65 | no_stem: bool = False, 66 | phonetic_matcher: str = None, 67 | **kwargs, 68 | ): 69 | Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs) 70 | 71 | if no_stem: 72 | Field.append_arg(self, self.NOSTEM) 73 | if phonetic_matcher and phonetic_matcher in [ 74 | "dm:en", 75 | "dm:fr", 76 | "dm:pt", 77 | "dm:es", 78 | ]: 79 | Field.append_arg(self, self.PHONETIC) 80 | Field.append_arg(self, phonetic_matcher) 81 | 82 | 83 | class NumericField(Field): 84 | """ 85 | NumericField is used to define a numeric field in a schema definition 86 | """ 87 | 88 | def __init__(self, name: str, **kwargs): 89 | Field.__init__(self, name, args=[Field.NUMERIC], **kwargs) 90 | 91 | 92 | class GeoField(Field): 93 | """ 94 | GeoField is used to define a geo-indexing field in a schema definition 95 | """ 96 | 97 | def __init__(self, name: str, **kwargs): 98 | Field.__init__(self, name, args=[Field.GEO], **kwargs) 99 | 100 | 101 | class TagField(Field): 102 | """ 103 | TagField is a tag-indexing field with simpler compression and tokenization. 104 | See http://redisearch.io/Tags/ 105 | """ 106 | 107 | SEPARATOR = "SEPARATOR" 108 | CASESENSITIVE = "CASESENSITIVE" 109 | 110 | def __init__( 111 | self, name: str, separator: str = ",", case_sensitive: bool = False, **kwargs 112 | ): 113 | args = [Field.TAG, self.SEPARATOR, separator] 114 | if case_sensitive: 115 | args.append(self.CASESENSITIVE) 116 | 117 | Field.__init__(self, name, args=args, **kwargs) 118 | 119 | 120 | class VectorField(Field): 121 | """ 122 | Allows vector similarity queries against the value in this attribute. 123 | See https://oss.redis.com/redisearch/Vectors/#vector_fields. 124 | """ 125 | 126 | def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): 127 | """ 128 | Create Vector Field. Notice that Vector cannot have sortable or no_index tag, 129 | although it's also a Field. 130 | 131 | ``name`` is the name of the field. 132 | 133 | ``algorithm`` can be "FLAT" or "HNSW". 134 | 135 | ``attributes`` each algorithm can have specific attributes. Some of them 136 | are mandatory and some of them are optional. See 137 | https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm 138 | for more information. 139 | """ 140 | sort = kwargs.get("sortable", False) 141 | noindex = kwargs.get("no_index", False) 142 | 143 | if sort or noindex: 144 | raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.") 145 | 146 | if algorithm.upper() not in ["FLAT", "HNSW"]: 147 | raise DataError( 148 | "Realtime vector indexing supporting 2 Indexing Methods:" 149 | "'FLAT' and 'HNSW'." 150 | ) 151 | 152 | attr_li = [] 153 | 154 | for key, value in attributes.items(): 155 | attr_li.extend([key, value]) 156 | 157 | Field.__init__( 158 | self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs 159 | ) 160 | -------------------------------------------------------------------------------- /aiokeydb/v1/commands/json/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from json import JSONDecodeError, JSONDecoder, JSONEncoder 4 | 5 | import aiokeydb 6 | 7 | from aiokeydb.v1.commands.helpers import nativestr 8 | from aiokeydb.v1.commands.json.commands import JSONCommands 9 | from aiokeydb.v1.commands.json.decoders import bulk_of_jsons, decode_list 10 | 11 | 12 | class JSON(JSONCommands): 13 | """ 14 | Create a client for talking to json. 15 | 16 | :param decoder: 17 | :type json.JSONDecoder: An instance of json.JSONDecoder 18 | 19 | :param encoder: 20 | :type json.JSONEncoder: An instance of json.JSONEncoder 21 | """ 22 | 23 | def __init__( 24 | self, client, version=None, decoder=JSONDecoder(), encoder=JSONEncoder() 25 | ): 26 | """ 27 | Create a client for talking to json. 28 | 29 | :param decoder: 30 | :type json.JSONDecoder: An instance of json.JSONDecoder 31 | 32 | :param encoder: 33 | :type json.JSONEncoder: An instance of json.JSONEncoder 34 | """ 35 | # Set the module commands' callbacks 36 | self.MODULE_CALLBACKS = { 37 | "JSON.CLEAR": int, 38 | "JSON.DEL": int, 39 | "JSON.FORGET": int, 40 | "JSON.GET": self._decode, 41 | "JSON.MGET": bulk_of_jsons(self._decode), 42 | "JSON.SET": lambda r: r and nativestr(r) == "OK", 43 | "JSON.NUMINCRBY": self._decode, 44 | "JSON.NUMMULTBY": self._decode, 45 | "JSON.TOGGLE": self._decode, 46 | "JSON.STRAPPEND": self._decode, 47 | "JSON.STRLEN": self._decode, 48 | "JSON.ARRAPPEND": self._decode, 49 | "JSON.ARRINDEX": self._decode, 50 | "JSON.ARRINSERT": self._decode, 51 | "JSON.ARRLEN": self._decode, 52 | "JSON.ARRPOP": self._decode, 53 | "JSON.ARRTRIM": self._decode, 54 | "JSON.OBJLEN": self._decode, 55 | "JSON.OBJKEYS": self._decode, 56 | "JSON.RESP": self._decode, 57 | "JSON.DEBUG": self._decode, 58 | } 59 | 60 | self.client = client 61 | self.execute_command = client.execute_command 62 | self.MODULE_VERSION = version 63 | 64 | for key, value in self.MODULE_CALLBACKS.items(): 65 | self.client.set_response_callback(key, value) 66 | 67 | self.__encoder__ = encoder 68 | self.__decoder__ = decoder 69 | 70 | def _decode(self, obj): 71 | """Get the decoder.""" 72 | if obj is None: 73 | return obj 74 | 75 | try: 76 | x = self.__decoder__.decode(obj) 77 | if x is None: 78 | raise TypeError 79 | return x 80 | except TypeError: 81 | try: 82 | return self.__decoder__.decode(obj.decode()) 83 | except AttributeError: 84 | return decode_list(obj) 85 | except (AttributeError, JSONDecodeError): 86 | return decode_list(obj) 87 | 88 | def _encode(self, obj): 89 | """Get the encoder.""" 90 | return self.__encoder__.encode(obj) 91 | 92 | def pipeline(self, transaction=True, shard_hint=None): 93 | """Creates a pipeline for the JSON module, that can be used for executing 94 | JSON commands, as well as classic core commands. 95 | 96 | Usage example: 97 | 98 | r = redis.Redis() 99 | pipe = r.json().pipeline() 100 | pipe.jsonset('foo', '.', {'hello!': 'world'}) 101 | pipe.jsonget('foo') 102 | pipe.jsonget('notakey') 103 | """ 104 | if isinstance(self.client, aiokeydb.v1.KeyDBCluster): 105 | p = ClusterPipeline( 106 | nodes_manager=self.client.nodes_manager, 107 | commands_parser=self.client.commands_parser, 108 | startup_nodes=self.client.nodes_manager.startup_nodes, 109 | result_callbacks=self.client.result_callbacks, 110 | cluster_response_callbacks=self.client.cluster_response_callbacks, 111 | cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, 112 | read_from_replicas=self.client.read_from_replicas, 113 | reinitialize_steps=self.client.reinitialize_steps, 114 | lock=self.client._lock, 115 | ) 116 | 117 | else: 118 | p = Pipeline( 119 | connection_pool=self.client.connection_pool, 120 | response_callbacks=self.MODULE_CALLBACKS, 121 | transaction=transaction, 122 | shard_hint=shard_hint, 123 | ) 124 | 125 | p._encode = self._encode 126 | p._decode = self._decode 127 | return p 128 | 129 | 130 | class ClusterPipeline(JSONCommands, aiokeydb.v1.cluster.ClusterPipeline): 131 | """Cluster pipeline for the module.""" 132 | 133 | 134 | class Pipeline(JSONCommands, aiokeydb.v1.client.Pipeline): 135 | """Pipeline for the module.""" 136 | -------------------------------------------------------------------------------- /aiokeydb/utils/base.py: -------------------------------------------------------------------------------- 1 | 2 | import hashlib 3 | import inspect 4 | import logging 5 | from contextlib import contextmanager, asynccontextmanager 6 | from typing import Union, Optional, List, Callable, Generator, AsyncGenerator, TYPE_CHECKING 7 | from aiokeydb.types import ENOVAL 8 | 9 | try: 10 | import hiredis # noqa 11 | 12 | # Only support Hiredis >= 1.0: 13 | HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.") 14 | HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command") 15 | except ImportError: 16 | HIREDIS_AVAILABLE = False 17 | HIREDIS_PACK_AVAILABLE = False 18 | 19 | try: 20 | import cryptography # noqa 21 | 22 | CRYPTOGRAPHY_AVAILABLE = True 23 | except ImportError: 24 | CRYPTOGRAPHY_AVAILABLE = False 25 | 26 | from redis.utils import ( 27 | str_if_bytes, 28 | safe_str, 29 | dict_merge, 30 | list_keys_to_dict, 31 | merge_result, 32 | ) 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | if TYPE_CHECKING: 37 | from aiokeydb.client import KeyDB, AsyncKeyDB, Pipeline, AsyncPipeline 38 | 39 | def from_url(url, asyncio: bool = False, _is_async: Optional[bool] = None, **kwargs) -> Union["KeyDB", "AsyncKeyDB"]: 40 | """ 41 | Returns an active Redis client generated from the given database URL. 42 | 43 | Will attempt to extract the database id from the path url fragment, if 44 | none is provided. 45 | """ 46 | _is_async = asyncio if _is_async is None else _is_async 47 | 48 | if _is_async: 49 | from aiokeydb.client import AsyncKeyDB 50 | return AsyncKeyDB.from_url(url, **kwargs) 51 | 52 | from aiokeydb.client import KeyDB 53 | return KeyDB.from_url(url, **kwargs) 54 | 55 | 56 | 57 | @contextmanager 58 | def pipeline(keydb_obj: "KeyDB") -> Generator["Pipeline", None, None]: 59 | p = keydb_obj.pipeline() 60 | yield p 61 | p.execute() 62 | del p 63 | 64 | @asynccontextmanager 65 | async def async_pipeline(keydb_obj: "AsyncKeyDB") -> AsyncGenerator["AsyncPipeline", None]: 66 | p = keydb_obj.pipeline() 67 | yield p 68 | await p.execute() 69 | del p 70 | 71 | def get_ulimits(): 72 | import resource 73 | soft_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) 74 | return soft_limit 75 | 76 | def set_ulimits( 77 | max_connections: int = 500, 78 | verbose: bool = False, 79 | ): 80 | """ 81 | Sets the system ulimits 82 | to allow for the maximum number of open connections 83 | 84 | - if the current ulimit > max_connections, then it is ignored 85 | - if it is less, then we set it. 86 | """ 87 | import resource 88 | 89 | soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) 90 | if soft_limit > max_connections: return 91 | if hard_limit < max_connections and verbose: 92 | logger.warning(f"The current hard limit ({hard_limit}) is less than max_connections ({max_connections}).") 93 | new_hard_limit = max(hard_limit, max_connections) 94 | if verbose: logger.info(f"Setting new ulimits to ({soft_limit}, {hard_limit}) -> ({max_connections}, {new_hard_limit})") 95 | resource.setrlimit(resource.RLIMIT_NOFILE, (max_connections + 10, new_hard_limit)) 96 | new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE) 97 | if verbose: logger.info(f"New Limits: ({new_soft}, {new_hard})") 98 | 99 | 100 | def full_name(func, follow_wrapper_chains=True): 101 | """ 102 | Return full name of `func` by adding the module and function name. 103 | 104 | If this function is decorated, attempt to unwrap it till the original function to use that 105 | function name by setting `follow_wrapper_chains` to True. 106 | """ 107 | if follow_wrapper_chains: func = inspect.unwrap(func) 108 | return f'{func.__module__}.{func.__qualname__}' 109 | 110 | 111 | def args_to_key( 112 | base, 113 | args: Optional[tuple] = None, 114 | kwargs: Optional[dict] = None, 115 | typed: bool = False, 116 | exclude: Optional[List[str]] = None 117 | ): 118 | """Create cache key out of function arguments. 119 | :param tuple base: base of key 120 | :param tuple args: function arguments 121 | :param dict kwargs: function keyword arguments 122 | :param bool typed: include types in cache key 123 | :return: cache key tuple 124 | """ 125 | key = base + args 126 | 127 | if kwargs: 128 | if exclude: kwargs = {k: v for k, v in kwargs.items() if k not in exclude} 129 | key += (ENOVAL,) 130 | sorted_items = sorted(kwargs.items()) 131 | 132 | for item in sorted_items: 133 | key += item 134 | 135 | if typed: 136 | key += tuple(type(arg) for arg in args) 137 | if kwargs: key += tuple(type(value) for _, value in sorted_items) 138 | 139 | cache_key = ':'.join(str(k) for k in key) 140 | return hashlib.md5(cache_key.encode()).hexdigest() 141 | 142 | 143 | def import_string(func: str) -> Callable: 144 | """Import a function from a string.""" 145 | module, func = func.rsplit('.', 1) 146 | return getattr(__import__(module, fromlist=[func]), func) 147 | 148 | -------------------------------------------------------------------------------- /aiokeydb/v2/utils/base.py: -------------------------------------------------------------------------------- 1 | 2 | import hashlib 3 | import inspect 4 | import logging 5 | from contextlib import contextmanager, asynccontextmanager 6 | from typing import Union, Optional, List, Callable, Generator, AsyncGenerator, TYPE_CHECKING 7 | from aiokeydb.v2.types import ENOVAL 8 | 9 | try: 10 | import hiredis # noqa 11 | 12 | # Only support Hiredis >= 1.0: 13 | HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.") 14 | HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command") 15 | except ImportError: 16 | HIREDIS_AVAILABLE = False 17 | HIREDIS_PACK_AVAILABLE = False 18 | 19 | try: 20 | import cryptography # noqa 21 | 22 | CRYPTOGRAPHY_AVAILABLE = True 23 | except ImportError: 24 | CRYPTOGRAPHY_AVAILABLE = False 25 | 26 | from redis.utils import ( 27 | str_if_bytes, 28 | safe_str, 29 | dict_merge, 30 | list_keys_to_dict, 31 | merge_result, 32 | ) 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | if TYPE_CHECKING: 37 | from aiokeydb.client import KeyDB, AsyncKeyDB, Pipeline, AsyncPipeline 38 | 39 | def from_url(url, asyncio: bool = False, _is_async: Optional[bool] = None, **kwargs) -> Union["KeyDB", "AsyncKeyDB"]: 40 | """ 41 | Returns an active Redis client generated from the given database URL. 42 | 43 | Will attempt to extract the database id from the path url fragment, if 44 | none is provided. 45 | """ 46 | _is_async = asyncio if _is_async is None else _is_async 47 | 48 | if _is_async: 49 | from aiokeydb.client import AsyncKeyDB 50 | return AsyncKeyDB.from_url(url, **kwargs) 51 | 52 | from aiokeydb.client import KeyDB 53 | return KeyDB.from_url(url, **kwargs) 54 | 55 | 56 | 57 | @contextmanager 58 | def pipeline(keydb_obj: "KeyDB") -> Generator["Pipeline", None, None]: 59 | p = keydb_obj.pipeline() 60 | yield p 61 | p.execute() 62 | del p 63 | 64 | @asynccontextmanager 65 | async def async_pipeline(keydb_obj: "AsyncKeyDB") -> AsyncGenerator["AsyncPipeline", None]: 66 | p = keydb_obj.pipeline() 67 | yield p 68 | await p.execute() 69 | del p 70 | 71 | def get_ulimits(): 72 | import resource 73 | soft_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) 74 | return soft_limit 75 | 76 | def set_ulimits( 77 | max_connections: int = 500, 78 | verbose: bool = False, 79 | ): 80 | """ 81 | Sets the system ulimits 82 | to allow for the maximum number of open connections 83 | 84 | - if the current ulimit > max_connections, then it is ignored 85 | - if it is less, then we set it. 86 | """ 87 | import resource 88 | 89 | soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) 90 | if soft_limit > max_connections: return 91 | if hard_limit < max_connections and verbose: 92 | logger.warning(f"The current hard limit ({hard_limit}) is less than max_connections ({max_connections}).") 93 | new_hard_limit = max(hard_limit, max_connections) 94 | if verbose: logger.info(f"Setting new ulimits to ({soft_limit}, {hard_limit}) -> ({max_connections}, {new_hard_limit})") 95 | resource.setrlimit(resource.RLIMIT_NOFILE, (max_connections + 10, new_hard_limit)) 96 | new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE) 97 | if verbose: logger.info(f"New Limits: ({new_soft}, {new_hard})") 98 | 99 | 100 | def full_name(func, follow_wrapper_chains=True): 101 | """ 102 | Return full name of `func` by adding the module and function name. 103 | 104 | If this function is decorated, attempt to unwrap it till the original function to use that 105 | function name by setting `follow_wrapper_chains` to True. 106 | """ 107 | if follow_wrapper_chains: func = inspect.unwrap(func) 108 | return f'{func.__module__}.{func.__qualname__}' 109 | 110 | 111 | def args_to_key( 112 | base, 113 | args: Optional[tuple] = None, 114 | kwargs: Optional[dict] = None, 115 | typed: bool = False, 116 | exclude: Optional[List[str]] = None 117 | ): 118 | """Create cache key out of function arguments. 119 | :param tuple base: base of key 120 | :param tuple args: function arguments 121 | :param dict kwargs: function keyword arguments 122 | :param bool typed: include types in cache key 123 | :return: cache key tuple 124 | """ 125 | key = base + args 126 | 127 | if kwargs: 128 | if exclude: kwargs = {k: v for k, v in kwargs.items() if k not in exclude} 129 | key += (ENOVAL,) 130 | sorted_items = sorted(kwargs.items()) 131 | 132 | for item in sorted_items: 133 | key += item 134 | 135 | if typed: 136 | key += tuple(type(arg) for arg in args) 137 | if kwargs: key += tuple(type(value) for _, value in sorted_items) 138 | 139 | cache_key = ':'.join(str(k) for k in key) 140 | return hashlib.md5(cache_key.encode()).hexdigest() 141 | 142 | 143 | def import_string(func: str) -> Callable: 144 | """Import a function from a string.""" 145 | module, func = func.rsplit('.', 1) 146 | return getattr(__import__(module, fromlist=[func]), func) 147 | 148 | -------------------------------------------------------------------------------- /tests/test_suite.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import asyncio 4 | import uuid 5 | from aiokeydb import KeyDBClient 6 | 7 | # The session can be explicitly initialized, or 8 | # will be lazily initialized on first use 9 | # through environment variables with all 10 | # params being prefixed with `KEYDB_` 11 | 12 | keydb_uri = "keydb://localhost:6379/0" 13 | 14 | # Initialize the Unified Client 15 | KeyDBClient.init_session( 16 | uri = keydb_uri, 17 | ) 18 | 19 | # Cache the results of these functions 20 | # cachify works for both sync and async functions 21 | # and has many params to customize the caching behavior 22 | # and supports both `redis` and `keydb` backends 23 | # as well as `api` frameworks such as `fastapi` and `starlette` 24 | 25 | @KeyDBClient.cachify() 26 | async def async_fibonacci(number: int): 27 | if number == 0: return 0 28 | elif number == 1: return 1 29 | return await async_fibonacci(number - 1) + await async_fibonacci(number - 2) 30 | 31 | @KeyDBClient.cachify() 32 | def fibonacci(number: int): 33 | if number == 0: return 0 34 | elif number == 1: return 1 35 | return fibonacci(number - 1) + fibonacci(number - 2) 36 | 37 | async def test_fib(n: int = 100, runs: int = 10): 38 | # Test that both results are the same. 39 | sync_t, async_t = 0.0, 0.0 40 | 41 | for i in range(runs): 42 | t = time.time() 43 | print(f'[Async - {i}/{runs}] Result: {await async_fibonacci(n)}') 44 | tt = time.time() - t 45 | print(f'[Async - {i}/{runs}] Time: {tt:.2f}s') 46 | async_t += tt 47 | 48 | t = time.time() 49 | print(f'[Sync - {i}/{runs}] Result: {fibonacci(n)}') 50 | tt = time.time() - t 51 | print(f'[Sync - {i}/{runs}] Time: {tt:.2f}s') 52 | sync_t += tt 53 | 54 | print(f'[Async] Cache Average Time: {async_t / runs:.2f}s | Total Time: {async_t:.2f}s') 55 | print(f'[Sync] Cache Average Time: {sync_t / runs:.2f}s | Total Time: {sync_t:.2f}s') 56 | 57 | async def test_setget(runs: int = 10): 58 | # By default, the client utilizes `pickle` to serialize 59 | # and deserialize objects. This can be changed by setting 60 | # the `serializer` 61 | 62 | sync_t, async_t = 0.0, 0.0 63 | for i in range(runs): 64 | value = str(uuid.uuid4()) 65 | key = f'async-test-{i}' 66 | t = time.time() 67 | await KeyDBClient.async_set(key, value) 68 | assert await KeyDBClient.async_get(key) == value 69 | tt = time.time() - t 70 | print(f'[Async - {i}/{runs}] Get/Set: {key} -> {value} = {tt:.2f}s') 71 | async_t += tt 72 | 73 | value = str(uuid.uuid4()) 74 | key = f'sync-test-{i}' 75 | t = time.time() 76 | KeyDBClient.set(key, value) 77 | assert KeyDBClient.get(key) == value 78 | tt = time.time() - t 79 | print(f'[Sync - {i}/{runs}] Get/Set: {key} -> {value} = {tt:.2f}s') 80 | sync_t += tt 81 | 82 | print(f'[Async] GetSet Average Time: {async_t / runs:.2f}s | Total Time: {async_t:.2f}s') 83 | print(f'[Sync] GetSet Average Time: {sync_t / runs:.2f}s | Total Time: {sync_t:.2f}s') 84 | 85 | 86 | async def run_tests(fib_n: int = 100, fib_runs: int = 10, setget_runs: int = 10): 87 | 88 | # You can explicitly wait for the client to be ready 89 | # Sync version 90 | # KeyDBClient.wait_for_ready() 91 | await KeyDBClient.async_wait_for_ready() 92 | 93 | # Run the tests 94 | await test_fib(n = fib_n, runs = fib_runs) 95 | await test_setget(runs = setget_runs) 96 | 97 | 98 | # Utilize the current session 99 | await KeyDBClient.async_set('async_test_0', 'test') 100 | assert await KeyDBClient.async_get('async_test_0') == 'test' 101 | 102 | KeyDBClient.set('sync_test_0', 'test') 103 | assert KeyDBClient.get('sync_test_0') == 'test' 104 | 105 | 106 | # you can access the `KeyDBSession` object directly 107 | # which mirrors the APIs in `KeyDBClient` 108 | 109 | await KeyDBClient.session.async_set('async_test_1', 'test') 110 | assert await KeyDBClient.session.async_get('async_test_1') == 'test' 111 | 112 | KeyDBClient.session.set('sync_test_1', 'test') 113 | assert KeyDBClient.session.get('sync_test_1') == 'test' 114 | 115 | # The underlying client can be accessed directly 116 | # if the desired api methods aren't mirrored 117 | 118 | # KeyDBClient.keydb 119 | # KeyDBClient.async_keydb 120 | # Since encoding / decoding is not handled by the client 121 | # you must encode / decode the data yourself 122 | await KeyDBClient.async_keydb.set('async_test_2', b'test') 123 | assert await KeyDBClient.async_keydb.get('async_test_2') == b'test' 124 | 125 | KeyDBClient.keydb.set('sync_test_2', b'test') 126 | assert KeyDBClient.keydb.get('sync_test_2') == b'test' 127 | 128 | # You can also explicitly close the client 129 | # However, this closes the connectionpool and will terminate 130 | # all connections. This is not recommended unless you are 131 | # explicitly closing the client. 132 | 133 | # Sync version 134 | # KeyDBClient.close() 135 | await KeyDBClient.aclose() 136 | 137 | asyncio.run(run_tests()) -------------------------------------------------------------------------------- /aiokeydb/v2/utils/helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import logging 5 | from tenacity import retry, wait_exponential, stop_after_delay, before_sleep_log, retry_unless_exception_type, retry_if_exception_type, retry_if_exception 6 | from typing import Optional, Union, Tuple, Type, TYPE_CHECKING 7 | 8 | if TYPE_CHECKING: 9 | from aiokeydb.v2.core import KeyDB, AsyncKeyDB 10 | from tenacity import WrappedFn 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | _excluded_funcs = ['parse_response', 'ping'] 16 | 17 | class retry_if_type(retry_if_exception): 18 | """ 19 | Retries if the exception is of the given type 20 | """ 21 | def __init__( 22 | self, 23 | exception_types: Union[ 24 | Type[BaseException], 25 | Tuple[Type[BaseException], ...], 26 | ] = Exception, 27 | excluded_types: Union[ 28 | Type[BaseException], 29 | Tuple[Type[BaseException], ...], 30 | ] = None, 31 | ): 32 | self.exception_types = exception_types 33 | self.excluded_types = excluded_types 34 | 35 | super().__init__( 36 | lambda e: self.validate_exception(e) 37 | ) 38 | # super().__init__( 39 | # lambda e: isinstance(e, exception_types) and not isinstance(e, excluded_types) 40 | # ) 41 | 42 | def validate_exception(self, e: BaseException) -> bool: 43 | # Exclude ping by default 44 | if e.args and e.args[0] == 'PING': 45 | print('EXCLUDED PING') 46 | return False 47 | return isinstance(e, self.exception_types) and not isinstance(e, self.excluded_types) 48 | 49 | def get_retryable_wrapper( 50 | max_attempts: int = 15, 51 | max_delay: int = 60, 52 | logging_level: int = logging.DEBUG, 53 | **kwargs, 54 | ) -> 'WrappedFn': 55 | 56 | """ 57 | Creates a retryable decorator 58 | """ 59 | from aiokeydb import exceptions as aiokeydb_exceptions 60 | from redis import exceptions as redis_exceptions 61 | return retry( 62 | wait = wait_exponential(multiplier = 0.5, min = 1, max = max_attempts), 63 | stop = stop_after_delay(max_delay), 64 | before_sleep = before_sleep_log(logger, logging_level), 65 | retry = retry_if_type( 66 | exception_types = ( 67 | aiokeydb_exceptions.ConnectionError, 68 | aiokeydb_exceptions.TimeoutError, 69 | aiokeydb_exceptions.BusyLoadingError, 70 | redis_exceptions.ConnectionError, 71 | redis_exceptions.TimeoutError, 72 | redis_exceptions.BusyLoadingError, 73 | ), 74 | excluded_types = ( 75 | aiokeydb_exceptions.AuthenticationError, 76 | aiokeydb_exceptions.AuthorizationError, 77 | aiokeydb_exceptions.InvalidResponse, 78 | aiokeydb_exceptions.ResponseError, 79 | aiokeydb_exceptions.NoScriptError, 80 | redis_exceptions.AuthenticationError, 81 | redis_exceptions.AuthorizationError, 82 | redis_exceptions.InvalidResponse, 83 | redis_exceptions.ResponseError, 84 | redis_exceptions.NoScriptError, 85 | ) 86 | ) 87 | ) 88 | 89 | 90 | def create_retryable_client( 91 | client: Type[Union[KeyDB, AsyncKeyDB]], 92 | 93 | max_attempts: int = 15, 94 | max_delay: int = 60, 95 | logging_level: int = logging.DEBUG, 96 | 97 | verbose: Optional[bool] = False, 98 | **kwargs 99 | ) -> Type[Union[KeyDB, AsyncKeyDB]]: 100 | """ 101 | Creates a retryable client 102 | """ 103 | if hasattr(client, '_is_retryable_wrapped'): return client 104 | decorator = get_retryable_wrapper( 105 | max_attempts = max_attempts, 106 | max_delay = max_delay, 107 | logging_level = logging_level, 108 | **kwargs 109 | ) 110 | for attr in dir(client): 111 | if attr.startswith('_'): continue 112 | if attr in _excluded_funcs: continue 113 | attr_val = getattr(client, attr) 114 | if inspect.isfunction(attr_val) or inspect.iscoroutinefunction(attr_val): 115 | if verbose: logger.info(f'Wrapping {attr} with retryable decorator') 116 | setattr(client, attr, decorator(attr_val)) 117 | 118 | setattr(client, '_is_retryable_wrapped', True) 119 | return client 120 | 121 | 122 | 123 | 124 | 125 | # def wrap_retryable( 126 | # sess: 'KeyDBSession' 127 | # ): 128 | # """ 129 | # Wraps the keydb-session with a retryable decorator 130 | # """ 131 | # if hasattr(sess, '_is_retryable_wrapped'): return sess 132 | # decorator = retry( 133 | # wait = wait_exponential(multiplier=0.5, min=1, max=15), 134 | # stop = stop_after_delay(60), 135 | # before_sleep = before_sleep_log(logger, logging.INFO) 136 | # ) 137 | # for attr in dir(sess): 138 | # attr_val = getattr(sess, attr) 139 | # if inspect.isfunction(attr_val) or inspect.iscoroutinefunction(attr_val): 140 | # logger.info(f'Wrapping {attr} with retryable decorator') 141 | # setattr(sess, attr, decorator(attr_val)) 142 | 143 | # setattr(sess, '_is_retryable_wrapped', True) 144 | # return sess 145 | -------------------------------------------------------------------------------- /aiokeydb/v1/exceptions.py: -------------------------------------------------------------------------------- 1 | "Core exceptions raised by the KeyDB client" 2 | import asyncio 3 | import builtins 4 | 5 | 6 | class KeyDBError(Exception): 7 | pass 8 | 9 | 10 | class ConnectionError(KeyDBError): 11 | pass 12 | 13 | 14 | class TimeoutError(KeyDBError): 15 | pass 16 | 17 | 18 | class AuthenticationError(ConnectionError): 19 | pass 20 | 21 | 22 | class AuthorizationError(ConnectionError): 23 | pass 24 | 25 | 26 | class BusyLoadingError(ConnectionError): 27 | pass 28 | 29 | 30 | class InvalidResponse(KeyDBError): 31 | pass 32 | 33 | 34 | class ResponseError(KeyDBError): 35 | pass 36 | 37 | 38 | class DataError(KeyDBError): 39 | pass 40 | 41 | 42 | class PubSubError(KeyDBError): 43 | pass 44 | 45 | 46 | class WatchError(KeyDBError): 47 | pass 48 | 49 | 50 | class NoScriptError(ResponseError): 51 | pass 52 | 53 | 54 | class ExecAbortError(ResponseError): 55 | pass 56 | 57 | 58 | class ReadOnlyError(ResponseError): 59 | pass 60 | 61 | 62 | class NoPermissionError(ResponseError): 63 | pass 64 | 65 | 66 | class ModuleError(ResponseError): 67 | pass 68 | 69 | 70 | class LockError(KeyDBError, ValueError): 71 | "Errors acquiring or releasing a lock" 72 | # NOTE: For backwards compatibility, this class derives from ValueError. 73 | # This was originally chosen to behave like threading.Lock. 74 | pass 75 | 76 | 77 | class LockNotOwnedError(LockError): 78 | "Error trying to extend or release a lock that is (no longer) owned" 79 | pass 80 | 81 | 82 | class ChildDeadlockedError(Exception): 83 | "Error indicating that a child process is deadlocked after a fork()" 84 | pass 85 | 86 | 87 | class AuthenticationWrongNumberOfArgsError(ResponseError): 88 | """ 89 | An error to indicate that the wrong number of args 90 | were sent to the AUTH command 91 | """ 92 | 93 | pass 94 | 95 | 96 | class KeyDBClusterException(Exception): 97 | """ 98 | Base exception for the KeyDBCluster client 99 | """ 100 | 101 | pass 102 | 103 | 104 | class ClusterError(KeyDBError): 105 | """ 106 | Cluster errors occurred multiple times, resulting in an exhaustion of the 107 | command execution TTL 108 | """ 109 | 110 | pass 111 | 112 | 113 | class ClusterDownError(ClusterError, ResponseError): 114 | """ 115 | Error indicated CLUSTERDOWN error received from cluster. 116 | By default KeyDB Cluster nodes stop accepting queries if they detect there 117 | is at least a hash slot uncovered (no available node is serving it). 118 | This way if the cluster is partially down (for example a range of hash 119 | slots are no longer covered) the entire cluster eventually becomes 120 | unavailable. It automatically returns available as soon as all the slots 121 | are covered again. 122 | """ 123 | 124 | def __init__(self, resp): 125 | self.args = (resp,) 126 | self.message = resp 127 | 128 | 129 | class AskError(ResponseError): 130 | """ 131 | Error indicated ASK error received from cluster. 132 | When a slot is set as MIGRATING, the node will accept all queries that 133 | pertain to this hash slot, but only if the key in question exists, 134 | otherwise the query is forwarded using a -ASK redirection to the node that 135 | is target of the migration. 136 | src node: MIGRATING to dst node 137 | get > ASK error 138 | ask dst node > ASKING command 139 | dst node: IMPORTING from src node 140 | asking command only affects next command 141 | any op will be allowed after asking command 142 | """ 143 | 144 | def __init__(self, resp): 145 | """should only redirect to master node""" 146 | self.args = (resp,) 147 | self.message = resp 148 | slot_id, new_node = resp.split(" ") 149 | host, port = new_node.rsplit(":", 1) 150 | self.slot_id = int(slot_id) 151 | self.node_addr = self.host, self.port = host, int(port) 152 | 153 | 154 | class TryAgainError(ResponseError): 155 | """ 156 | Error indicated TRYAGAIN error received from cluster. 157 | Operations on keys that don't exist or are - during resharding - split 158 | between the source and destination nodes, will generate a -TRYAGAIN error. 159 | """ 160 | 161 | def __init__(self, *args, **kwargs): 162 | pass 163 | 164 | 165 | class ClusterCrossSlotError(ResponseError): 166 | """ 167 | Error indicated CROSSSLOT error received from cluster. 168 | A CROSSSLOT error is generated when keys in a request don't hash to the 169 | same slot. 170 | """ 171 | 172 | message = "Keys in request don't hash to the same slot" 173 | 174 | 175 | class MovedError(AskError): 176 | """ 177 | Error indicated MOVED error received from cluster. 178 | A request sent to a node that doesn't serve this key will be replayed with 179 | a MOVED error that points to the correct node. 180 | """ 181 | 182 | pass 183 | 184 | 185 | class MasterDownError(ClusterDownError): 186 | """ 187 | Error indicated MASTERDOWN error received from cluster. 188 | Link with MASTER is down and replica-serve-stale-data is set to 'no'. 189 | """ 190 | 191 | pass 192 | 193 | 194 | class SlotNotCoveredError(KeyDBClusterException): 195 | """ 196 | This error only happens in the case where the connection pool will try to 197 | fetch what node that is covered by a given slot. 198 | 199 | If this error is raised the client should drop the current node layout and 200 | attempt to reconnect and refresh the node layout again 201 | """ 202 | 203 | pass 204 | 205 | 206 | class MaxConnectionsError(ConnectionError): 207 | ... 208 | --------------------------------------------------------------------------------