├── .editorconfig ├── .flake8 ├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── README.md ├── examples ├── inmem_app │ ├── README.md │ ├── __init__.py │ ├── app.py │ ├── db.py │ ├── db_types.py │ ├── models.py │ ├── settings.py │ └── utils.py └── redis_app │ ├── README.md │ ├── __init__.py │ ├── app.py │ ├── db.py │ ├── db_types.py │ ├── models.py │ ├── settings.py │ └── utils.py ├── fastapi_caching ├── __init__.py ├── backends.py ├── constants.py ├── dependencies.py ├── exceptions.py ├── hashers.py ├── manager.py ├── objects.py ├── py.typed └── raw.py ├── pyproject.toml ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── helpers.py ├── test_backends.py ├── test_dependencies.py └── test_manager.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig helps developers define and maintain consistent 2 | # coding styles between different editors and IDEs 3 | # editorconfig.org 4 | 5 | root = true 6 | 7 | 8 | [*] 9 | indent_style = space 10 | end_of_line = lf 11 | charset = utf-8 12 | trim_trailing_whitespace = true 13 | insert_final_newline = true 14 | 15 | [*.{md}] 16 | trim_trailing_whitespace = false 17 | 18 | [*.py] 19 | indent_size = 4 20 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203,E501,E711,E712,W503,W504 3 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.6, 3.7, 3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -e '.[all]' 30 | - name: Test with pytest 31 | run: | 32 | pytest 33 | - name: Ensure code standard is followed via black 34 | run: | 35 | black --diff . 36 | black --check . 37 | - name: Ensure that imports are sorted correctly with isort 38 | run: | 39 | isort --diff . 40 | isort --check . 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.db 2 | *.egg-info 3 | *.pyc 4 | *.sublime-* 5 | /.coverage 6 | /.pytest_cache 7 | /build 8 | /dist 9 | /env 10 | /pip-wheel-metadata 11 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output = 3 3 | include_trailing_comma = True 4 | force_grid_wrap = 0 5 | use_parentheses = True 6 | line_length = 88 7 | known_third_party = aioredis,cachetools,databases,fakeredis,fastapi,httpx,pkg_resources,pydantic,pytest,setuptools,sqlalchemy,starlette 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitlab.com/pycqa/flake8 3 | rev: master 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/asottile/seed-isort-config 7 | rev: v1.9.2 8 | hooks: 9 | - id: seed-isort-config 10 | - repo: https://github.com/pre-commit/mirrors-isort 11 | rev: master 12 | hooks: 13 | - id: isort 14 | - repo: https://github.com/psf/black 15 | rev: stable 16 | hooks: 17 | - id: black 18 | language_version: python3.7 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI-Caching 2 | 3 | Cache library for FastAPI with tag based invalidation. Asynchronous only for the time being. 4 | 5 | ## Features 6 | 7 | - Automatic response cache fetching using FastAPI dependencies 8 | - Fine-grained control over when to return and set the cache 9 | - Ability to invalidate cached objects based on a concept of associated tags. See [examples/redis_app](/examples/redis_app) for an example. 10 | 11 | ## Installation 12 | 13 | With in-memory support only: 14 | ```bash 15 | pip install fastapi-caching 16 | ``` 17 | 18 | NOTE: In-memory backend is only recommended when your app is only run as a single instance. 19 | 20 | With redis support (through the [aioredis](https://aioredis.readthedocs.io/) library): 21 | ```bash 22 | pip install fastapi-caching[redis] 23 | ``` 24 | 25 | ## Usage examples 26 | 27 | Examples on how to use [can be found here](/examples). 28 | 29 | 30 | ## Limitations 31 | 32 | - Only supported within async contexts. 33 | 34 | ## Changelog 35 | 36 | ### v0.1.2, 2020-08-16 37 | 38 | - Feature: Support lazy setup of cache manager 39 | 40 | ### v0.2.0, 2020-08-16 41 | 42 | - Breaking change: Revert to requiring `backend` to be passed to CacheManager 43 | - Feature: Support lazy configuration of the caching backend 44 | 45 | ### v0.3.0, 2020-08-16 46 | 47 | - Feature: Add functionality to disable (and re-enable) caching. 48 | -------------------------------------------------------------------------------- /examples/inmem_app/README.md: -------------------------------------------------------------------------------- 1 | # FastAPI-Caching In-Memory Backend Example App 2 | 3 | To try out: 4 | 5 | ```bash 6 | pip install 'fastapi-caching[.all]' 7 | uvicorn app1:app 8 | ``` 9 | -------------------------------------------------------------------------------- /examples/inmem_app/__init__.py: -------------------------------------------------------------------------------- 1 | from .app import app # noqa 2 | -------------------------------------------------------------------------------- /examples/inmem_app/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import uuid 4 | from typing import List 5 | 6 | from fastapi import FastAPI, HTTPException 7 | 8 | from fastapi_caching import CacheManager, InMemoryBackend, ResponseCache 9 | 10 | from . import db 11 | from .models import Product 12 | from .settings import get_settings 13 | 14 | app = FastAPI() 15 | 16 | settings = get_settings() 17 | logger = logging.getLogger(__name__) 18 | 19 | cache_backend = InMemoryBackend() 20 | cache_manager = CacheManager(cache_backend) 21 | 22 | 23 | @app.on_event("startup") 24 | async def connect_db(): 25 | await db.database.connect() 26 | db.create_all() 27 | 28 | 29 | @app.on_event("shutdown") 30 | async def disconnect_db(): 31 | await db.database.disconnect() 32 | 33 | 34 | @app.post("/products", response_model=Product) 35 | async def create_product(product: Product): 36 | await db.create_product(product) 37 | await cache_manager.invalidate_tag("all-products") 38 | return product 39 | 40 | 41 | @app.put("/products/{product_id}", response_model=Product) 42 | async def update_product(product_id: uuid.UUID, product: Product): 43 | product.id = product_id 44 | 45 | if not await db.product_exists(product_id): 46 | raise HTTPException(404, detail="Product does not exist") 47 | 48 | await db.update_product(product) 49 | await cache_manager.invalidate_tags([f"product-{product.id}", "all-products"]) 50 | return product 51 | 52 | 53 | @app.get("/products", response_model=List[Product]) 54 | async def list_products(rcache: ResponseCache = cache_manager.from_request()): 55 | if rcache.exists(): 56 | return rcache.data 57 | 58 | products = await db.fetch_products() 59 | 60 | await asyncio.sleep(1) # Some heavy processing... 61 | 62 | await rcache.set(products, tag="all-products") 63 | 64 | return products 65 | 66 | 67 | @app.get("/products/{product_id}", response_model=Product) 68 | async def get_product( 69 | product_id: uuid.UUID, rcache: ResponseCache = cache_manager.from_request() 70 | ): 71 | if rcache.exists(): 72 | return rcache.data 73 | 74 | product = await db.fetch_product(product_id) 75 | 76 | if product is None: 77 | raise HTTPException(404, detail="Product not found") 78 | 79 | await rcache.set(product, tag=f"product-{product_id}") 80 | 81 | return product 82 | 83 | 84 | @app.get("/__admin/reset/all") 85 | async def reset_all_data(): 86 | db.clear_data() 87 | await cache_backend.reset() 88 | return {"ok": True} 89 | 90 | 91 | @app.get("/__admin/reset/cache") 92 | async def reset_cache(): 93 | await cache_backend.reset() 94 | return {"ok": True} 95 | 96 | 97 | @app.get("/__admin/reset/db") 98 | async def reset_db(): 99 | db.clear_data() 100 | return {"ok": True} 101 | -------------------------------------------------------------------------------- /examples/inmem_app/db.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import sqlalchemy as sa 4 | from databases import Database 5 | 6 | from . import db_types 7 | from .settings import get_settings 8 | 9 | db_uri = get_settings().db_uri 10 | 11 | database = Database(db_uri) 12 | metadata = sa.MetaData() 13 | 14 | product_tbl = sa.Table( 15 | "product", 16 | metadata, 17 | sa.Column("id", db_types.Uuid, primary_key=True), 18 | sa.Column("name", sa.Text), 19 | sa.Column("description", sa.Text), 20 | sa.Column("created_at", sa.DateTime), 21 | ) 22 | 23 | 24 | def create_all(): 25 | metadata.create_all(bind=sa.create_engine(db_uri)) 26 | 27 | 28 | def clear_data(): 29 | engine = sa.create_engine(db_uri) 30 | for tbl in metadata.tables.values(): 31 | engine.execute(tbl.delete()) 32 | 33 | 34 | async def create_product(product): 35 | await database.execute(product_tbl.insert().values(product.dict())) 36 | 37 | 38 | async def update_product(product): 39 | values = product.dict() 40 | await database.execute( 41 | product_tbl.update().values(values).where(product_tbl.c.id == product.id) 42 | ) 43 | 44 | 45 | async def fetch_products(): 46 | return await database.fetch_all(product_tbl.select()) 47 | 48 | 49 | async def product_exists(product_id: uuid.UUID) -> bool: 50 | val = await database.fetch_val( 51 | sa.select([product_tbl.c.id]).where(product_tbl.c.id == product_id).limit(1) 52 | ) 53 | return val is not None 54 | 55 | 56 | async def fetch_product(product_id: uuid.UUID): 57 | return await database.fetch_one( 58 | product_tbl.select().where(product_tbl.c.id == product_id) 59 | ) 60 | -------------------------------------------------------------------------------- /examples/inmem_app/db_types.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from sqlalchemy.dialects import postgresql 4 | from sqlalchemy.types import CHAR, TypeDecorator 5 | 6 | 7 | class Uuid(TypeDecorator): 8 | """Platform-independent UUID type. 9 | 10 | Uses PostgreSQL's UUID type, otherwise uses 11 | CHAR(32), storing as stringified hex values. 12 | 13 | From https://docs.sqlalchemy.org/en/13/core/custom_types.html#backend-agnostic-guid-type 14 | """ 15 | 16 | impl = CHAR 17 | 18 | def load_dialect_impl(self, dialect): 19 | if dialect.name == "postgresql": 20 | type_ = postgresql.UUID() 21 | else: 22 | type_ = CHAR(32) 23 | return dialect.type_descriptor(type_) 24 | 25 | def process_bind_param(self, value, dialect): 26 | if value is None: 27 | return value 28 | elif dialect.name == "postgresql": 29 | return str(value) 30 | else: 31 | if not isinstance(value, uuid.UUID): 32 | return "%.32x" % uuid.UUID(value).int 33 | else: 34 | # hexstring 35 | return "%.32x" % value.int 36 | 37 | def process_result_value(self, value, dialect): 38 | if value is None: 39 | return value 40 | else: 41 | if not isinstance(value, uuid.UUID): 42 | value = uuid.UUID(value) 43 | return value 44 | -------------------------------------------------------------------------------- /examples/inmem_app/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime 3 | 4 | from pydantic import BaseModel, Field, validator 5 | 6 | from .utils import random_ascii 7 | 8 | 9 | class Product(BaseModel): 10 | id: uuid.UUID = Field(default_factory=uuid.uuid4) 11 | name: str 12 | description: str = "" 13 | created_at: datetime = Field(default_factory=datetime.utcnow) 14 | 15 | @validator("description", pre=True, always=True) 16 | def ensure_random_description(cls, v): 17 | return v or random_ascii(20_000) 18 | -------------------------------------------------------------------------------- /examples/inmem_app/settings.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | from pydantic import BaseSettings 4 | 5 | __all__ = ("get_settings",) 6 | 7 | 8 | class _Settings(BaseSettings): 9 | log_level: str = "INFO" 10 | db_uri: str = "sqlite:///inmem_app.db" 11 | 12 | 13 | @lru_cache() 14 | def get_settings() -> _Settings: 15 | return _Settings() # Reads variables from environment 16 | -------------------------------------------------------------------------------- /examples/inmem_app/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | ASCII_ALPHA_NUM = string.ascii_lowercase + string.ascii_uppercase + string.digits + " " 5 | 6 | 7 | def random_ascii(n: int): 8 | return "".join(random.choice(ASCII_ALPHA_NUM) for _ in range(n)) 9 | -------------------------------------------------------------------------------- /examples/redis_app/README.md: -------------------------------------------------------------------------------- 1 | # FastAPI-Caching Redis Backend Example App 2 | 3 | To try out: 4 | 5 | 1. Make sure `redis` is installed locally 6 | 2. Execute: 7 | 8 | ```bash 9 | pip install 'fastapi-caching[.all]' 10 | uvicorn examples.redis_app:app 11 | ``` 12 | -------------------------------------------------------------------------------- /examples/redis_app/__init__.py: -------------------------------------------------------------------------------- 1 | from .app import app # noqa 2 | -------------------------------------------------------------------------------- /examples/redis_app/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import uuid 4 | from pathlib import Path 5 | from typing import List 6 | 7 | from fastapi import FastAPI, HTTPException 8 | 9 | from fastapi_caching import CacheManager, RedisBackend, ResponseCache, hashers 10 | 11 | from . import db 12 | from .models import Product 13 | from .settings import get_settings 14 | 15 | app = FastAPI() 16 | 17 | settings = get_settings() 18 | logger = logging.getLogger(__name__) 19 | 20 | # NOTE: In a real world scenario a version identifier for the given code deployment 21 | # should be used, commonly this would be the commit hash from your VCS. 22 | # Here we're using a hash of installed packages combined with a hash of the files in 23 | # the redis_app directory. 24 | app_version = "-".join( 25 | [hashers.installed_packages_hash(), hashers.files_hash(Path(__file__).parent)] 26 | ) 27 | 28 | cache_backend = RedisBackend( 29 | host=settings.redis_host, 30 | port=settings.redis_port, 31 | prefix="redis_app", 32 | app_version=app_version, 33 | ) 34 | cache_manager = CacheManager(cache_backend) 35 | 36 | logger.info(f"App version: {app_version}") 37 | 38 | 39 | @app.on_event("startup") 40 | async def connect_db(): 41 | await db.database.connect() 42 | db.create_all() 43 | 44 | 45 | @app.on_event("shutdown") 46 | async def disconnect_db(): 47 | await db.database.disconnect() 48 | 49 | 50 | @app.post("/products", response_model=Product) 51 | async def create_product(product: Product): 52 | await db.create_product(product) 53 | await cache_manager.invalidate_tag("all-products") 54 | return product 55 | 56 | 57 | @app.put("/products/{product_id}", response_model=Product) 58 | async def update_product(product_id: uuid.UUID, product: Product): 59 | product.id = product_id 60 | 61 | if not await db.product_exists(product_id): 62 | raise HTTPException(404, detail="Product does not exist") 63 | 64 | await db.update_product(product) 65 | await cache_manager.invalidate_tags([f"product-{product.id}", "all-products"]) 66 | return product 67 | 68 | 69 | @app.get("/products", response_model=List[Product]) 70 | async def list_products(rcache: ResponseCache = cache_manager.from_request()): 71 | if rcache.exists(): 72 | return rcache.data 73 | 74 | products = await db.fetch_products() 75 | 76 | await asyncio.sleep(1) # Some heavy processing... 77 | 78 | await rcache.set(products, tag="all-products") 79 | 80 | return products 81 | 82 | 83 | @app.get("/products/{product_id}", response_model=Product) 84 | async def get_product( 85 | product_id: uuid.UUID, rcache: ResponseCache = cache_manager.from_request() 86 | ): 87 | if rcache.exists(): 88 | return rcache.data 89 | 90 | product = await db.fetch_product(product_id) 91 | 92 | if product is None: 93 | raise HTTPException(404, detail="Product not found") 94 | 95 | await rcache.set(product, tag=f"product-{product_id}") 96 | 97 | return product 98 | 99 | 100 | @app.get("/__admin/reset/all") 101 | async def reset_all_data(): 102 | db.clear_data() 103 | await cache_backend.reset() 104 | return {"ok": True} 105 | 106 | 107 | @app.get("/__admin/reset/cache") 108 | async def reset_cache(): 109 | await cache_backend.reset() 110 | return {"ok": True} 111 | 112 | 113 | @app.get("/__admin/reset/db") 114 | async def reset_db(): 115 | db.clear_data() 116 | return {"ok": True} 117 | -------------------------------------------------------------------------------- /examples/redis_app/db.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import sqlalchemy as sa 4 | from databases import Database 5 | 6 | from . import db_types 7 | from .settings import get_settings 8 | 9 | db_uri = get_settings().db_uri 10 | 11 | database = Database(db_uri) 12 | metadata = sa.MetaData() 13 | 14 | product_tbl = sa.Table( 15 | "product", 16 | metadata, 17 | sa.Column("id", db_types.Uuid, primary_key=True), 18 | sa.Column("name", sa.Text), 19 | sa.Column("description", sa.Text), 20 | sa.Column("created_at", sa.DateTime), 21 | ) 22 | 23 | 24 | def create_all(): 25 | metadata.create_all(bind=sa.create_engine(db_uri)) 26 | 27 | 28 | def clear_data(): 29 | engine = sa.create_engine(db_uri) 30 | for tbl in metadata.tables.values(): 31 | engine.execute(tbl.delete()) 32 | 33 | 34 | async def create_product(product): 35 | await database.execute(product_tbl.insert().values(product.dict())) 36 | 37 | 38 | async def update_product(product): 39 | values = product.dict() 40 | await database.execute( 41 | product_tbl.update().values(values).where(product_tbl.c.id == product.id) 42 | ) 43 | 44 | 45 | async def fetch_products(): 46 | return await database.fetch_all(product_tbl.select()) 47 | 48 | 49 | async def product_exists(product_id: uuid.UUID) -> bool: 50 | val = await database.fetch_val( 51 | sa.select([product_tbl.c.id]).where(product_tbl.c.id == product_id).limit(1) 52 | ) 53 | return val is not None 54 | 55 | 56 | async def fetch_product(product_id: uuid.UUID): 57 | return await database.fetch_one( 58 | product_tbl.select().where(product_tbl.c.id == product_id) 59 | ) 60 | -------------------------------------------------------------------------------- /examples/redis_app/db_types.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from sqlalchemy.dialects import postgresql 4 | from sqlalchemy.types import CHAR, TypeDecorator 5 | 6 | 7 | class Uuid(TypeDecorator): 8 | """Platform-independent UUID type. 9 | 10 | Uses PostgreSQL's UUID type, otherwise uses 11 | CHAR(32), storing as stringified hex values. 12 | 13 | From https://docs.sqlalchemy.org/en/13/core/custom_types.html#backend-agnostic-guid-type 14 | """ 15 | 16 | impl = CHAR 17 | 18 | def load_dialect_impl(self, dialect): 19 | if dialect.name == "postgresql": 20 | type_ = postgresql.UUID() 21 | else: 22 | type_ = CHAR(32) 23 | return dialect.type_descriptor(type_) 24 | 25 | def process_bind_param(self, value, dialect): 26 | if value is None: 27 | return value 28 | elif dialect.name == "postgresql": 29 | return str(value) 30 | else: 31 | if not isinstance(value, uuid.UUID): 32 | return "%.32x" % uuid.UUID(value).int 33 | else: 34 | # hexstring 35 | return "%.32x" % value.int 36 | 37 | def process_result_value(self, value, dialect): 38 | if value is None: 39 | return value 40 | else: 41 | if not isinstance(value, uuid.UUID): 42 | value = uuid.UUID(value) 43 | return value 44 | -------------------------------------------------------------------------------- /examples/redis_app/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime 3 | 4 | from pydantic import BaseModel, Field, validator 5 | 6 | from .utils import random_ascii 7 | 8 | 9 | class Product(BaseModel): 10 | id: uuid.UUID = Field(default_factory=uuid.uuid4) 11 | name: str 12 | description: str = "" 13 | created_at: datetime = Field(default_factory=datetime.utcnow) 14 | 15 | @validator("description", pre=True, always=True) 16 | def ensure_random_description(cls, v): 17 | return v or random_ascii(20_000) 18 | -------------------------------------------------------------------------------- /examples/redis_app/settings.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | from pydantic import BaseSettings 4 | 5 | __all__ = ("get_settings",) 6 | 7 | 8 | class _Settings(BaseSettings): 9 | log_level: str = "INFO" 10 | redis_host: str = "127.0.0.1" 11 | redis_port: int = 6379 12 | db_uri: str = "sqlite:///redis_app.db" 13 | 14 | 15 | @lru_cache() 16 | def get_settings() -> _Settings: 17 | return _Settings() # Reads variables from environment 18 | -------------------------------------------------------------------------------- /examples/redis_app/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | ASCII_ALPHA_NUM = string.ascii_lowercase + string.ascii_uppercase + string.digits + " " 5 | 6 | 7 | def random_ascii(n: int): 8 | return "".join(random.choice(ASCII_ALPHA_NUM) for _ in range(n)) 9 | -------------------------------------------------------------------------------- /fastapi_caching/__init__.py: -------------------------------------------------------------------------------- 1 | from .backends import * # noqa 2 | from .exceptions import * # noqa 3 | from .manager import * # noqa 4 | from .objects import * # noqa 5 | -------------------------------------------------------------------------------- /fastapi_caching/backends.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | from typing import Any, Optional, Sequence 4 | 5 | import cachetools 6 | 7 | from . import constants 8 | from .exceptions import CachingNotEnabled 9 | from .raw import RawCacheObject 10 | 11 | try: 12 | import aioredis 13 | except ImportError: 14 | aioredis = None 15 | 16 | 17 | __all__ = ("RedisBackend", "InMemoryBackend", "NoOpBackend") 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class CacheBackendBase: 23 | def setup(self): 24 | """Configure backend lazily, may be needed in advanced use cases""" 25 | # NOTE: Default method is a no-op 26 | 27 | def enable(self): 28 | self._is_enabled = True 29 | 30 | def disable(self): 31 | self._is_enabled = False 32 | 33 | def is_enabled(self) -> bool: 34 | return getattr(self, "_is_enabled", True) 35 | 36 | async def get(self, key: str) -> Optional[RawCacheObject]: 37 | self._ensure_enabled() 38 | return await self._get_impl(key) 39 | 40 | async def set( 41 | self, 42 | key: str, 43 | obj: RawCacheObject, 44 | *, 45 | tags: Sequence[str] = (), 46 | ttl: int = None, 47 | ) -> bool: 48 | self._ensure_enabled() 49 | if not isinstance(obj, RawCacheObject): 50 | obj = RawCacheObject(data=obj) 51 | return await self._set_impl(key, obj, tags=tags, ttl=ttl) 52 | 53 | async def invalidate_tag(self, tag: str): 54 | """Delete cache entries associated with the given tag""" 55 | self._ensure_enabled() 56 | return await self._invalidate_tag_impl(tag) 57 | 58 | async def invalidate_tags(self, tags: Sequence[str]): 59 | """Delete cache entries associated with the given tags""" 60 | self._ensure_enabled() 61 | return await self._invalidate_tags_impl(tags) 62 | 63 | async def reset(self): 64 | """Delete all stored cache related keys""" 65 | self._ensure_enabled() 66 | return await self._reset_impl() 67 | 68 | async def _get_impl(self, key: str) -> Optional[RawCacheObject]: 69 | raise NotImplementedError 70 | 71 | async def _set_impl( 72 | self, 73 | key: str, 74 | cache_object: RawCacheObject, 75 | *, 76 | tags: Sequence[str] = (), 77 | ttl: int = None, 78 | ) -> bool: 79 | raise NotImplementedError 80 | 81 | async def _invalidate_tag_impl(self, tag: str): 82 | raise NotImplementedError 83 | 84 | async def _invalidate_tags_impl(self, tags: Sequence[str]): 85 | raise NotImplementedError 86 | 87 | async def _reset_impl(self): 88 | raise NotImplementedError 89 | 90 | def _dumps(self, obj: Any) -> bytes: 91 | return pickle.dumps(obj) 92 | 93 | def _loads(self, raw: bytes) -> Any: 94 | return pickle.loads(raw) 95 | 96 | def _ensure_enabled(self): 97 | if not self.is_enabled(): 98 | raise CachingNotEnabled() 99 | 100 | 101 | class NoOpBackend(CacheBackendBase): 102 | async def _get_impl(self, key: str) -> Optional[RawCacheObject]: 103 | return None 104 | 105 | async def _set_impl( 106 | self, 107 | key: str, 108 | cache_object: RawCacheObject, 109 | *, 110 | tags: Sequence[str] = (), 111 | ttl: int = None, 112 | ) -> bool: 113 | return True 114 | 115 | async def _invalidate_tag_impl(self, tag: str): 116 | pass 117 | 118 | async def _invalidate_tags_impl(self, tags: Sequence[str]): 119 | pass 120 | 121 | async def _reset_impl(self): 122 | pass 123 | 124 | 125 | class InMemoryBackend(CacheBackendBase): 126 | def __init__( 127 | self, maxsize: int = 50_000, ttl: int = constants.DEFAULT_TTL, 128 | ): 129 | self._cached = cachetools.TTLCache(maxsize, ttl) 130 | self._tag_to_keys = {} 131 | 132 | def setup( 133 | self, *, maxsize: int = 50_000, ttl: int = constants.DEFAULT_TTL, 134 | ): 135 | """Configure backend lazily, may be needed in advanced use cases""" 136 | self._cached = cachetools.TTLCache(maxsize, ttl) 137 | 138 | async def _get_impl(self, key: str) -> Optional[RawCacheObject]: 139 | try: 140 | obj = self._cached[key] 141 | except KeyError: 142 | return None 143 | else: 144 | return self._loads(obj) 145 | 146 | async def _set_impl( 147 | self, 148 | key: str, 149 | cache_object: RawCacheObject, 150 | *, 151 | tags: Sequence[str] = (), 152 | ttl: int = None, 153 | ) -> bool: 154 | self._cached[key] = self._dumps(cache_object) 155 | for tag in tags: 156 | tag_keys = self._tag_to_keys.setdefault(tag, []) 157 | tag_keys.append(key) 158 | return True 159 | 160 | async def _invalidate_tag_impl(self, tag: str): 161 | try: 162 | keys = self._tag_to_keys[tag] 163 | except KeyError: 164 | return 165 | else: 166 | for key in keys: 167 | self._cached.pop(key, None) 168 | del self._tag_to_keys[tag] 169 | 170 | async def _invalidate_tags_impl(self, tags: Sequence[str]): 171 | for tag in tags: 172 | await self.invalidate_tag(tag) 173 | 174 | async def _reset_impl(self): 175 | """Reset cache completely""" 176 | self._cached = {} 177 | self._tag_to_keys = {} 178 | 179 | 180 | class RedisBackend(CacheBackendBase): 181 | def __init__( 182 | self, 183 | *, 184 | host: str = "127.0.0.1", 185 | port: int = 6379, 186 | password: str = None, 187 | prefix: str = "fastapi-caching", 188 | app_version: str = None, 189 | ttl: int = constants.DEFAULT_TTL, 190 | redis: Any = None, 191 | ): 192 | self._app_version = app_version 193 | self._host = host 194 | self._port = port 195 | self._password = password 196 | self._prefix = prefix 197 | self._ttl = ttl 198 | self._redis = redis 199 | self._setup_prefix(prefix) 200 | 201 | def setup( 202 | self, 203 | *, 204 | host: str = None, 205 | port: int = None, 206 | password: str = None, 207 | prefix: str = None, 208 | app_version: str = None, 209 | ttl: int = None, 210 | ): 211 | """Configure backend lazily, may be needed in advanced use cases""" 212 | if host is not None: 213 | self._host = host 214 | if port is not None: 215 | self._port = port 216 | if password is not None: 217 | self._password = password 218 | if prefix is not None: 219 | self._setup_prefix(prefix) 220 | if app_version is not None: 221 | self._app_version = app_version 222 | if ttl is not None: 223 | self._ttl = ttl 224 | 225 | def _setup_prefix(self, prefix: str): 226 | if not prefix: 227 | raise RuntimeError("`prefix` is required for redis backend") 228 | if prefix.endswith(":"): 229 | prefix = prefix[0:-1] 230 | self._prefix = prefix 231 | 232 | def _get_full_prefix(self) -> str: 233 | if self._app_version is not None: 234 | return f"{self._prefix}:{self._app_version}" 235 | else: 236 | return self._prefix 237 | 238 | async def _get_impl(self, key: str) -> Optional[RawCacheObject]: 239 | redis = await self._get_redis() 240 | obj = await redis.get(self._prefixed(key)) 241 | if obj is None: 242 | return None 243 | else: 244 | return self._loads(obj) 245 | 246 | async def _set_impl( 247 | self, 248 | key: str, 249 | cache_object: RawCacheObject, 250 | *, 251 | tags: Sequence[str] = (), 252 | ttl: int = None, 253 | ) -> bool: 254 | redis = await self._get_redis() 255 | dumped = self._dumps(cache_object) 256 | tr = redis.multi_exec() 257 | 258 | tr.set(self._prefixed(key), dumped, expire=ttl or self._ttl) 259 | 260 | for tag in tags: 261 | logger.debug(f"Adding key {key} to tag {tag}") 262 | tr.sadd(self._prefixed(f"tags_to_keys:{tag}"), key) 263 | 264 | success, *rest = await tr.execute() 265 | return success 266 | 267 | async def _invalidate_tag_impl(self, tag: str): 268 | await self.invalidate_tags([tag]) 269 | 270 | async def _invalidate_tags_impl(self, tags: Sequence[str]): 271 | redis = await self._get_redis() 272 | 273 | all_keys = [] 274 | for tag in tags: 275 | tag_key = self._prefixed(f"tags_to_keys:{tag}") 276 | keys = [ 277 | *( 278 | self._prefixed(k) 279 | for k in await redis.smembers(tag_key, encoding="utf-8",) 280 | ), 281 | tag_key, 282 | ] 283 | logger.debug(f"Invalidating tag {tag}, containing the given keys: {keys}") 284 | all_keys.extend(keys) 285 | 286 | if len(all_keys) > 0: 287 | await redis.unlink(*all_keys) 288 | 289 | async def _reset_impl(self): 290 | await self._unlink_by_prefix(self._prefix) 291 | 292 | async def reset_version(self): 293 | """Delete all stored cache related keys for the current app version""" 294 | full_prefix = self._get_full_prefix() 295 | await self._unlink_by_prefix(full_prefix) 296 | 297 | async def _get_redis(self): 298 | if self._redis is None: 299 | if aioredis is None: 300 | raise RuntimeError( 301 | "Cannot instantiate Redis backend without aioredis installed", 302 | ) 303 | else: 304 | self._redis = await aioredis.create_redis_pool( 305 | f"redis://{self._host}:{self._port}", password=self._password 306 | ) 307 | return self._redis 308 | 309 | async def _unlink_by_prefix(self, prefix: str): 310 | if prefix.endswith(":"): 311 | prefix = prefix[0:-1] 312 | redis = await self._get_redis() 313 | resp = await redis.eval( 314 | """local keys = redis.call('keys', ARGV[1]) 315 | local unpack = table.unpack or unpack 316 | for i=1,#keys,5000 do 317 | redis.call('unlink', unpack(keys, i, math.min(i+4999, #keys))) 318 | end 319 | return keys""", 320 | args=[f"{prefix}:*"], 321 | ) 322 | unlinked_keys = ", ".join(k.decode() for k in resp) 323 | logger.debug(f"Unlinked keys: {unlinked_keys}") 324 | 325 | def _prefixed(self, unprefixed_key: str) -> str: 326 | full_prefix = self._get_full_prefix() 327 | return f"{full_prefix}:{unprefixed_key}" 328 | -------------------------------------------------------------------------------- /fastapi_caching/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_TTL: int = 60 * 60 * 24 # 1 day 2 | -------------------------------------------------------------------------------- /fastapi_caching/dependencies.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from starlette.requests import Request 4 | 5 | from .backends import CacheBackendBase 6 | from .objects import NoOpResponseCache, ResponseCache 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | __all__ = ("ResponseCacheDependency",) 11 | 12 | 13 | class ResponseCacheDependency: 14 | def __init__( 15 | self, 16 | backend: CacheBackendBase, 17 | *, 18 | no_cache_query_param: str = "no-cache", 19 | ttl: int = None, 20 | ): 21 | self._backend = backend 22 | self._no_cache_query_param = no_cache_query_param 23 | self._ttl = ttl 24 | 25 | async def __call__(self, request: Request) -> ResponseCache: 26 | cache = ResponseCache( 27 | self._backend, 28 | request, 29 | no_cache_query_param=self._no_cache_query_param, 30 | ttl=self._ttl, 31 | ) 32 | 33 | if not self._backend.is_enabled(): 34 | logger.debug( 35 | f"{cache.key}: Caching backend not enabled - returning no-op cache" 36 | ) 37 | return NoOpResponseCache() 38 | elif "authorization" in request.headers: 39 | logger.debug( 40 | f"{cache.key}: Authorization header set - not fetching from cache" 41 | ) 42 | return NoOpResponseCache() 43 | elif ( 44 | self._no_cache_query_param is not None 45 | and self._no_cache_query_param in request.query_params 46 | ): 47 | logger.debug( 48 | f"{cache.key}: The no-cache query parameter was specified - not " 49 | f"fetching from cache, but will update it afterwards." 50 | ) 51 | return cache 52 | else: 53 | await cache.fetch() 54 | if cache.data is None: 55 | logger.debug(f"{cache.key}: No cached response data found") 56 | else: 57 | logger.debug( 58 | f"{cache.key}: Found cached response data with " 59 | f"timestamp {cache.obj.timestamp}" 60 | ) 61 | return cache 62 | -------------------------------------------------------------------------------- /fastapi_caching/exceptions.py: -------------------------------------------------------------------------------- 1 | class CachingNotEnabled(Exception): 2 | """Raised when the caching backend is accessed while it's disabled""" 3 | -------------------------------------------------------------------------------- /fastapi_caching/hashers.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import pathlib 3 | from typing import Sequence, Union 4 | 5 | import pkg_resources 6 | 7 | __all__ = ("files_hash", "installed_packages_hash") 8 | 9 | 10 | def files_hash( 11 | *paths: Union[pathlib.Path, str], 12 | include: Sequence[str] = ("*",), 13 | digest_size: int = 4, 14 | ) -> str: 15 | """Create a reproducible hash (hex digest) from the file contents of the given paths 16 | 17 | Paths that point to directories will be traversed recursively. 18 | 19 | Mainly useful for cache busting scenarios. 20 | 21 | Args: 22 | paths: The directory/file paths to extract file contents from 23 | include: 24 | The types of files to include in the creation of the hash, 25 | specified as a list of file extensions. 26 | Defaults to all files. Example: ["py", "rs"] 27 | digest_size: 28 | The amount of bytes to use for the hex digest to return. 29 | Defaults to 4 which equals 8 characters. 30 | 31 | """ 32 | h = hashlib.blake2b(digest_size=digest_size) 33 | 34 | for path in paths: 35 | path = pathlib.Path(path) 36 | if path.is_dir(): 37 | for sub_path in sorted( 38 | p 39 | for file_ext in include 40 | for p in path.rglob(f"*.{file_ext}") 41 | if p.is_file() 42 | ): 43 | h.update(str(sub_path).encode()) 44 | h.update(sub_path.read_bytes()) 45 | elif path.is_file(): 46 | h.update(str(path).encode()) 47 | h.update(path.read_bytes()) 48 | else: 49 | raise RuntimeError(f"Unsupported path: {path} (does it exist?)") 50 | 51 | return h.hexdigest() 52 | 53 | 54 | def installed_packages_hash(digest_size: int = 4) -> str: 55 | """Return a reproducible hash (hex digest) of the installed python packages 56 | 57 | Args: 58 | digest_size: 59 | The amount of bytes to use for the hex digest to return. 60 | Defaults to 4 which equals 8 characters. 61 | 62 | """ 63 | packages = sorted( 64 | f"{d.project_name}=={d.version}".encode() for d in pkg_resources.working_set 65 | ) 66 | return hashlib.blake2b(b" ".join(packages), digest_size=digest_size).hexdigest() 67 | -------------------------------------------------------------------------------- /fastapi_caching/manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Sequence 3 | 4 | from fastapi import Depends 5 | 6 | from . import constants 7 | from .backends import CacheBackendBase 8 | from .dependencies import ResponseCacheDependency 9 | 10 | __all__ = ("CacheManager",) 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class CacheManager: 16 | def __init__( 17 | self, 18 | backend: CacheBackendBase, 19 | *, 20 | ttl: int = constants.DEFAULT_TTL, 21 | no_cache_query_param: str = "no-cache", 22 | ): 23 | self._backend = backend 24 | self._ttl = ttl 25 | self._no_cache_query_param = no_cache_query_param 26 | 27 | def setup( 28 | self, *, ttl: int = None, no_cache_query_param: str = None, 29 | ): 30 | if ttl is not None: 31 | self._ttl = ttl 32 | if no_cache_query_param is not None: 33 | self._no_cache_query_param = no_cache_query_param 34 | 35 | def enable(self): 36 | self._backend.enable() 37 | 38 | def disable(self): 39 | self._backend.disable() 40 | 41 | def is_enabled(self) -> bool: 42 | self._backend.is_enabled() 43 | 44 | @property 45 | def backend(self) -> CacheBackendBase: 46 | return self._backend 47 | 48 | def from_request(self, ttl: int = None) -> Depends: 49 | d = ResponseCacheDependency( 50 | self.backend, no_cache_query_param=self._no_cache_query_param, ttl=ttl, 51 | ) 52 | return Depends(d) 53 | 54 | async def invalidate_tag(self, tag: str): 55 | """Delete cache entries associated with the given tag""" 56 | await self.backend.invalidate_tag(tag) 57 | 58 | async def invalidate_tags(self, tags: Sequence[str]): 59 | """Delete cache entries associated with the given tags""" 60 | await self.backend.invalidate_tags(tags) 61 | -------------------------------------------------------------------------------- /fastapi_caching/objects.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Sequence 2 | 3 | from starlette.requests import Request 4 | 5 | from .backends import CacheBackendBase 6 | from .raw import RawCacheObject 7 | 8 | __all__ = ("ResponseCache",) 9 | 10 | 11 | class ResponseCache: 12 | """Object returned by ResponseCacheDependency 13 | 14 | Dependency ensures that an existing cached item for the given endpoint is 15 | automatically fetched. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | backend: CacheBackendBase, 21 | request: Request, 22 | no_cache_query_param: str = "no-cache", 23 | ttl: int = None, 24 | ): 25 | self._backend = backend 26 | self._request = request 27 | self._no_cache_query_param = no_cache_query_param 28 | self._ttl = ttl 29 | self.key = self._make_key(request) 30 | self._obj = None 31 | 32 | @property 33 | def obj(self) -> RawCacheObject: 34 | return self._obj 35 | 36 | @property 37 | def data(self) -> Any: 38 | return None if self._obj is None else self._obj.data 39 | 40 | def exists(self) -> bool: 41 | """Return whether or not there's an existing cache for this response""" 42 | return self._obj is not None 43 | 44 | async def fetch(self): 45 | """Fetch and associate existing cache data""" 46 | self._obj = await self._backend.get(self.key) 47 | 48 | async def set( 49 | self, data: Any, *, ttl: int = None, tag: str = None, tags: Sequence[Any] = (), 50 | ) -> bool: 51 | tags = list(tags) 52 | if tag is not None: 53 | tags.append(tag) 54 | return await self._backend.set( 55 | key=self.key, 56 | obj=self._make_raw_cache_object(data), 57 | tags=tags, 58 | ttl=ttl or self._ttl, 59 | ) 60 | 61 | def _make_raw_cache_object(self, data: Any) -> RawCacheObject: 62 | return RawCacheObject(data) 63 | 64 | def _make_key(self, request: Request) -> str: 65 | parts = [request.method, request.url.path] 66 | for k in request.query_params.keys(): 67 | if k == self._no_cache_query_param: 68 | continue 69 | for v in request.query_params.getlist(k): 70 | parts.append(f"{k}={v}") 71 | return "|".join(sorted(parts)) 72 | 73 | 74 | class NoOpResponseCache(ResponseCache): 75 | """No-op version of the ResponseCache object returned by CacheDependency""" 76 | 77 | def __init__(self): 78 | self._obj = None 79 | 80 | async def fetch(self, *args, **kw): 81 | return 82 | 83 | async def set(self, *args, **kw): 84 | return 85 | -------------------------------------------------------------------------------- /fastapi_caching/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobsvante/fastapi-caching/3c85f8f262813604587cc88a50cb9b22520c82f7/fastapi_caching/py.typed -------------------------------------------------------------------------------- /fastapi_caching/raw.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from datetime import datetime 3 | from typing import Any, Dict 4 | 5 | __all__ = ("RawCacheObject",) 6 | 7 | 8 | @dataclass 9 | class RawCacheObject: 10 | data: Any 11 | timestamp: datetime = field(default_factory=datetime.utcnow) 12 | meta: Dict[str, Any] = field(default_factory=dict) # For future usage 13 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | 4 | [tool.black] 5 | line-length = 88 6 | target-version = ['py37', 'py38'] 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | install_requires = ["fastapi", "cachetools"] 6 | 7 | extras_require = { 8 | "redis": ["aioredis"], 9 | "examples": [ 10 | "uvicorn==0.11.5", # Logging issues with newer versions 11 | "databases[sqlite]", 12 | ], 13 | "test": [ 14 | "pytest>=5.4.0", 15 | "pytest-cov", 16 | "pytest-asyncio", 17 | "requests", 18 | "httpx", 19 | "fakeredis", 20 | "lupa", 21 | ], 22 | "dev": ["black", "isort", "watchgod>=0.6,<0.7"], 23 | } 24 | 25 | extras_require["all"] = [ 26 | dep for _, dependencies in extras_require.items() for dep in dependencies 27 | ] 28 | 29 | 30 | this_directory = os.path.abspath(os.path.dirname(__file__)) 31 | with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: 32 | long_description = f.read() 33 | 34 | setup( 35 | name="fastapi-caching", 36 | version="0.3.0", 37 | description="Cache library for FastAPI with tag based invalidation", 38 | url="https://github.com/jmagnusson/fastapi-caching", 39 | author="Jacob Magnusson", 40 | author_email="m@jacobian.se", 41 | long_description=long_description, 42 | long_description_content_type="text/markdown", 43 | license="MIT", 44 | packages=find_packages(".", include=["fastapi_caching"]), 45 | python_requires=">=3.6", 46 | zip_safe=False, 47 | package_data={"fastapi_caching": ["py.typed"]}, 48 | install_requires=install_requires, 49 | extras_require=extras_require, 50 | classifiers=[ 51 | "Intended Audience :: Information Technology", 52 | "Intended Audience :: System Administrators", 53 | "Operating System :: OS Independent", 54 | "Programming Language :: Python :: 3", 55 | "Programming Language :: Python", 56 | "Topic :: Internet", 57 | "Topic :: Software Development :: Libraries :: Python Modules", 58 | "Topic :: Software Development :: Libraries", 59 | "Topic :: Software Development", 60 | "Typing :: Typed", 61 | "Development Status :: 2 - Pre-Alpha", 62 | "Environment :: Web Environment", 63 | "Framework :: AsyncIO", 64 | "Intended Audience :: Developers", 65 | "License :: OSI Approved :: MIT License", 66 | "Programming Language :: Python :: 3 :: Only", 67 | "Programming Language :: Python :: 3.6", 68 | "Programming Language :: Python :: 3.7", 69 | "Programming Language :: Python :: 3.8", 70 | "Topic :: Internet :: WWW/HTTP", 71 | ], 72 | ) 73 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobsvante/fastapi-caching/3c85f8f262813604587cc88a50cb9b22520c82f7/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fastapi import FastAPI 3 | from httpx import AsyncClient 4 | 5 | from . import helpers 6 | 7 | 8 | @pytest.fixture 9 | def app(): 10 | return FastAPI() 11 | 12 | 13 | @pytest.fixture 14 | async def async_client(app): 15 | async with AsyncClient(app=app, base_url="http://127.0.0.1") as ac: 16 | yield ac 17 | 18 | 19 | @pytest.fixture 20 | def redis_backend(): 21 | return helpers.make_redis_backend() 22 | 23 | 24 | @pytest.fixture 25 | def inmem_backend(): 26 | return helpers.make_inmemory_backend() 27 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | import aioredis 2 | from fakeredis.aioredis import FakeConnectionsPool 3 | 4 | from fastapi_caching import InMemoryBackend, RedisBackend 5 | 6 | 7 | def make_inmemory_backend(): 8 | return InMemoryBackend() 9 | 10 | 11 | def make_redis_backend(): 12 | pool = FakeConnectionsPool(minsize=1, maxsize=10) 13 | return RedisBackend(redis=aioredis.Redis(pool)) 14 | 15 | 16 | def make_caching_backends(): 17 | return (make_inmemory_backend(), make_redis_backend()) 18 | -------------------------------------------------------------------------------- /tests/test_backends.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fastapi_caching import CachingNotEnabled, InMemoryBackend, RedisBackend 4 | 5 | from . import helpers 6 | 7 | 8 | @pytest.mark.asyncio 9 | @pytest.mark.parametrize("cache_backend", helpers.make_caching_backends()) 10 | async def test_that_cache_can_be_set(cache_backend): 11 | was_set = await cache_backend.set("hello_key", "world") 12 | assert was_set is True 13 | 14 | cache_obj = await cache_backend.get("hello_key") 15 | 16 | assert cache_obj.data == "world" 17 | 18 | 19 | @pytest.mark.asyncio 20 | @pytest.mark.parametrize("cache_backend", helpers.make_caching_backends()) 21 | async def test_that_cache_does_not_return_data_on_non_existing_keys(cache_backend): 22 | data = await cache_backend.get("hello_key") 23 | assert data is None 24 | 25 | 26 | @pytest.mark.asyncio 27 | @pytest.mark.parametrize("cache_backend", helpers.make_caching_backends()) 28 | async def test_that_cache_can_be_cleared(cache_backend): 29 | await cache_backend.set("a", "b") 30 | await cache_backend.set("c", "d") 31 | 32 | await cache_backend.reset() 33 | 34 | a_obj = await cache_backend.get("a") 35 | b_obj = await cache_backend.get("b") 36 | 37 | assert a_obj is None 38 | assert b_obj is None 39 | 40 | 41 | def test_that_redis_backend_can_be_configured_lazily(): 42 | backend = RedisBackend() 43 | backend.setup(prefix="my-cool-app") 44 | assert backend._get_full_prefix() == "my-cool-app" 45 | 46 | 47 | def test_that_inmemory_backend_can_be_configured_lazily(): 48 | backend = InMemoryBackend() 49 | backend.setup(maxsize=2) 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_that_exception_is_raised_for_disabled_backend(): 54 | cache_backend = InMemoryBackend() 55 | cache_backend.disable() 56 | with pytest.raises(CachingNotEnabled): 57 | await cache_backend.set("a", "b") 58 | with pytest.raises(CachingNotEnabled): 59 | await cache_backend.get("a") 60 | with pytest.raises(CachingNotEnabled): 61 | await cache_backend.invalidate_tag("foo") 62 | with pytest.raises(CachingNotEnabled): 63 | await cache_backend.invalidate_tags(["foo", "bar"]) 64 | with pytest.raises(CachingNotEnabled): 65 | await cache_backend.invalidate_tags(["foo", "bar"]) 66 | with pytest.raises(CachingNotEnabled): 67 | await cache_backend.reset() 68 | -------------------------------------------------------------------------------- /tests/test_dependencies.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fastapi_caching import CacheManager, InMemoryBackend, ResponseCache 4 | from fastapi_caching.objects import NoOpResponseCache 5 | 6 | 7 | @pytest.mark.asyncio 8 | async def test_that_response_cache_can_be_set(app, async_client): 9 | cache_backend = InMemoryBackend() 10 | cache_manager = CacheManager(cache_backend) 11 | 12 | cached_object = await cache_backend.get("/|GET") 13 | assert cached_object is None 14 | 15 | @app.get("/") 16 | async def get_products(rcache: ResponseCache = cache_manager.from_request()): 17 | assert rcache.__class__ is ResponseCache 18 | resp_data = [{"foo": "bar"}] 19 | await rcache.set(resp_data) 20 | return resp_data 21 | 22 | resp = await async_client.get("/") 23 | assert resp.status_code == 200 24 | 25 | cached_object = await cache_backend.get("/|GET") 26 | assert cached_object.data == [{"foo": "bar"}] 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_that_cache_isnt_fetched_from_disabled_backend(app, async_client): 31 | cache_backend = InMemoryBackend() 32 | cache_manager = CacheManager(cache_backend) 33 | cache_manager.disable() 34 | 35 | @app.get("/") 36 | async def home(rcache: ResponseCache = cache_manager.from_request()): 37 | assert rcache.__class__ is NoOpResponseCache 38 | 39 | await async_client.get("/") 40 | -------------------------------------------------------------------------------- /tests/test_manager.py: -------------------------------------------------------------------------------- 1 | from fastapi_caching import CacheManager, InMemoryBackend 2 | 3 | 4 | def test_that_ttl_can_be_set_lazily(): 5 | cache_backend = InMemoryBackend() 6 | cache_manager = CacheManager(cache_backend) 7 | cache_manager.setup(ttl=1337) 8 | 9 | assert cache_manager.backend is not None 10 | --------------------------------------------------------------------------------