├── requirements.txt ├── tests ├── __init__.py ├── registry_tests.py ├── memory_tests.py ├── redis_tests.py └── ttldict_tests.py ├── fastapi_cache ├── backends │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── ttldict.py │ ├── base.py │ ├── memory.py │ └── redis.py ├── version.py ├── __init__.py └── registry.py ├── requirements-test.txt ├── pytest.ini ├── tox.ini ├── .travis.yml ├── setup.py ├── LICENSE ├── README.md └── .gitignore /requirements.txt: -------------------------------------------------------------------------------- 1 | . -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastapi_cache/backends/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastapi_cache/backends/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest==6.2.1 2 | pytest-cov==2.10.0 3 | pytest-asyncio==0.14.0 4 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files = tests.py test_*.py *_tests.py *_test.py 3 | addopts = --doctest-modules -------------------------------------------------------------------------------- /fastapi_cache/version.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 1, 0) 2 | 3 | __author__ = 'Ivan Sushkov ' 4 | __version__ = '.'.join(str(x) for x in VERSION) 5 | -------------------------------------------------------------------------------- /fastapi_cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import CacheRegistry 2 | 3 | caches = CacheRegistry 4 | 5 | 6 | async def close_caches(): 7 | for cache in caches.all(): 8 | await cache.close() 9 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py36,py37,py38 3 | 4 | [testenv] 5 | deps = 6 | -r requirements.txt 7 | -r requirements-test.txt 8 | 9 | commands = pytest \ 10 | --cov=fastapi_cache/ \ 11 | --cov-config="{toxinidir}/tox.ini" \ 12 | --cov-append -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "3.6" 5 | - "3.7" 6 | - "3.8" 7 | 8 | dist: xenial 9 | 10 | services: 11 | - redis-server 12 | 13 | install: 14 | - pip install tox-travis 15 | 16 | script: 17 | - tox 18 | 19 | notifications: 20 | email: 21 | on_success: never 22 | on_failure: never -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | from setuptools import find_packages, setup 4 | 5 | from fastapi_cache.version import __author__, __version__ 6 | 7 | setup( 8 | name='fastapi-cache', 9 | version=__version__, 10 | description='FastAPI simple cache', 11 | author=__author__, 12 | url='https://github.com/comeuplater/fastapi_cache', 13 | packages=find_packages(exclude=('tests',)), 14 | long_description=open(join(dirname(__file__), 'README.md')).read(), 15 | long_description_content_type='text/markdown', 16 | license='MIT License', 17 | keywords=[ 18 | 'redis', 'aioredis', 'asyncio', 'fastapi', 'starlette', 'cache' 19 | ], 20 | install_requires=[ 21 | 'aioredis==1.3.1', 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /fastapi_cache/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Union 2 | 3 | from .backends.base import BaseCacheBackend 4 | 5 | 6 | class CacheRegistry: 7 | _caches: Dict[str, BaseCacheBackend] = {} 8 | 9 | @classmethod 10 | def get(cls, name: str) -> Union[BaseCacheBackend, None]: 11 | return cls._caches.get(name, None) 12 | 13 | @classmethod 14 | def set(cls, name: str, cache: BaseCacheBackend) -> None: 15 | if name in cls._caches: 16 | raise NameError('Cache with the same name already registered') 17 | 18 | cls._caches[name] = cache 19 | 20 | @classmethod 21 | def all(cls) -> Tuple[BaseCacheBackend]: 22 | return tuple(cls._caches.values()) 23 | 24 | @classmethod 25 | def remove(cls, name: str) -> None: 26 | if name not in cls._caches: 27 | raise NameError('Cache with the same name not registered') 28 | 29 | del cls._caches[name] 30 | 31 | @classmethod 32 | def flush(cls) -> None: 33 | cls._caches = {} 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ivan Sushkov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fastapi_cache/backends/base.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar, Generic, Tuple 2 | 3 | KT = TypeVar('KT') 4 | VT = TypeVar('VT') 5 | 6 | 7 | class BaseCacheBackend(Generic[KT, VT]): 8 | async def add( 9 | self, 10 | key: KT, 11 | value: VT, 12 | **kwargs 13 | ) -> bool: 14 | raise NotImplementedError 15 | 16 | async def get( 17 | self, 18 | key: KT, 19 | default: VT = None, 20 | **kwargs 21 | ) -> VT: 22 | raise NotImplementedError 23 | 24 | async def set( 25 | self, 26 | key: KT, 27 | value: VT, 28 | **kwargs 29 | ) -> bool: 30 | raise NotImplementedError 31 | 32 | async def expire( 33 | self, 34 | key: KT, 35 | ttl: int 36 | ) -> bool: 37 | raise NotImplementedError 38 | 39 | async def exists(self, *keys: KT) -> bool: 40 | raise NotImplementedError 41 | 42 | async def delete(self, key: KT) -> bool: 43 | raise NotImplementedError 44 | 45 | async def flush(self) -> None: 46 | raise NotImplementedError 47 | 48 | async def close(self) -> None: 49 | raise NotImplementedError 50 | -------------------------------------------------------------------------------- /fastapi_cache/backends/memory.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Hashable 2 | 3 | from .base import BaseCacheBackend 4 | from .utils.ttldict import TTLDict 5 | 6 | CACHE_KEY = 'IN_MEMORY' 7 | 8 | 9 | class InMemoryCacheBackend(BaseCacheBackend[Hashable, Any]): 10 | def __init__(self): 11 | self._cache = TTLDict() 12 | 13 | async def add( 14 | self, 15 | key: Hashable, 16 | value: Any, 17 | **kwargs, 18 | ) -> bool: 19 | return self._cache.add(key, value, **kwargs) 20 | 21 | async def get( 22 | self, 23 | key: Hashable, 24 | default: Any = None, 25 | **kwargs 26 | ) -> Any: 27 | return self._cache.get(key, default) 28 | 29 | async def set( 30 | self, 31 | key: Hashable, 32 | value: Any, 33 | **kwargs, 34 | ) -> bool: 35 | return self._cache.set(key, value, **kwargs) 36 | 37 | async def expire( 38 | self, 39 | key: Hashable, 40 | ttl: int 41 | ) -> bool: 42 | return self._cache.expire(key, ttl) 43 | 44 | async def exists(self, *keys: Hashable) -> bool: 45 | return self._cache.exists(*keys) 46 | 47 | async def delete(self, key: Hashable) -> bool: 48 | return self._cache.delete(key) 49 | 50 | async def flush(self) -> None: 51 | return self._cache.flush() 52 | 53 | async def close(self) -> None: # pragma: no cover 54 | return None 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI Cache 2 | 3 | [![Codacy Badge](https://api.codacy.com/project/badge/Grade/2ec5c44e899943c8920d3c3e31616784)](https://app.codacy.com/manual/ivan.sushkov/fastapi_cache?utm_source=github.com&utm_medium=referral&utm_content=comeuplater/fastapi_cache&utm_campaign=Badge_Grade_Dashboard) 4 | [![License](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 5 | [![PyPi Version](https://img.shields.io/pypi/v/fastapi-cache.svg)](https://pypi.python.org/pypi/fastapi-cache/) 6 | [![Downloads](https://pepy.tech/badge/fastapi-cache)](https://pepy.tech/project/fastapi-cache) 7 | [![Build Status](https://travis-ci.com/comeuplater/fastapi_cache.svg?branch=master)](https://travis-ci.com/comeuplater/fastapi_cache) 8 | 9 | Implements simple lightweight cache system as dependencies in FastAPI. 10 | 11 | ## Installation 12 | 13 | ```sh 14 | pip install fastapi-cache 15 | ``` 16 | 17 | ## Usage example 18 | ```python 19 | from fastapi import Depends, FastAPI 20 | 21 | from fastapi_cache import caches, close_caches 22 | from fastapi_cache.backends.redis import CACHE_KEY, RedisCacheBackend 23 | 24 | app = FastAPI() 25 | 26 | 27 | def redis_cache(): 28 | return caches.get(CACHE_KEY) 29 | 30 | 31 | @app.get('/') 32 | async def hello( 33 | cache: RedisCacheBackend = Depends(redis_cache) 34 | ): 35 | in_cache = await cache.get('some_cached_key') 36 | if not in_cache: 37 | await cache.set('some_cached_key', 'new_value', 5) 38 | 39 | return {'response': in_cache or 'default'} 40 | 41 | 42 | @app.on_event('startup') 43 | async def on_startup() -> None: 44 | rc = RedisCacheBackend('redis://redis') 45 | caches.set(CACHE_KEY, rc) 46 | 47 | 48 | @app.on_event('shutdown') 49 | async def on_shutdown() -> None: 50 | await close_caches() 51 | ``` 52 | 53 | ## TODO 54 | 55 | * [X] Add tests 56 | * [ ] ~~Add registry decorator~~ 57 | * [ ] Add dependency for requests caching 58 | 59 | ## Acknowledgments 60 | 61 | * [Balburdia](https://github.com/Balburdia) 62 | * [xobtoor](https://github.com/xobtoor) 63 | * [jersobh](https://github.com/jersobh) 64 | 65 | 66 | ## Changelog 67 | 68 | * 0.0.6 Added typings for backends. Specific arguments now need to be passed through **kwargs. 69 | Set default encoding to utf-8 for redis backend, removed default TTL for redis keys. 70 | 71 | * 0.1.0 Added TTL support for InMemoryCacheBackend. Added `expire()` method that update ttl value for key. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ -------------------------------------------------------------------------------- /fastapi_cache/backends/utils/ttldict.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Hashable, Optional, Union 3 | 4 | 5 | class TTLDict: 6 | def __init__(self): 7 | self._base = dict() 8 | 9 | def set(self, key: Hashable, value: Any, *, ttl: Optional[int] = None) -> bool: 10 | try: 11 | self._base[key] = ( 12 | self._get_ttl_timestamp(ttl), value 13 | ) 14 | print(self._base) 15 | except Exception: 16 | return False 17 | 18 | return True 19 | 20 | def add(self, key: Hashable, value: Any, *, ttl: Optional[int] = None) -> bool: 21 | if key in self._base: 22 | return False 23 | 24 | return self.set(key, value, ttl=ttl) 25 | 26 | def get(self, key: Hashable, default: Optional[Any] = None) -> Any: 27 | ttl, value = self._base.get(key, (None, default)) 28 | print(self._base) 29 | if ttl is not None and self._is_ttl_expired(ttl): 30 | self.delete(key) 31 | return default 32 | 33 | return value 34 | 35 | def delete(self, key: Hashable) -> bool: 36 | if key not in self._base: 37 | return False 38 | 39 | del self._base[key] 40 | return True 41 | 42 | def exists(self, *keys: Hashable) -> bool: 43 | hits = [] 44 | for key in keys: 45 | if key not in self._base: 46 | hits.append(False) 47 | continue 48 | 49 | ttl, value = self._base.get(key, (None, None)) 50 | if ttl is not None and self._is_ttl_expired(ttl): 51 | hits.append(False) 52 | else: 53 | hits.append(True) 54 | 55 | return any(hits) 56 | 57 | def expire( 58 | self, 59 | key: Hashable, 60 | ttl: int 61 | ) -> bool: 62 | if key in self._base: 63 | _, value = self._base.get(key) 64 | self._base[key] = ( 65 | self._get_ttl_timestamp(ttl), 66 | value 67 | ) 68 | return True 69 | 70 | return False 71 | 72 | def flush(self) -> None: 73 | self._base.clear() 74 | 75 | def _get_ttl_timestamp(self, ttl: Union[int, None]) -> Union[int, None]: 76 | if ttl is None: 77 | return None 78 | 79 | return int(time.time()) + ttl 80 | 81 | def _is_ttl_expired(self, timestamp: Union[int, None]) -> bool: 82 | return int(time.time()) >= timestamp 83 | -------------------------------------------------------------------------------- /tests/registry_tests.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fastapi_cache.backends.memory import BaseCacheBackend, CACHE_KEY 4 | from fastapi_cache.registry import CacheRegistry 5 | 6 | 7 | @pytest.fixture 8 | def cache_registry() -> CacheRegistry: 9 | registry = CacheRegistry() 10 | yield registry 11 | registry.flush() 12 | 13 | 14 | def test_get_from_registry_should_return_none_if_cache_not_registered( 15 | cache_registry: CacheRegistry 16 | ) -> None: 17 | assert cache_registry.get('') is None 18 | 19 | 20 | def test_get_from_registry_should_return_cache_instance( 21 | cache_registry: CacheRegistry 22 | ) -> None: 23 | cache = BaseCacheBackend() 24 | cache_registry.set(CACHE_KEY, cache) 25 | 26 | assert cache_registry.get(CACHE_KEY) == cache 27 | 28 | 29 | def test_retrieve_all_registered_caches_from_registry( 30 | cache_registry: CacheRegistry 31 | ) -> None: 32 | cache = BaseCacheBackend() 33 | 34 | cache_registry.set(CACHE_KEY, cache) 35 | cache_registry.set('OTHER_CACHE_KEY', cache) 36 | 37 | assert cache_registry.all() == (cache, cache) 38 | 39 | 40 | def test_registry_should_raise_error_on_dublicate_cache_key( 41 | cache_registry: CacheRegistry 42 | ) -> None: 43 | cache = BaseCacheBackend() 44 | cache_registry.set(CACHE_KEY, cache) 45 | 46 | with pytest.raises(NameError, match='Cache with the same name already registered'): 47 | cache_registry.set(CACHE_KEY, cache) 48 | 49 | 50 | def test_remove_cache_from_registry( 51 | cache_registry: CacheRegistry 52 | ) -> None: 53 | cache = BaseCacheBackend() 54 | cache_registry.set(CACHE_KEY, cache) 55 | cache_registry.remove(CACHE_KEY) 56 | 57 | assert cache_registry.get(CACHE_KEY) is None 58 | 59 | 60 | def test_remove_cache_from_registry_should_raise_error_if_cache_not_register( 61 | cache_registry: CacheRegistry 62 | ) -> None: 63 | with pytest.raises(NameError, match='Cache with the same name not registered'): 64 | cache_registry.remove(CACHE_KEY) 65 | 66 | 67 | def test_flush_should_remove_all_registered_cashes( 68 | cache_registry: CacheRegistry 69 | ) -> None: 70 | cache = BaseCacheBackend() 71 | 72 | cache_registry.set(CACHE_KEY, cache) 73 | cache_registry.set('OTHER_CACHE_KEY', cache) 74 | 75 | cache_registry.flush() 76 | 77 | assert cache_registry.get(CACHE_KEY) is None 78 | assert cache_registry.get('OTHER_CACHE_KEY') is None 79 | 80 | 81 | @pytest.mark.backwards 82 | def test_registry_can_be_imported_by_older_path() -> None: 83 | import importlib 84 | 85 | fastapi_cache = importlib.import_module("fastapi_cache") 86 | assert hasattr(fastapi_cache, 'caches') -------------------------------------------------------------------------------- /fastapi_cache/backends/redis.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, AnyStr 2 | 3 | import aioredis 4 | from aioredis import Redis 5 | 6 | from .base import BaseCacheBackend 7 | 8 | DEFAULT_ENCODING = 'utf-8' 9 | DEFAULT_POOL_MIN_SIZE = 5 10 | CACHE_KEY = 'REDIS' 11 | 12 | # expected to be of bytearray, bytes, float, int, or str type 13 | 14 | RedisKey = Union[AnyStr, float, int] 15 | RedisValue = Union[AnyStr, float, int] 16 | 17 | 18 | class RedisCacheBackend(BaseCacheBackend[RedisKey, RedisValue]): 19 | def __init__( 20 | self, 21 | address: str, 22 | pool_minsize: Optional[int] = DEFAULT_POOL_MIN_SIZE, 23 | encoding: Optional[str] = DEFAULT_ENCODING, 24 | ) -> None: 25 | self._redis_address = address 26 | self._redis_pool_minsize = pool_minsize 27 | self._encoding = encoding 28 | 29 | self._pool: Optional[Redis] = None 30 | 31 | @property 32 | async def _client(self) -> Redis: 33 | if self._pool is None: 34 | self._pool = await self._create_connection() 35 | 36 | return self._pool 37 | 38 | async def _create_connection(self) -> Redis: 39 | return await aioredis.create_redis_pool( 40 | self._redis_address, 41 | minsize=self._redis_pool_minsize, 42 | ) 43 | 44 | async def add( 45 | self, 46 | key: RedisKey, 47 | value: RedisValue, 48 | **kwargs 49 | ) -> bool: 50 | """ 51 | Bad temporary solution, this approach prone to 52 | race condition error. 53 | 54 | # TODO: Implement via Lua 55 | 56 | if redis.call("GET", KEYS[1]) then 57 | return 0 58 | else 59 | redis.call("SET", KEYS[1], ARGV[1]) 60 | return 1 61 | end 62 | """ 63 | 64 | client = await self._client 65 | in_cache = await client.get(key) 66 | 67 | if in_cache is not None: 68 | return False 69 | 70 | return await client.set(key, value, **kwargs) 71 | 72 | async def get( 73 | self, 74 | key: RedisKey, 75 | default: RedisValue = None, 76 | **kwargs, 77 | ) -> AnyStr: 78 | kwargs.setdefault('encoding', self._encoding) 79 | 80 | client = await self._client 81 | cached_value = await client.get(key, **kwargs) 82 | 83 | return cached_value if cached_value is not None else default 84 | 85 | async def set( 86 | self, 87 | key: RedisKey, 88 | value: RedisValue, 89 | **kwargs, 90 | ) -> bool: 91 | client = await self._client 92 | 93 | return await client.set(key, value, **kwargs) 94 | 95 | async def exists(self, *keys: RedisKey) -> bool: 96 | client = await self._client 97 | exists = await client.exists(*keys) 98 | 99 | return bool(exists) 100 | 101 | async def delete(self, key: RedisKey) -> bool: 102 | client = await self._client 103 | 104 | return await client.delete(key) 105 | 106 | async def flush(self) -> None: 107 | client = await self._client 108 | await client.flushdb() 109 | 110 | async def expire( 111 | self, 112 | key: RedisKey, 113 | ttl: int 114 | ) -> bool: 115 | client = await self._client 116 | 117 | return await client.expire(key, ttl) 118 | 119 | async def close(self) -> None: 120 | client = await self._client 121 | client.close() 122 | await client.wait_closed() 123 | -------------------------------------------------------------------------------- /tests/memory_tests.py: -------------------------------------------------------------------------------- 1 | from typing import Hashable, Any, Tuple 2 | 3 | import pytest 4 | 5 | from fastapi_cache.backends.memory import InMemoryCacheBackend 6 | 7 | TEST_KEY = 'constant' 8 | TEST_VALUE = '0' 9 | 10 | 11 | @pytest.fixture 12 | def f_backend() -> InMemoryCacheBackend: 13 | return InMemoryCacheBackend() 14 | 15 | 16 | @pytest.mark.asyncio 17 | async def test_should_add_n_get_data( 18 | f_backend: InMemoryCacheBackend 19 | ) -> None: 20 | is_added = await f_backend.add(TEST_KEY, TEST_VALUE) 21 | 22 | assert is_added is True 23 | assert await f_backend.get(TEST_KEY) == TEST_VALUE 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_add_should_return_false_if_key_exists( 28 | f_backend: InMemoryCacheBackend 29 | ) -> None: 30 | await f_backend.add(TEST_KEY, TEST_VALUE) 31 | is_added = await f_backend.add(TEST_KEY, TEST_VALUE) 32 | 33 | assert is_added is False 34 | 35 | 36 | @pytest.mark.asyncio 37 | async def test_should_return_default_if_key_not_exists( 38 | f_backend: InMemoryCacheBackend 39 | ) -> None: 40 | default = '3.14159' 41 | fetched_value = await f_backend.get('not_exists', default) 42 | 43 | assert fetched_value == default 44 | 45 | 46 | @pytest.mark.asyncio 47 | async def test_set_should_rewrite_value( 48 | f_backend: InMemoryCacheBackend 49 | ) -> None: 50 | eulers_number = '2.71828' 51 | 52 | await f_backend.add(TEST_KEY, TEST_VALUE) 53 | await f_backend.set(TEST_KEY, eulers_number) 54 | 55 | fetched_value = await f_backend.get(TEST_KEY) 56 | 57 | assert fetched_value == eulers_number 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_delete_should_remove_from_cache( 62 | f_backend: InMemoryCacheBackend 63 | ) -> None: 64 | await f_backend.add(TEST_KEY, TEST_VALUE) 65 | await f_backend.delete(TEST_KEY) 66 | 67 | fetched_value = await f_backend.get(TEST_KEY) 68 | 69 | assert fetched_value is None 70 | 71 | 72 | @pytest.mark.asyncio 73 | async def test_flush_should_remove_all_objects_from_cache( 74 | f_backend: InMemoryCacheBackend 75 | ) -> None: 76 | await f_backend.add('pi', '3.14159') 77 | await f_backend.add('golden_ratio', '1.61803') 78 | 79 | await f_backend.flush() 80 | 81 | assert await f_backend.get('pi') is None 82 | assert await f_backend.get('golden_ratio') is None 83 | 84 | 85 | @pytest.mark.asyncio 86 | @pytest.mark.parametrize('key,value,ttl,expected', [ 87 | ['hello', 'world', 0, None], 88 | ['hello', 'world', 100, 'world'], 89 | ]) 90 | async def test_should_set_value_with_ttl( 91 | key: Hashable, 92 | value: Any, 93 | ttl: int, 94 | expected: Any, 95 | f_backend: InMemoryCacheBackend 96 | ) -> None: 97 | await f_backend.set(key, value, ttl=ttl) 98 | fetched_value = await f_backend.get(key) 99 | 100 | assert fetched_value == expected 101 | 102 | 103 | @pytest.mark.asyncio 104 | @pytest.mark.parametrize('key,value,ttl,expected', [ 105 | ['hello', 'world', 0, None], 106 | ['hello', 'world', 100, 'world'], 107 | ]) 108 | async def test_should_add_value_with_ttl( 109 | key: Hashable, 110 | value: Any, 111 | ttl: int, 112 | expected: Any, 113 | f_backend: InMemoryCacheBackend 114 | ) -> None: 115 | await f_backend.add(key, value, ttl=ttl) 116 | fetched_value = await f_backend.get(key) 117 | 118 | assert fetched_value == expected 119 | 120 | 121 | @pytest.mark.asyncio 122 | @pytest.mark.parametrize('keys', [ 123 | ('hello', 'world'), 124 | ]) 125 | async def test_key_should_check_for_exists( 126 | keys: Tuple[Hashable], 127 | f_backend: InMemoryCacheBackend 128 | ) -> None: 129 | for key in keys: 130 | await f_backend.set(key, key) 131 | 132 | assert await f_backend.exists(*keys) is True 133 | 134 | 135 | @pytest.mark.asyncio 136 | @pytest.mark.parametrize('keys,ttl,exists', [ 137 | [('hello', 'world'), 0, False], 138 | [('hello', 'world'), 100, True], 139 | ]) 140 | async def test_key_should_check_for_exists_with_ttl( 141 | keys: Tuple[Hashable], 142 | ttl: int, 143 | exists: bool, 144 | f_backend: InMemoryCacheBackend 145 | ) -> None: 146 | for key in keys: 147 | await f_backend.set(key, key, ttl=ttl) 148 | 149 | assert await f_backend.exists(*keys) is exists 150 | 151 | 152 | @pytest.mark.asyncio 153 | @pytest.mark.parametrize('keys', [ 154 | ('hello', 'world'), 155 | ]) 156 | async def test_should_return_false_if_keys_not_exist( 157 | keys: Tuple[Hashable], 158 | f_backend: InMemoryCacheBackend 159 | ) -> None: 160 | assert await f_backend.exists(*keys) is False 161 | 162 | 163 | @pytest.mark.asyncio 164 | @pytest.mark.parametrize('key,value,ttl,expected', [ 165 | [TEST_KEY, TEST_VALUE, 0, None], 166 | [TEST_KEY, TEST_VALUE, 10, TEST_VALUE], 167 | ]) 168 | async def test_expire_from_cache( 169 | key: Hashable, 170 | value: Any, 171 | ttl: int, 172 | expected: Any, 173 | f_backend: InMemoryCacheBackend 174 | ) -> None: 175 | await f_backend.add(key, value) 176 | await f_backend.expire(key, ttl) 177 | fetched_value = await f_backend.get(key) 178 | 179 | assert fetched_value == expected 180 | -------------------------------------------------------------------------------- /tests/redis_tests.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Any 2 | 3 | import aioredis 4 | import pytest 5 | 6 | from fastapi_cache.backends.redis import RedisCacheBackend, RedisKey 7 | 8 | TEST_KEY = 'constant' 9 | TEST_VALUE = '0' 10 | 11 | 12 | @pytest.fixture 13 | def f_backend() -> RedisCacheBackend: 14 | return RedisCacheBackend( 15 | 'redis://localhost' 16 | ) 17 | 18 | 19 | @pytest.mark.asyncio 20 | async def test_should_add_n_get_data( 21 | f_backend: RedisCacheBackend 22 | ) -> None: 23 | is_added = await f_backend.add(TEST_KEY, TEST_VALUE) 24 | 25 | assert is_added is True 26 | assert await f_backend.get(TEST_KEY) == TEST_VALUE 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_should_add_n_get_data_no_encoding( 31 | f_backend: RedisCacheBackend 32 | ) -> None: 33 | NO_ENCODING_KEY = 'bytes' 34 | NO_ENCODING_VALUE = b'test' 35 | is_added = await f_backend.add(NO_ENCODING_KEY, NO_ENCODING_VALUE) 36 | 37 | assert is_added is True 38 | assert await f_backend.get(NO_ENCODING_KEY, encoding=None) == bytes(NO_ENCODING_VALUE) 39 | 40 | 41 | @pytest.mark.asyncio 42 | async def test_add_should_return_false_if_key_exists( 43 | f_backend: RedisCacheBackend 44 | ) -> None: 45 | await f_backend.add(TEST_KEY, TEST_VALUE) 46 | is_added = await f_backend.add(TEST_KEY, TEST_VALUE) 47 | 48 | assert is_added is False 49 | 50 | 51 | @pytest.mark.asyncio 52 | async def test_should_return_default_if_key_not_exists( 53 | f_backend: RedisCacheBackend 54 | ) -> None: 55 | default = '3.14159' 56 | fetched_value = await f_backend.get('not_exists', default) 57 | 58 | assert fetched_value == default 59 | 60 | 61 | @pytest.mark.parametrize('preset,keys,exists', ( 62 | [ 63 | [ 64 | ('key1', 'value1'), 65 | ('key2', 'value2'), 66 | ], 67 | ['key1', 'key2'], 68 | True, 69 | ], 70 | [ 71 | [ 72 | ('key1', 'value1'), 73 | ], 74 | ['key1', 'key2'], 75 | True, 76 | ], 77 | [ 78 | [ 79 | ('key1', 'value1'), 80 | ('key2', 'value2'), 81 | ], 82 | ['key3', 'key4'], 83 | False, 84 | ], 85 | [ 86 | [], 87 | ['key3', 'key4'], 88 | False, 89 | ], 90 | 91 | )) 92 | @pytest.mark.asyncio 93 | async def test_should_check_is_several_keys_exists( 94 | preset: List[Tuple[str, str]], 95 | keys: List[str], 96 | exists: bool, 97 | f_backend: RedisCacheBackend 98 | ) -> None: 99 | for key, value in preset: 100 | await f_backend.add(key, value) 101 | 102 | assert await f_backend.exists(*keys) == exists 103 | 104 | 105 | @pytest.mark.asyncio 106 | async def test_set_should_rewrite_value( 107 | f_backend: RedisCacheBackend 108 | ) -> None: 109 | eulers_number = '2.71828' 110 | 111 | await f_backend.add(TEST_KEY, TEST_VALUE) 112 | await f_backend.set(TEST_KEY, eulers_number) 113 | 114 | fetched_value = await f_backend.get(TEST_KEY) 115 | 116 | assert fetched_value == eulers_number 117 | 118 | 119 | @pytest.mark.asyncio 120 | async def test_delete_should_remove_from_cache( 121 | f_backend: RedisCacheBackend 122 | ) -> None: 123 | await f_backend.add(TEST_KEY, TEST_VALUE) 124 | await f_backend.delete(TEST_KEY) 125 | 126 | fetched_value = await f_backend.get(TEST_KEY) 127 | 128 | assert fetched_value is None 129 | 130 | 131 | @pytest.mark.asyncio 132 | @pytest.mark.parametrize('key,value,ttl,expected', [ 133 | [TEST_KEY, TEST_VALUE, 0, None], 134 | [TEST_KEY, TEST_VALUE, 10, TEST_VALUE], 135 | ]) 136 | async def test_expire_from_cache( 137 | key: RedisKey, 138 | value: Any, 139 | ttl: int, 140 | expected: Any, 141 | f_backend: RedisCacheBackend 142 | ) -> None: 143 | await f_backend.add(key, value) 144 | await f_backend.expire(key, ttl) 145 | fetched_value = await f_backend.get(key) 146 | 147 | assert fetched_value == expected 148 | 149 | 150 | @pytest.mark.asyncio 151 | async def test_flush_should_remove_all_objects_from_cache( 152 | f_backend: RedisCacheBackend 153 | ) -> None: 154 | await f_backend.add('pi', '3.14159') 155 | await f_backend.add('golden_ratio', '1.61803') 156 | 157 | await f_backend.flush() 158 | 159 | assert await f_backend.get('pi') is None 160 | assert await f_backend.get('golden_ratio') is None 161 | 162 | 163 | @pytest.mark.asyncio 164 | async def test_close_should_close_connection( 165 | f_backend: RedisCacheBackend 166 | ) -> None: 167 | await f_backend.close() 168 | with pytest.raises(aioredis.errors.PoolClosedError): 169 | await f_backend.add(TEST_KEY, TEST_VALUE) 170 | 171 | 172 | @pytest.mark.asyncio 173 | @pytest.mark.parametrize('key,value,expected', [ 174 | [1, 1, '1'], 175 | ['1', 1, '1'], 176 | [1.5, 1, '1'], 177 | [b'bytes', 1, '1'], 178 | ['2', 1.5, '1.5'], 179 | [3, b'bytes', 'bytes'], # default encoding 180 | [3.5, '3.5', '3.5'], 181 | ]) 182 | async def test_scalar_types( 183 | key: RedisKey, 184 | value: Any, 185 | expected: Any, 186 | f_backend: RedisCacheBackend 187 | ) -> None: 188 | await f_backend.set(key, value) 189 | assert await f_backend.get(key) == expected 190 | -------------------------------------------------------------------------------- /tests/ttldict_tests.py: -------------------------------------------------------------------------------- 1 | from typing import Hashable, Any, Tuple 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | from fastapi_cache.backends.utils.ttldict import TTLDict 7 | 8 | MOCKED_DICT = dict() 9 | 10 | 11 | @pytest.fixture 12 | def ttl_dict() -> TTLDict: 13 | global MOCKED_DICT 14 | MOCKED_DICT = dict() 15 | 16 | with mock.patch('fastapi_cache.backends.utils.ttldict.dict', return_value=MOCKED_DICT): 17 | return TTLDict() 18 | 19 | 20 | @pytest.mark.parametrize('key,value', [ 21 | ('hello', 'world'), 22 | ]) 23 | def test_set_value_to_dict( 24 | key: Hashable, 25 | value: Any, 26 | ttl_dict: TTLDict, 27 | ) -> None: 28 | assert ttl_dict.set(key, value) is True 29 | assert key in MOCKED_DICT 30 | assert MOCKED_DICT[key] == (None, value) 31 | 32 | 33 | @pytest.mark.parametrize('key,value,rewrite_value', [ 34 | ('hello', 'world', 'kitty'), 35 | ]) 36 | def test_ser_should_rewrite_value_if_exits( 37 | key: Hashable, 38 | value: Any, 39 | rewrite_value: Any, 40 | ttl_dict: TTLDict 41 | ) -> None: 42 | ttl_dict.set(key, value) 43 | 44 | assert ttl_dict.set(key, rewrite_value) is True 45 | assert MOCKED_DICT.get(key) == (None, rewrite_value) 46 | 47 | 48 | @pytest.mark.parametrize('key,value', [ 49 | ('hello', 'world'), 50 | ]) 51 | def test_add_value_to_dict( 52 | key: Hashable, 53 | value: Any, 54 | ttl_dict: TTLDict 55 | ): 56 | assert ttl_dict.add(key, value) is True 57 | assert key in MOCKED_DICT 58 | assert MOCKED_DICT.get(key) == (None, value) 59 | 60 | 61 | @pytest.mark.parametrize('key,value,rewrite_value', [ 62 | ('hello', 'world', 'kitty'), 63 | ]) 64 | def test_add_should_return_false_if_key_exists( 65 | key: Hashable, 66 | value: Any, 67 | rewrite_value: Any, 68 | ttl_dict: TTLDict 69 | ) -> None: 70 | ttl_dict.add(key, value) 71 | 72 | assert ttl_dict.add(key, rewrite_value) is False 73 | assert MOCKED_DICT.get(key) == (None, value) 74 | 75 | 76 | @pytest.mark.parametrize('key,value', [ 77 | ('hello', 'world'), 78 | ]) 79 | def test_get_should_return_value_if_exists( 80 | key: Hashable, 81 | value: Any, 82 | ttl_dict: TTLDict 83 | ) -> None: 84 | ttl_dict.set(key, value) 85 | 86 | assert ttl_dict.get(key) == value 87 | 88 | 89 | @pytest.mark.parametrize('key,default', [ 90 | ('hello', 'WOW'), 91 | ]) 92 | def test_get_should_return_default_if_key_not_exists( 93 | key: Hashable, 94 | default: Any, 95 | ttl_dict: TTLDict 96 | ) -> None: 97 | assert ttl_dict.get(key, default) == default 98 | 99 | 100 | @pytest.mark.parametrize('key,value', [ 101 | ('hello', 'world'), 102 | ]) 103 | def test_delete_should_remove_key( 104 | key: Hashable, 105 | value: Any, 106 | ttl_dict: TTLDict 107 | ) -> None: 108 | ttl_dict.set(key, value) 109 | ttl_dict.delete(key) 110 | 111 | assert key not in MOCKED_DICT 112 | 113 | 114 | def test_flush_should_remove_all_keys(ttl_dict: TTLDict) -> None: 115 | for num in range(10): 116 | ttl_dict.set(str(num), num) 117 | 118 | ttl_dict.flush() 119 | assert MOCKED_DICT == {} 120 | 121 | 122 | @pytest.mark.parametrize('key,value,default', [ 123 | ('hello', 'world', 'WOW'), 124 | ]) 125 | def test_get_should_return_default_if_ttl_expired( 126 | key: Hashable, 127 | value: Any, 128 | default: Any, 129 | ttl_dict: TTLDict, 130 | ) -> None: 131 | ttl_dict.set(key, value, ttl=0) 132 | 133 | assert ttl_dict.get(key, default) == default 134 | 135 | 136 | @pytest.mark.parametrize('key,value', [ 137 | ('hello', 'world'), 138 | ]) 139 | def test_get_should_return_value_if_ttl_not_expired( 140 | key: Hashable, 141 | value: Any, 142 | ttl_dict: TTLDict, 143 | ) -> None: 144 | ttl_dict.set(key, value, ttl=100) 145 | assert ttl_dict.get(key, 'NO OOPS') == value 146 | 147 | 148 | @pytest.mark.parametrize('keys', [ 149 | ('hello', 'world'), 150 | ]) 151 | def test_key_should_check_for_exists( 152 | keys: Tuple[Hashable], 153 | ttl_dict: TTLDict 154 | ) -> None: 155 | for key in keys: 156 | ttl_dict.set(key, key) 157 | 158 | assert ttl_dict.exists(*keys) is True 159 | 160 | 161 | @pytest.mark.parametrize('keys,ttl,exists', [ 162 | [('hello', 'world'), 0, False], 163 | [('hello', 'world'), 100, True], 164 | ]) 165 | def test_key_should_check_for_exists_with_ttl( 166 | keys: Tuple[Hashable], 167 | ttl: int, 168 | exists: bool, 169 | ttl_dict: TTLDict 170 | ) -> None: 171 | for key in keys: 172 | ttl_dict.set(key, key, ttl=ttl) 173 | 174 | assert ttl_dict.exists(*keys) is exists 175 | 176 | 177 | @pytest.mark.parametrize('keys', [ 178 | ('hello', 'world'), 179 | ]) 180 | def test_should_return_false_if_keys_not_exist( 181 | keys: Tuple[Hashable], 182 | ttl_dict: TTLDict 183 | ) -> None: 184 | assert ttl_dict.exists(*keys) is False 185 | 186 | 187 | @pytest.mark.parametrize('key,value,ttl,expected', [ 188 | ['hello', 'world', 0, None], 189 | ['hello', 'world', 10, 'world'], 190 | ]) 191 | def test_expire_from_cache( 192 | key: Hashable, 193 | value: Any, 194 | ttl: int, 195 | expected: Any, 196 | ttl_dict: TTLDict 197 | ) -> None: 198 | ttl_dict.add(key, value) 199 | ttl_dict.expire(key, ttl) 200 | 201 | assert ttl_dict.get(key) == expected 202 | --------------------------------------------------------------------------------