├── .gitignore ├── README.md ├── data_fetcher ├── __init__.py ├── core.py ├── extras.py ├── global_request_context.py ├── middleware.py ├── shorthand_fetcher_classes.py └── util.py ├── manage.py ├── pyproject.toml ├── pytest.ini ├── pytest_test_runner.py ├── requirements.txt ├── sample_app ├── __init__.py ├── data_factories.py ├── migrations │ ├── 0001_initial.py │ └── __init__.py ├── models.py ├── settings.py ├── urls.py ├── views.py └── wsgi.py ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── test_cache_within_request.py ├── test_data_fetchers.py ├── test_keyed_fetcher.py ├── test_middleware.py ├── test_model_fetchers.py └── test_singleton_fetcher.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | *.pyc 3 | __pycache__ 4 | *# 5 | *.swp 6 | *.swo 7 | .DS_Store 8 | githubtoken 9 | .vscode/settings.json 10 | .ipynb_checkpoints/ 11 | .coverage 12 | htmlcov/ 13 | dist/ 14 | *.egg-info 15 | publish.sh 16 | notes.md 17 | 18 | db.sqlite3 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # django-data-fetcher 2 | 3 | ## What is this? 4 | 5 | This library contains 3 utilities: 6 | 7 | 1. a function to access the currently-executing django request from anywhere 8 | - This can help with a wide variety of fringe use-cases, e.g. wanting to access the current request from a decoupled log helper function 9 | 2. A caching abstraction to cache a function's return value on the request 10 | - This is the most frequently useful thing in this package 11 | 3. A _data-fetcher_ abstraction to ease batching 12 | - In addition to caching, this helps with _batching_, i.e. N+1 queries 13 | - This is the least important, most complicated and most specific use case. For historical reasons, the entire library is named after this part. Naming things is hard :( 14 | 15 | 16 | 17 | ## Installation 18 | 19 | ```bash 20 | pip install django-data-fetcher 21 | ``` 22 | 23 | After installing, you'll need to add our middleware: 24 | 25 | 26 | ```python 27 | # settings.py 28 | # ... 29 | MIDDLEWARE = [ 30 | # ... 31 | "data_fetcher.middleware.GlobalRequestMiddleware", 32 | # ... 33 | ] 34 | ``` 35 | 36 | ## Usage 37 | 38 | 39 | ### Accessing the global request object 40 | 41 | Thanks to the middleware, accessing the global request is simple: 42 | 43 | ```python 44 | from data_fetcher.util import get_request 45 | 46 | def some_random_util_function(): 47 | request = get_request() 48 | do_something_with_request(request) 49 | ``` 50 | 51 | 52 | ### Caching 53 | 54 | If you'd just like to cache a function so you don't repeat it in different places, you can just use the `cache_within_request` decorator: 55 | 56 | ```python 57 | from data_fetcher import cache_within_request 58 | 59 | @cache_within_request 60 | def get_most_recent_order(user_id): 61 | return Order.objects.filter(user_id=user_id).order_by('-created_at').first() 62 | ``` 63 | 64 | Now you can call `get_most_recent_order` as many times as you want within a request, e.g. in template helpers and in views, and it will only hit the database once (assuming you use the same user_id). This is a wrapper around `functools.cache`, so it will also cache across calls to the same function with the same arguments. 65 | 66 | ### Batching 67 | 68 | This library also supports _batching_ fetching logic. You need to subclass our `DataFetcher` class and implement a `batch_load` (or `batch_load_dict`) method with the batching logic. Then you can use its factory method to get an instance of your fetcher class, and call its `get()`, `get_many()`, or `prefetch_keys()` methods. 69 | 70 | For example, it's usually pretty difficult to efficiently fetch permissions for a list of objects without coupling your view to your templates/helpers. With this library, we can offload the work to a data-fetcher instance and have a re-usable template-helper that checks permissions. When we notice performance problems, we simply add a `prefetch_keys` call to our view to pre-populate the cache. 71 | 72 | ```python 73 | # my_app/fetchers.py 74 | 75 | from data_fetcher import DataFetcher 76 | 77 | class ArticlePermissionFetcher(DataFetcher): 78 | def batch_load_dict(self, article_ids): 79 | permissions = ArticlePermission.objects.filter(article_id__in=article_ids) 80 | return {p.article_id: p for p in permissions} 81 | 82 | # my_app/template_helpers.py 83 | 84 | from my_app.fetchers import ArticlePermissionFetcher 85 | 86 | @register.simple_tag(takes_context=True) 87 | def can_read_article(context, article): 88 | """ 89 | called in a loop in article_list.html, e.g. to condtionally render a link 90 | """ 91 | fetcher = ArticlePermissionFetcher.get_instance() 92 | permission = fetcher.get(article.id) 93 | return permission.can_read 94 | 95 | # my_app/views.py 96 | 97 | from my_app.fetchers import ArticlePermissionFetcher 98 | 99 | def article_list(request): 100 | articles = Article.objects.all() 101 | fetcher = ArticlePermissionFetcher.get_instance() 102 | fetcher.prefetch_keys([a.id for a in articles]) 103 | return render(request, 'article_list.html', {'articles': articles}) 104 | 105 | ``` 106 | 107 | Behind the scenes, fetchers' `get_instance` will use the global-request middleware's request object to always return the same instance of the fetcher for the same request. This allows the fetcher to call your batch function once, when the view calls `prefetch_keys`, and then use the cached results for all subsequent calls to `get` or `get_many`. 108 | 109 | Fetchers also cache values that were called with `get` or `get_many`. If you request a key that isn't cached, it will call your batch method again for that single key. It's recommended to monitor your queries while developing with a tool like [django-debug-toolbar](https://github.com/jazzband/django-debug-toolbar/). 110 | 111 | 112 | #### Fetcher API 113 | 114 | 115 | Public method: 116 | 117 | - `get(key)` : fetch a single resource by key 118 | - `get_many(keys)` : fetch multiple resources by key, returns a list 119 | - `get_many_as_dict(keys)` : like get_many, but returns a dict indexed by your requested keys 120 | - `prefetch_keys(keys)` : Like get-many but returns nothing. Pre-populates the cache with a list of keys. This is useful when you know you're going to need a lot of objects, and you want to avoid N+1 queries. 121 | - `prime(key,value)` manually set a value in the cache. This isn't recommended, but it can be useful for performance in certain cases 122 | - `enqueue_keys(keys)` : Keys get added to queue, which gets fetched the next time get, get_many or prefetch_keys is called. It is often more convenient to use this than to collect all required keys and call prefetch_keys. 123 | - `get_lazy/get_lazy_many`: (*experimental) enqueues the key and returns a lazy object wrapper. The lazy object's `get()` method will return the value when called. This API might be replaced with smarter lazy objects in the future. 124 | 125 | Subclass-API: 126 | 127 | You can implement `batch_load(keys)` OR `batch_load_dict(keys)`. 128 | - `batch_load(keys)` needs to return a list of resources in the same order (and length) as the keys. If a resource is missing, you need an explicit None in the returned list. 129 | - `batch_load_dict(keys)` should return a dict of resources, indexed by the keys. If a value is missing, `None` will be returned when that key is requested (it tolerates missing keys). 130 | 131 | 132 | ## Shortcuts 133 | 134 | It's extremely common to want to fetch a single object by id, or by a parent's foreign key. We provide a few baseclasses for this: 135 | 136 | ```python 137 | from data_fetcher import AbstractModelByIdFetcher, AbstractChildModelByAttrFetcher 138 | 139 | class ArticleByIdFetcher(AbstractModelByIdFetcher): 140 | model = Article 141 | 142 | class ArticleByAuthorIdFetcher(AbstractChildModelByAttrFetcher): 143 | model = Article 144 | parent_attr = 'author_id' 145 | 146 | ``` 147 | 148 | In fact, the ID fetcher was so common we have a factory for it. This factory returns the same class every time, so you can use it in multiple places without worrying about creating multiple classes with distinct caches. 149 | 150 | ```python 151 | from data_fetcher import PrimaryKeyFetcherFactory 152 | 153 | ArticleByIdFetcher = PrimaryKeyFetcherFactory.get_model_by_id_fetcher(Article) 154 | ArticleByIdFetcher2 = PrimaryKeyFetcherFactory.get_model_by_id_fetcher(Article) 155 | assert ArticleByIdFetcher == ArticleByIdFetcher2 156 | 157 | article_1 = ArticleByIdFetcher.get_instance().get(1) 158 | ``` 159 | 160 | 161 | ## Testing data-fetchers 162 | 163 | Batch logic is often complex and error-prone. We recommend writing tests for your fetchers. This package provides a mock request object that you can use to test your fetchers. Without this context-manager, your fetchers won't be able to cache anything and might raise errors. Here's an example in pytest: 164 | 165 | ```python 166 | from data_fetcher.util import GlobalRequest 167 | 168 | def test_article_permission_fetcher(django_assert_num_queries): 169 | with GlobalRequest(): 170 | with django_assert_num_queries(1): 171 | fetcher = ArticlePermissionFetcher.get_instance(request) 172 | fetcher.prefetch_keys([1, 2, 3]) 173 | assert fetcher.get(1).can_read 174 | assert not fetcher.get(2).can_read 175 | assert not fetcher.get(3).can_read 176 | 177 | ``` 178 | 179 | Note that this context-manager also allows you to use the cache decorator and data-fetchers inside other scenarios, such as celery tasks. 180 | 181 | 182 | ## How to provide non-key data to fetchers 183 | 184 | Data-fetcher's main feature is not performance, but enabling decoupling. The view layer no longer has to be responsible for passing data to downstream consumers (e.g. utils, template-helpers, service-objects, etc.). 185 | 186 | This paradigm shift can be a challenging adjustment. For instance, our `ArticlePermissionFetcher` above was naïve. Permission records should be fetched with respect to a user. How can we provide the user's ID to the fetcher? 187 | 188 | It's tempting to subclass DataFetcher and add a user argument to its `get_instance()` method. Unfortunately, extending the factory pattern is rather complex. There are broadly 3 different ways to solve this problem: 189 | 190 | 1. Use the global request to get the user. This is the simplest solution, but it limits your data-fetcher to the current user. You couldn't, for example, build a view that shows a list of articles available to _other_ users. 191 | 2. Create composite-keys: instead of loading permissions by article id, you load them by `(user_id, article_id)` pairs. This is a good solution, but is often complex to implement and you usually don't need this flexibility. 192 | 3. Dynamically create a data-fetcher _class_ that has a reference to the user 193 | 194 | The 3rd solution fulfills the OOP temptation of adding a user argument to the constructor, but it's a "higher-order" solution. Rather than attaching the user to the fetcher-class, we would dynamically create a class that has a reference to the user, and then use a factory to ensure we recycle the same class for the same user. 195 | 196 | There's a builtin shortcut for this pattern, too, called `ValueBoundDataFetcher`. ValueBoundDataFetcher classes have a `bound_value` attribute available inside their batch-load methods. 197 | 198 | 199 | ```python 200 | # my_app/fetchers.py 201 | from data_fetcher import ValueBoundDataFetcher 202 | 203 | class UserBoundArticlePermissionFetcher(ValueBoundDataFetcher): 204 | def batch_load_dict(self, article_ids): 205 | user_id = self.bound_value 206 | permissions = ArticlePermission.objects.filter(user_id=user_id, article_id__in=article_ids) 207 | return { p.article_id: p for p in permissions } 208 | 209 | # my_app/views.py 210 | from my_app.fetchers import UserBoundArticlePermissionFetcher 211 | 212 | def article_list(request): 213 | # generate a class that has a reference to the user 214 | UserBoundArticlePermissionFetcher = ValueBoundDataFetcher.get_value_bound_class( 215 | UserBoundArticlePermissionFetcher, 216 | request.user.id 217 | ) 218 | fetcher = UserBoundArticlePermissionFetcher.get_instance() 219 | articles = Article.objects.all() 220 | fetcher.prefetch_keys([a.id for a in articles]) 221 | return render(request, 'article_list.html', {'articles': articles}) 222 | 223 | ``` 224 | 225 | With this solution, we're still able to create fetchers for multiple users. However, it won't be as efficient as the composite-key solution (e.g. one query per user vs. one query for all users). 226 | 227 | Note that `bound_value` can be anything, so you can use this pattern to provide more than a single piece of data to your fetcher, just make sure it's hashable so it can be used as a key (otherwise, you'll want to pass separate value and key kwargs to `get_value_bound_class`. 228 | 229 | 230 | ## Recipe: Caching a single data-structure with complex data 231 | 232 | Batching logic often has a high-cognitive load, it may not be worth it to batch everything. Fortunately, the `@cache_within_request` decorator can cache anything, there's no need to restrict ourselves to a single resource. For instance, let's say we have a complex home-feed page that needs to fetch a lot of data for a particular user. We can use the `cache_within_request` decorator to cache the entire data-structure. 233 | 234 | 235 | ```python 236 | @cache_within_request 237 | def get_home_feed_data(user_id): 238 | user = User.objects.filter(id=user_id).prefetch_related( 239 | Prefetch('articles', queryset=Article.objects.filter(deleted=False), to_attr='article_list'), 240 | Prefetch('articles__comments', queryset=Comment.objects.filter(deleted=False), to_attr='comment_list'), 241 | Prefetch('articles__author', queryset=User.objects.all(), to_attr='author_list'), 242 | # ... 243 | ) 244 | more_data = get_more_data(user) 245 | # assemble a rich data structure with convenient API 246 | return { 247 | 'user': user, 248 | 'articles': user.article_list, 249 | 'comments': flatten([article.comment_list for article in user.article_list]), 250 | 'articles_by_id': # ... 251 | 'comments_by_article_id': # ... 252 | 'comments_by_id': # ... 253 | # ... 254 | } 255 | 256 | ``` 257 | 258 | Now any function can request the entire structure and use its rich API. We can isolate the ugly fetching logic and don't need to pass data around (e.g. view -> template -> helpers) to remain efficient. 259 | 260 | This is not a perfect approach, as it couples our consumers (e.g. views, helpers) to this data-structure. This makes it difficult to re-use those helpers, or parts of the data-structure. However, in a pinch, it may be preferable to setting up fetchers (e.g. article-by-id, comments-by-article-id) for every atomic piece of data. A neat compromise might be to split this up into multiple cache functions, or a class that that executes other cached functions lazily. 261 | 262 | 263 | ## Async 264 | 265 | Like most ORM-consuming code, data-fetcher is synchronous. You'll need to use `sync_to_async` to use it inside async views. Behind the scenes, the global-request middleware uses context-vars, which are both thread-safe and async-safe. 266 | 267 | ## Cache invalidation 268 | 269 | You can probably ignore cache invalidation, since the cache is cleared at the end of each request. However, if you change data that has been cached and want updated data during the same request, you can use the `clear_request_cache` function. This will clear all data-fetchers and `@cache_within_request` caches. 270 | 271 | ```python 272 | from data_fetcher.util import clear_request_caches 273 | 274 | def update_article(request, article_id): 275 | article = ArticleFetcher.get_instance().get(article_id) 276 | article.title = 'new title' 277 | article.save() 278 | 279 | # in case render_page uses the article-fetcher, 280 | # we clear all data-fetchers 281 | clear_request_caches() 282 | return render_page(article_id) 283 | ``` 284 | -------------------------------------------------------------------------------- /data_fetcher/__init__.py: -------------------------------------------------------------------------------- 1 | from data_fetcher.middleware import GlobalRequest 2 | 3 | from .core import DataFetcher 4 | from .extras import ValueBoundDataFetcher, cache_within_request 5 | from .shorthand_fetcher_classes import ( 6 | AbstractChildModelByAttrFetcher, 7 | AbstractModelByIdFetcher, 8 | PrimaryKeyFetcherFactory, 9 | ) 10 | from .util import get_datafetcher_request_cache 11 | -------------------------------------------------------------------------------- /data_fetcher/core.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .util import MissingRequestContextException, get_datafetcher_request_cache 4 | 5 | 6 | class BaseDataFetcher: 7 | def __init__(self): 8 | self._cache = {} 9 | self._queue = set() 10 | 11 | def get(self, key): 12 | all_keys = {key, *self._queue} 13 | 14 | self.prefetch_keys(all_keys) 15 | 16 | return self._cache[key] 17 | 18 | def get_many(self, keys): 19 | all_keys = {*keys, *self._queue} 20 | 21 | uncached_keys = [key for key in all_keys if key not in self._cache] 22 | if uncached_keys: 23 | self._get_many_uncached_values(uncached_keys) 24 | 25 | return [self._cache.get(key) for key in keys] 26 | 27 | def prefetch_keys(self, keys): 28 | self.get_many(keys) 29 | 30 | def get_many_as_dict(self, keys): 31 | return dict(zip(keys, self.get_many(keys))) 32 | 33 | def _get_single_uncached_value(self, key): 34 | return self.batch_load_and_cache([key])[0] 35 | 36 | def _get_many_uncached_values(self, keys): 37 | return self.batch_load_and_cache(keys) 38 | 39 | def _batch_load_fn(self, keys): 40 | if getattr(self, "batch_load", None): 41 | return self.batch_load(keys) 42 | elif getattr(self, "batch_load_dict", None): 43 | value_dict = self.batch_load_dict(keys) 44 | return [value_dict.get(key) for key in keys] 45 | else: 46 | raise NotImplementedError( 47 | "must implement batch_load or batch_load_dict" 48 | ) 49 | 50 | def batch_load_and_cache(self, keys): 51 | values = self._batch_load_fn(keys) 52 | for key, value in zip(keys, values): 53 | self._cache[key] = value 54 | return values 55 | 56 | def prime(self, key, value): 57 | self._cache[key] = value 58 | 59 | def enqueue_keys(self, keys): 60 | self._queue.update(set(keys)) 61 | 62 | def fetch_queued(self): 63 | queued_keys = set(self._queue) 64 | self.get_many(queued_keys) 65 | 66 | def get_lazy(self, key): 67 | self.enqueue_keys([key]) 68 | return LazyFetchedValue(lambda: self.get(key)) 69 | 70 | def get_many_lazy(self, keys): 71 | self.enqueue_keys(keys) 72 | return LazyFetchedValue(lambda: self.get_many(keys)) 73 | 74 | 75 | class LazyFetchedValue: 76 | def __init__(self, get_val): 77 | self.get_val = get_val 78 | 79 | def get(self): 80 | return self.get_val() 81 | 82 | 83 | class DataFetcher(BaseDataFetcher): 84 | """ 85 | Factory for creating composable datafetchers 86 | """ 87 | 88 | # this variable can be used for composition 89 | datafetcher_instance_cache = None 90 | 91 | __create_key = object() 92 | 93 | def __init__(self, create_key): 94 | # Hacky way to make constructor "private" 95 | assert ( 96 | create_key == DataFetcher.__create_key 97 | ), "Never create data-fetcher instances directly, use get_instance" 98 | 99 | super().__init__() 100 | 101 | @classmethod 102 | def get_instance(cls, raise_on_no_context=False): 103 | try: 104 | fetcher_instance_cache = get_datafetcher_request_cache() 105 | except MissingRequestContextException as e: 106 | if raise_on_no_context: 107 | raise e 108 | else: 109 | fetcher_instance_cache = {} 110 | 111 | if cls not in fetcher_instance_cache: 112 | fetcher_instance_cache[cls] = cls(DataFetcher.__create_key) 113 | 114 | return fetcher_instance_cache[cls] 115 | -------------------------------------------------------------------------------- /data_fetcher/extras.py: -------------------------------------------------------------------------------- 1 | from functools import cache, wraps 2 | 3 | from .core import DataFetcher 4 | from .util import MissingRequestContextException, get_datafetcher_request_cache 5 | 6 | 7 | class CacheDecoratorException(Exception): 8 | pass 9 | 10 | 11 | def cache_within_request(fn): 12 | """ 13 | ensure a function's values are cached for the duration of a request 14 | 15 | """ 16 | 17 | if isinstance(fn, classmethod): 18 | raise CacheDecoratorException( 19 | "apply the classmethod decorator after (above) the cache_within_request decorator" 20 | ) 21 | 22 | @wraps(fn) 23 | def wrapper(*args, **kwargs): 24 | try: 25 | datafetcher_cache = get_datafetcher_request_cache() 26 | except MissingRequestContextException: 27 | print( 28 | f"WARNING: calling {fn.__name__} outside of a request context," 29 | " caching is disabled" 30 | ) 31 | return fn(*args, **kwargs) 32 | 33 | # use function itself as key 34 | if fn not in datafetcher_cache: 35 | datafetcher_cache[fn] = cache(fn) 36 | 37 | return datafetcher_cache[fn](*args, **kwargs) 38 | 39 | return wrapper 40 | 41 | 42 | class ValueBoundFetcherFactory: 43 | datafetcher_classes_by_key = {} 44 | 45 | @staticmethod 46 | def _create_datafetcher_cls_for_keyval( 47 | parent_cls, 48 | key, 49 | value=None, 50 | ): 51 | if value is None: 52 | value = key 53 | 54 | return type( 55 | f"{parent_cls.__name__}__{key}", 56 | (parent_cls,), 57 | dict( 58 | bound_value=value, 59 | ), 60 | ) 61 | 62 | @classmethod 63 | def get_fetcher_by_key(cls, parent_cls, key, value=None): 64 | """ 65 | This ensures the same _class_ for a single key can only be created once 66 | 67 | datafetcher class will 'provide' 68 | the value. If key matches an already generated class, returns that class 69 | 70 | value argument only necessary if you want attach an hashable value 71 | """ 72 | 73 | dict_key = (parent_cls, key) 74 | 75 | if dict_key in cls.datafetcher_classes_by_key: 76 | return cls.datafetcher_classes_by_key[dict_key] 77 | else: 78 | fetcher = cls._create_datafetcher_cls_for_keyval( 79 | parent_cls, key, value 80 | ) 81 | cls.datafetcher_classes_by_key[dict_key] = fetcher 82 | return fetcher 83 | 84 | 85 | class ValueBoundDataFetcher(DataFetcher): 86 | """ 87 | To be used as a parent class for keyed-datafetchers 88 | 89 | The most common use case for ValueBoundDataFetcher 90 | is providing a user_id to a data-fetcher 91 | 92 | """ 93 | 94 | def __init__(self, *args, **kwargs): 95 | if not getattr(self, "bound_value", None): 96 | raise MissingRequestContextException( 97 | "AbstractThreatLocationAsOfTimeByThreatIdFetcher " 98 | "must be instantiated KeyedDataFetcherFactory" 99 | ) 100 | super().__init__(*args, **kwargs) 101 | 102 | @classmethod 103 | def get_value_bound_class(cls, key, value=None): 104 | """ 105 | if value not hashable, provide key first, then value 106 | otherwise just provide the value as 'key' 107 | """ 108 | return ValueBoundFetcherFactory.get_fetcher_by_key( 109 | cls, key, value=None 110 | ) 111 | -------------------------------------------------------------------------------- /data_fetcher/global_request_context.py: -------------------------------------------------------------------------------- 1 | import contextvars 2 | 3 | from django.http import HttpRequest 4 | 5 | storage = contextvars.ContextVar("request", default=None) 6 | 7 | 8 | class GlobalRequest: 9 | """ 10 | get_request() will return the same object 11 | within this ctx-manager's block 12 | 13 | """ 14 | 15 | def __init__(self, request=None): 16 | self.new_request = request or HttpRequest() 17 | self.old_request = storage.get() 18 | 19 | def __enter__(self): 20 | storage.set(self.new_request) 21 | return storage.get() 22 | 23 | def __exit__(self, *args, **kwargs): 24 | storage.set(self.old_request) 25 | 26 | 27 | def get_request(): 28 | return storage.get() 29 | -------------------------------------------------------------------------------- /data_fetcher/middleware.py: -------------------------------------------------------------------------------- 1 | from .global_request_context import GlobalRequest 2 | 3 | 4 | class GlobalRequestMiddleware: 5 | 6 | def __init__(self, get_response): 7 | self.get_response = get_response 8 | 9 | def __call__(self, request): 10 | with GlobalRequest(request=request): 11 | return self.get_response(request) 12 | -------------------------------------------------------------------------------- /data_fetcher/shorthand_fetcher_classes.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .core import DataFetcher 4 | 5 | 6 | class AbstractModelByIdFetcher(DataFetcher): 7 | model = None # override this part 8 | 9 | @classmethod 10 | def batch_load_dict(cls, ids): 11 | records = list(cls.model.objects.filter(pk__in=ids)) 12 | return {record.id: record for record in records} 13 | 14 | def get_all(self, queryset=None): 15 | if queryset is None: 16 | records = [*self.model.objects.all()] 17 | for r in records: 18 | self.prime(r.pk, r) 19 | 20 | return records 21 | 22 | 23 | class PrimaryKeyFetcherFactory: 24 | """ 25 | This ensures the same _class_ for a single model can only be created once. 26 | This is because some consumers dynamically create data-fetchers based on models not yet known 27 | """ 28 | 29 | datafetcher_classes_by_model = {} 30 | 31 | @staticmethod 32 | def _create_datafetcher_cls_for_model(model_cls): 33 | return type( 34 | f"{model_cls.__name__}ByIDFetcher", 35 | (AbstractModelByIdFetcher,), 36 | dict(model=model_cls), 37 | ) 38 | 39 | @classmethod 40 | def get_model_by_id_fetcher(cls, model_cls): 41 | if model_cls in cls.datafetcher_classes_by_model: 42 | return cls.datafetcher_classes_by_model[model_cls] 43 | else: 44 | fetcher = cls._create_datafetcher_cls_for_model(model_cls) 45 | cls.datafetcher_classes_by_model[model_cls] = fetcher 46 | return fetcher 47 | 48 | 49 | class AbstractChildModelByAttrFetcher(DataFetcher): 50 | """ 51 | Loads many records by a single attr, use this to create child-by-parent-id loaders 52 | """ 53 | 54 | model = None # override this part 55 | attr = None # override this part 56 | 57 | @classmethod 58 | def batch_load(cls, attr_values): 59 | records = list( 60 | cls.model.objects.filter(**{f"{cls.attr}__in": attr_values}) 61 | ) 62 | by_attr = defaultdict(list) 63 | for record in records: 64 | by_attr[getattr(record, cls.attr)].append(record) 65 | 66 | return [by_attr[attr_val] for attr_val in attr_values] 67 | -------------------------------------------------------------------------------- /data_fetcher/util.py: -------------------------------------------------------------------------------- 1 | # from data_fetcher.middleware import get_request 2 | from .global_request_context import GlobalRequest, get_request 3 | 4 | 5 | class MissingRequestContextException(Exception): 6 | pass 7 | 8 | 9 | def get_datafetcher_request_cache(): 10 | request = get_request() 11 | if not request: 12 | raise MissingRequestContextException( 13 | "No request is available, don't use datafetchers outside of a request context" 14 | ) 15 | 16 | if not hasattr(request, "datafetcher_cache"): 17 | request.datafetcher_cache = {} 18 | 19 | return request.datafetcher_cache 20 | 21 | 22 | def clear_request_caches(): 23 | """ 24 | Clears all cached values for datafetchers 25 | 26 | Only necessary when a request wants data it has modified 27 | 28 | Also clears the functions cached with cache_within_request decorator 29 | """ 30 | request = get_request() 31 | if request and hasattr(request, "datafetcher_cache"): 32 | # reset the cache to an empty dict 33 | request.datafetcher_cache = {} 34 | 35 | 36 | def clear_datafetchers(): 37 | """ 38 | clears request cache 39 | old API, prefer clear_request_caches() 40 | """ 41 | clear_request_caches() 42 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "sample_app.settings") 9 | try: 10 | from django.core.management import execute_from_command_line 11 | except ImportError as exc: 12 | raise ImportError( 13 | "Couldn't import Django. Are you sure it's installed and " 14 | "available on your PYTHONPATH environment variable? Did you " 15 | "forget to activate a virtual environment?" 16 | ) from exc 17 | execute_from_command_line(sys.argv) 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | target-version = ['py310'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | 7 | ( 8 | /( 9 | \.eggs # exclude a few common directories in the 10 | | \.git # root of the project 11 | | \.mypy_cache 12 | | venv 13 | )/ 14 | # ignore auto-generated migrations, no one ever saves those manually 15 | | migrations/.*.py 16 | | migrations/.*_initial.py 17 | | migrations/.*_auto_.*.py 18 | ) 19 | ''' 20 | 21 | 22 | [tool.isort] 23 | profile = "black" 24 | line_length = 79 25 | known_django = "django" 26 | sections=[ "FUTURE", "STDLIB", "DJANGO", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER" ] 27 | skip_glob=['*/migrations/*.py'] 28 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # -- FILE: pytest.ini (or tox.ini) 2 | [pytest] 3 | DJANGO_SETTINGS_MODULE = sample_app.settings 4 | python_files = tests.py test_*.py *_tests.py 5 | norecursedirs = .git env venv 6 | addopts = -p no:warnings -s 7 | junit_family = xunit2 -------------------------------------------------------------------------------- /pytest_test_runner.py: -------------------------------------------------------------------------------- 1 | class PytestTestRunner: 2 | """Runs pytest to discover and run tests.""" 3 | 4 | @classmethod 5 | def add_arguments(cls, parser): 6 | parser.set_defaults(keepdb=True) 7 | 8 | # remaps to the -k cli arg of pytest, useful for test selection 9 | parser.add_argument( 10 | "-s", 11 | "--select", 12 | help="remaps to -k test-selection argument in pytest", 13 | ) 14 | parser.add_argument( 15 | "--junit-xml", 16 | dest="junitXml", 17 | help="remaps the junit xml argument for pytest", 18 | ) 19 | 20 | def __init__( 21 | self, 22 | verbosity=1, 23 | failfast=False, 24 | keepdb=True, 25 | select=None, 26 | junitXml=None, 27 | **kwargs, 28 | ): 29 | self.verbosity = verbosity 30 | self.failfast = failfast 31 | self.keepdb = keepdb 32 | self.select = select 33 | self.junitXml = junitXml 34 | 35 | def run_tests(self, test_labels): 36 | """Run pytest and return the exitcode. 37 | 38 | It translates some of Django's test command option to pytest's. 39 | """ 40 | import pytest 41 | 42 | argv = [] 43 | if self.select is not None: 44 | argv.append(f"-k {self.select}") 45 | if self.verbosity == 0: 46 | argv.append("--quiet") 47 | if self.verbosity == 2: 48 | argv.append("--verbose") 49 | if self.verbosity == 3: 50 | argv.append("-vv") 51 | if self.failfast: 52 | argv.append("--exitfirst") 53 | if self.keepdb: 54 | argv.append("--reuse-db") 55 | if self.junitXml: 56 | argv.append(f"--junit-xml={self.junitXml}") 57 | 58 | argv.extend(test_labels) 59 | return pytest.main(argv) 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colored-traceback==0.3.0 # used for tests, but required my manage.py and therefore prod scripts 2 | Django==4.0.0 3 | django-extensions==3.1.1 4 | Faker==4.0.0 5 | ipython==8.10.0 6 | 7 | # dev 8 | coverage==5.1 9 | django-debug-toolbar==3.7 10 | django-graphiql-debug-toolbar==0.2.0 11 | black==22.3.0 12 | pytest==7.1.2 13 | pytest-django==4.5.2 14 | factory-boy===2.12.0 15 | isort===5.7.0 16 | 17 | 18 | # deployment 19 | twine==4.0.1 20 | -------------------------------------------------------------------------------- /sample_app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexCLeduc/django-data-fetcher/aada5f5630c25603af005a2e20a57bc1b6fa216e/sample_app/__init__.py -------------------------------------------------------------------------------- /sample_app/data_factories.py: -------------------------------------------------------------------------------- 1 | from django.contrib.auth.models import User 2 | 3 | import factory 4 | 5 | from .models import Author, Book, Tag 6 | 7 | 8 | class UserFactory(factory.django.DjangoModelFactory): 9 | class Meta: 10 | model = User 11 | 12 | username = factory.Faker("user_name") 13 | email = factory.Faker("email") 14 | 15 | 16 | class TagFactory(factory.django.DjangoModelFactory): 17 | class Meta: 18 | model = Tag 19 | 20 | name = factory.Faker("color") 21 | 22 | 23 | class AuthorFactory(factory.django.DjangoModelFactory): 24 | class Meta: 25 | model = Author 26 | 27 | first_name = factory.Faker("first_name") 28 | last_name = factory.Faker("last_name") 29 | 30 | 31 | class BookFactory(factory.django.DjangoModelFactory): 32 | class Meta: 33 | model = Book 34 | 35 | author = factory.SubFactory(AuthorFactory) 36 | 37 | @factory.post_generation 38 | def tags(self, create, extracted, **kwargs): 39 | if not create: 40 | # Simple build, do nothing. 41 | return 42 | 43 | if extracted is not None: 44 | # A list of tags were passed in, use them 45 | self.tags.add(*extracted) 46 | -------------------------------------------------------------------------------- /sample_app/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.0 on 2024-01-20 14:58 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | initial = True 10 | 11 | dependencies = [ 12 | ] 13 | 14 | operations = [ 15 | migrations.CreateModel( 16 | name='Author', 17 | fields=[ 18 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 19 | ('first_name', models.CharField(max_length=100)), 20 | ('last_name', models.CharField(max_length=100)), 21 | ], 22 | ), 23 | migrations.CreateModel( 24 | name='Tag', 25 | fields=[ 26 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 27 | ('name', models.CharField(max_length=250)), 28 | ], 29 | ), 30 | migrations.CreateModel( 31 | name='Book', 32 | fields=[ 33 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 34 | ('title', models.CharField(max_length=250)), 35 | ('author', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='books', to='sample_app.author')), 36 | ('tags', models.ManyToManyField(to='sample_app.Tag')), 37 | ], 38 | ), 39 | ] 40 | -------------------------------------------------------------------------------- /sample_app/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexCLeduc/django-data-fetcher/aada5f5630c25603af005a2e20a57bc1b6fa216e/sample_app/migrations/__init__.py -------------------------------------------------------------------------------- /sample_app/models.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from django.db import models 3 | from django.utils import timezone 4 | 5 | 6 | class Author(models.Model): 7 | first_name = models.CharField(max_length=100) 8 | last_name = models.CharField(max_length=100) 9 | 10 | 11 | class Tag(models.Model): 12 | name = models.CharField(max_length=250) 13 | 14 | def __str__(self): 15 | return self.name 16 | 17 | 18 | class Book(models.Model): 19 | author = models.ForeignKey( 20 | Author, related_name="books", on_delete=models.CASCADE 21 | ) 22 | title = models.CharField(max_length=250) 23 | tags = models.ManyToManyField(Tag) 24 | -------------------------------------------------------------------------------- /sample_app/settings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | """ 4 | Django settings for sample_app project. 5 | 6 | Generated by 'django-admin startproject' using Django 2.2.11. 7 | 8 | For more information on this file, see 9 | https://docs.djangoproject.com/en/2.2/topics/settings/ 10 | 11 | For the full list of settings and their values, see 12 | https://docs.djangoproject.com/en/2.2/ref/settings/ 13 | """ 14 | 15 | import os 16 | 17 | from django.urls import reverse_lazy 18 | 19 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 20 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 21 | 22 | # Quick-start development settings - unsuitable for production 23 | # See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/ 24 | 25 | SECRET_KEY = "abc123" 26 | 27 | DEBUG = True 28 | 29 | ALLOWED_HOSTS = [] 30 | 31 | LOGIN_URL = reverse_lazy("login") 32 | 33 | # Application definition 34 | 35 | INSTALLED_APPS = [ 36 | "django.contrib.auth", 37 | "django.contrib.contenttypes", 38 | "django.contrib.sessions", 39 | "django_extensions", 40 | "sample_app", 41 | ] 42 | 43 | MIDDLEWARE = [ 44 | "django.middleware.security.SecurityMiddleware", 45 | "django.contrib.sessions.middleware.SessionMiddleware", 46 | "django.middleware.common.CommonMiddleware", 47 | "django.middleware.csrf.CsrfViewMiddleware", 48 | "django.contrib.auth.middleware.AuthenticationMiddleware", 49 | "django.contrib.messages.middleware.MessageMiddleware", 50 | "django.middleware.clickjacking.XFrameOptionsMiddleware", 51 | "data_fetcher.middleware.GlobalRequestMiddleware", 52 | ] 53 | 54 | ROOT_URLCONF = "sample_app.urls" 55 | 56 | TEMPLATES = [ 57 | { 58 | "BACKEND": "django.template.backends.django.DjangoTemplates", 59 | "DIRS": [], 60 | "APP_DIRS": True, 61 | "OPTIONS": { 62 | "context_processors": [ 63 | "django.template.context_processors.debug", 64 | "django.template.context_processors.request", 65 | "django.contrib.auth.context_processors.auth", 66 | "django.contrib.messages.context_processors.messages", 67 | ], 68 | }, 69 | }, 70 | ] 71 | 72 | WSGI_APPLICATION = "sample_app.wsgi.application" 73 | 74 | 75 | # Database 76 | # https://docs.djangoproject.com/en/2.2/ref/settings/#databases 77 | 78 | DATABASES = { 79 | "default": { 80 | "ENGINE": "django.db.backends.sqlite3", 81 | "NAME": f"{BASE_DIR}/db.sqlite3", 82 | } 83 | } 84 | 85 | 86 | # Internationalization 87 | # https://docs.djangoproject.com/en/2.2/topics/i18n/ 88 | 89 | LANGUAGE_CODE = "en" 90 | 91 | TIME_ZONE = "UTC" 92 | 93 | USE_I18N = True 94 | 95 | USE_L10N = True 96 | 97 | USE_TZ = True 98 | 99 | 100 | # Static files (CSS, JavaScript, Images) 101 | # https://docs.djangoproject.com/en/2.2/howto/static-files/ 102 | 103 | STATIC_URL = "/static/" 104 | 105 | IS_TEST = False 106 | 107 | if "test" in sys.argv or any("pytest" in arg for arg in sys.argv): 108 | IS_TEST = True 109 | TEST_RUNNER = "pytest_test_runner.PytestTestRunner" 110 | -------------------------------------------------------------------------------- /sample_app/urls.py: -------------------------------------------------------------------------------- 1 | """sample_app URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/2.2/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | 17 | from django.contrib import admin 18 | from django.contrib.auth.views import LoginView 19 | from django.urls import path 20 | 21 | from .views import async_view_with_loaders, edit_book, view_with_loaders 22 | 23 | urlpatterns = [ 24 | path("login/", LoginView.as_view(), name="login"), 25 | path("book//edit/", edit_book, name="edit-book"), 26 | path("view1", view_with_loaders, name="view1"), 27 | path("async_view1", async_view_with_loaders, name="async_view1"), 28 | ] 29 | -------------------------------------------------------------------------------- /sample_app/views.py: -------------------------------------------------------------------------------- 1 | from django.http.response import HttpResponse 2 | 3 | from asgiref.sync import sync_to_async 4 | 5 | from data_fetcher import PrimaryKeyFetcherFactory 6 | 7 | from .models import Author, Book, Tag 8 | 9 | 10 | def edit_book(request, pk=None): 11 | book = Book.objects.get(pk=pk) 12 | if request.POST: 13 | new_name = request.POST["title"] 14 | book.title = new_name 15 | book.save() 16 | 17 | return HttpResponse() 18 | 19 | 20 | def spyable_func(*args, **kwargs): 21 | # for testing purposes 22 | return None 23 | 24 | 25 | AuthorByIdFetcher = PrimaryKeyFetcherFactory.get_model_by_id_fetcher(Author) 26 | 27 | 28 | class WatchedAuthorByIdFetcher(AuthorByIdFetcher): 29 | def batch_load_dict(self, keys): 30 | spyable_func(keys) 31 | return super().batch_load_dict(keys) 32 | 33 | 34 | def get_author(author_id): 35 | return WatchedAuthorByIdFetcher.get_instance().get(author_id) 36 | 37 | 38 | def view_with_loaders(request): 39 | author_ids = Author.objects.values_list("id", flat=True) 40 | all_authors = WatchedAuthorByIdFetcher.get_instance().get_many(author_ids) 41 | 42 | for author in all_authors: 43 | refetched_author = get_author(author.id) 44 | assert refetched_author is author 45 | 46 | return HttpResponse("ok") 47 | 48 | 49 | async def async_view_with_loaders(request): 50 | resp = await sync_to_async(view_with_loaders)(request) 51 | return resp 52 | -------------------------------------------------------------------------------- /sample_app/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for sample_app project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "sample_app.settings") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="django-data-fetcher", 8 | version="2.2.1", 9 | author="AlexCLeduc", 10 | # author_email="author@example.com", 11 | # description="A small example package", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/AlexCLeduc/django-data-fetcher", 15 | packages=[ 16 | # find_packages() also includes extraneous stuff, like testing and sample_app 17 | package 18 | for package in setuptools.find_packages() 19 | if package.startswith("data_fetcher") 20 | ], 21 | install_requires=[], 22 | tests_require=["django"], 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexCLeduc/django-data-fetcher/aada5f5630c25603af005a2e20a57bc1b6fa216e/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from django.db import transaction 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(autouse=True) 7 | def enable_db_access_for_all_tests(db): 8 | """ 9 | without this, tests (including old-style) have to explicitly declare db as a dependency 10 | https://pytest-django.readthedocs.io/en/latest/faq.html#how-can-i-give-database-access-to-all-my-tests-without-the-django-db-marker 11 | """ 12 | pass 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def globally_scoped_fixture_helper(django_db_setup, django_db_blocker): 17 | with django_db_blocker.unblock(): 18 | # Wrap in try + atomic block to do non crashing rollback 19 | # This means we don't have to re-create a test DB each time 20 | try: 21 | with transaction.atomic(): 22 | yield 23 | raise Exception 24 | except Exception: 25 | pass 26 | -------------------------------------------------------------------------------- /tests/test_cache_within_request.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from unittest.mock import MagicMock 3 | 4 | from django.contrib.auth import get_user_model 5 | 6 | import pytest 7 | 8 | from data_fetcher import cache_within_request, get_datafetcher_request_cache 9 | from data_fetcher.extras import CacheDecoratorException 10 | from data_fetcher.util import GlobalRequest, get_request 11 | 12 | 13 | def test_cache_decorator(): 14 | spy = MagicMock() 15 | 16 | def func_to_cache(): 17 | spy() 18 | return 1 19 | 20 | cached_func = cache_within_request(func_to_cache) 21 | 22 | with GlobalRequest(): 23 | result = cached_func() 24 | result2 = cached_func() 25 | 26 | assert result == result2 == 1 27 | spy.assert_called_once() 28 | 29 | with GlobalRequest(): 30 | cached_func() 31 | cached_func() 32 | 33 | assert spy.call_count == 2 34 | 35 | 36 | def test_cache_with_staticmethod(): 37 | """ 38 | static decorator works in either order 39 | """ 40 | 41 | spy = MagicMock() 42 | 43 | class TestClass: 44 | 45 | @cache_within_request 46 | @staticmethod 47 | def _things_by_id(): 48 | spy() 49 | return { 50 | 1: "a", 51 | 2: "b", 52 | } 53 | 54 | @staticmethod 55 | @cache_within_request 56 | def get_thing(id): 57 | return TestClass._things_by_id()[id] 58 | 59 | with GlobalRequest(): 60 | assert TestClass.get_thing(1) == "a" 61 | assert TestClass.get_thing(2) == "b" 62 | assert TestClass.get_thing(1) == "a" 63 | assert TestClass.get_thing(2) == "b" 64 | dict_one = TestClass._things_by_id() 65 | assert spy.call_count == 1 66 | 67 | with GlobalRequest(): 68 | dict_two = TestClass._things_by_id() 69 | 70 | assert dict_one is not dict_two 71 | 72 | 73 | def test_cache_with_undecorated_method(): 74 | """obviously not recommended, but it should work""" 75 | 76 | spy = MagicMock() 77 | 78 | class TestClass: 79 | 80 | @cache_within_request 81 | def _things_by_id(): 82 | spy() 83 | return { 84 | 1: "a", 85 | 2: "b", 86 | } 87 | 88 | def get_thing(id): 89 | return TestClass._things_by_id()[id] 90 | 91 | with GlobalRequest(): 92 | dict_one = TestClass._things_by_id() 93 | 94 | assert TestClass.get_thing(1) == "a" 95 | assert TestClass.get_thing(2) == "b" 96 | assert TestClass.get_thing(1) == "a" 97 | assert TestClass.get_thing(2) == "b" 98 | assert spy.call_count == 1 99 | 100 | with GlobalRequest(): 101 | dict_two = TestClass._things_by_id() 102 | 103 | assert dict_one is not dict_two 104 | 105 | 106 | def test_classmethod_correct_order(): 107 | 108 | inner_spy = MagicMock() 109 | spy = MagicMock() 110 | 111 | class TestClass: 112 | 113 | _cls_value = "c" 114 | 115 | @classmethod 116 | @cache_within_request 117 | def _other_value(cls): 118 | inner_spy() 119 | return cls._cls_value 120 | 121 | @classmethod 122 | @cache_within_request 123 | def _things_by_id(cls): 124 | spy() 125 | return { 126 | 1: "a", 127 | 2: "b", 128 | 3: cls._other_value(), 129 | } 130 | 131 | @classmethod 132 | def get_thing(cls, id): 133 | return cls._things_by_id()[id] 134 | 135 | with GlobalRequest(): 136 | assert TestClass.get_thing(1) == "a" 137 | assert TestClass.get_thing(2) == "b" 138 | assert TestClass.get_thing(1) == "a" 139 | assert TestClass.get_thing(2) == "b" 140 | assert TestClass.get_thing(3) == "c" 141 | dict_one = TestClass._things_by_id() 142 | assert inner_spy.call_count == 1 143 | assert spy.call_count == 1 144 | 145 | with GlobalRequest(): 146 | assert TestClass.get_thing(1) == "a" 147 | assert TestClass.get_thing(2) == "b" 148 | assert TestClass.get_thing(3) == "c" 149 | dict_two = TestClass._things_by_id() 150 | assert inner_spy.call_count == 2 151 | assert spy.call_count == 2 152 | 153 | assert dict_one is not dict_two 154 | 155 | 156 | def test_classmethod_wrong_order(): 157 | 158 | with pytest.raises(CacheDecoratorException): 159 | 160 | class TestClass: 161 | 162 | @cache_within_request 163 | @classmethod 164 | def _other_value(cls): 165 | pass 166 | -------------------------------------------------------------------------------- /tests/test_data_fetchers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from unittest.mock import MagicMock 3 | 4 | from django.contrib.auth import get_user_model 5 | 6 | from data_fetcher import ( 7 | DataFetcher, 8 | PrimaryKeyFetcherFactory, 9 | get_datafetcher_request_cache, 10 | ) 11 | from data_fetcher.util import GlobalRequest, get_request 12 | 13 | 14 | def test_global_request_outside_request(): 15 | assert get_request() is None 16 | 17 | 18 | def test_global_request_context_processor(): 19 | with GlobalRequest(): 20 | assert get_request() is not None 21 | get_request().x = 1 22 | assert get_request().x == 1 23 | 24 | 25 | def test_global_request_returns_same_request(): 26 | with GlobalRequest(): 27 | r1 = get_request() 28 | r2 = get_request() 29 | assert r1 is r2 30 | with GlobalRequest(): 31 | r3 = get_request() 32 | 33 | assert r1 is not r3 34 | 35 | 36 | def test_user_datafetcher(django_assert_num_queries): 37 | users = [ 38 | get_user_model().objects.create(username=f"test_user_{i}") 39 | for i in range(10) 40 | ] 41 | user_ids = [user.id for user in users] 42 | 43 | UserByPKFetcher = PrimaryKeyFetcherFactory.get_model_by_id_fetcher( 44 | get_user_model() 45 | ) 46 | 47 | with GlobalRequest(): 48 | loader = UserByPKFetcher.get_instance() 49 | with django_assert_num_queries(1): 50 | # querying users also prefetches groups, so 2 queries are expected 51 | assert loader.get_many(user_ids) == users 52 | assert loader.get(user_ids[0]) == users[0] 53 | assert loader.get_many_as_dict(user_ids) == { 54 | user.id: user for user in users 55 | } 56 | 57 | loader2 = UserByPKFetcher.get_instance() 58 | assert loader is loader2 59 | 60 | with GlobalRequest(): 61 | # now check a new loader is brand new w/out any cache 62 | loader3 = UserByPKFetcher.get_instance() 63 | assert loader != loader3 64 | assert loader3._cache == {} 65 | 66 | 67 | def test_composed_datafetcher(django_assert_max_num_queries): 68 | users = [ 69 | get_user_model().objects.create(username=f"test_user_{i}") 70 | for i in range(10) 71 | ] 72 | user_ids = [user.id for user in users] 73 | 74 | UserByPKFetcher = PrimaryKeyFetcherFactory.get_model_by_id_fetcher( 75 | get_user_model() 76 | ) 77 | 78 | spy = MagicMock() 79 | 80 | # Trivial example of dataloader composition 81 | class TrivialOtherFetcher(DataFetcher): 82 | def batch_load_dict(self, keys): 83 | spy(keys) 84 | user_fetcher = UserByPKFetcher.get_instance() 85 | return user_fetcher.get_many_as_dict(keys) 86 | 87 | with GlobalRequest(): 88 | loader = TrivialOtherFetcher.get_instance() 89 | with django_assert_max_num_queries(1): 90 | # querying users also prefetches groups, so 2 queries are expected 91 | assert loader.get_many(user_ids) == users 92 | assert loader.get(user_ids[0]) == users[0] 93 | assert loader.get_many_as_dict(user_ids) == { 94 | user.id: user for user in users 95 | } 96 | 97 | assert spy.call_count == 1 98 | 99 | fetcher_cache = get_datafetcher_request_cache() 100 | assert fetcher_cache[UserByPKFetcher] is not None 101 | 102 | 103 | def test_priming(): 104 | with GlobalRequest(): 105 | user_fetcher = PrimaryKeyFetcherFactory.get_model_by_id_fetcher( 106 | get_user_model() 107 | ).get_instance() 108 | 109 | user_fetcher.prime(1, "test value") 110 | assert user_fetcher.get(1) == "test value" 111 | 112 | 113 | def test_pk_fetcher_fetch_all(django_assert_max_num_queries): 114 | users = [ 115 | get_user_model().objects.create(username=f"test_user_{i}") 116 | for i in range(10) 117 | ] 118 | u1 = users[0] 119 | user_ids = [user.id for user in users] 120 | 121 | user_fetcher = PrimaryKeyFetcherFactory.get_model_by_id_fetcher( 122 | get_user_model() 123 | ).get_instance() 124 | 125 | with GlobalRequest(): 126 | with django_assert_max_num_queries(1): 127 | records = user_fetcher.get_all() 128 | assert set(records) == set(users) 129 | u = user_fetcher.get(user_ids[0]) 130 | assert u == u1 131 | 132 | 133 | def test_batch_load_dict_none_value(): 134 | """ 135 | the batch_load_dict is more tolerant than the list counterpart, 136 | it doesn't need explicit None 137 | 138 | Also check that ommitted keys still get cached 139 | """ 140 | 141 | spy = MagicMock() 142 | 143 | class TestFetcher(DataFetcher): 144 | def batch_load_dict(self, keys): 145 | spy() 146 | return {"a": 1, "b": None} 147 | 148 | with GlobalRequest(): 149 | fetcher = TestFetcher.get_instance() 150 | 151 | fetcher.prefetch_keys(["a", "b", "c"]) 152 | assert fetcher.get("a") == 1 153 | assert fetcher.get("b") is None 154 | assert fetcher.get("c") is None 155 | 156 | assert spy.call_count == 1 157 | 158 | 159 | def test_queued_fetch(): 160 | spy = MagicMock() 161 | 162 | class TestFetcher(DataFetcher): 163 | def batch_load_dict(self, keys): 164 | spy(keys) 165 | return {key: key * 2 for key in keys} 166 | 167 | with GlobalRequest(): 168 | fetcher = TestFetcher.get_instance() 169 | fetcher.prime(1, 2) 170 | 171 | fetcher.enqueue_keys([1, 2, 3, 4]) 172 | 173 | assert fetcher.get(1) == 2 174 | assert fetcher.get(2) == 4 175 | assert fetcher.get(3) == 6 176 | assert fetcher.get_many([2, 4, 5]) == [4, 8, 10] 177 | 178 | # primed/cached keys should not be called, even if enqueued 179 | assert spy.call_args_list == [ 180 | (([2, 3, 4],),), 181 | (([5],),), 182 | ] 183 | 184 | # now check a regular .get() will also fetch queued keys 185 | fetcher.enqueue_keys([10, 11]) 186 | assert fetcher.get(12) == 24 187 | assert fetcher.get(10) == 20 188 | assert spy.call_args_list == [ 189 | (([2, 3, 4],),), 190 | (([5],),), 191 | (([10, 11, 12],),), 192 | ] 193 | 194 | # and clears cache 195 | fetcher.fetch_queued() 196 | assert spy.call_count == 3 197 | 198 | 199 | def test_fetch_lazy(): 200 | spy = MagicMock() 201 | 202 | class TestFetcher(DataFetcher): 203 | def batch_load_dict(self, keys): 204 | spy(keys) 205 | return {key: key * 2 for key in keys} 206 | 207 | with GlobalRequest(): 208 | fetcher = TestFetcher.get_instance() 209 | 210 | # check lazy calls are flushed all at once 211 | l1 = fetcher.get_lazy(1) 212 | l2 = fetcher.get_lazy(2) 213 | l3 = fetcher.get_lazy(3) 214 | assert spy.call_count == 0 215 | 216 | assert l3.get() == 6 217 | assert spy.call_count == 1 218 | spy.assert_called_once_with([1, 2, 3]) 219 | 220 | # and that queue is cleared 221 | fetcher.fetch_queued() 222 | assert spy.call_count == 1 223 | spy.assert_called_once_with([1, 2, 3]) 224 | 225 | # and similarly for get_many_lazy 226 | l4_5 = fetcher.get_many_lazy([4, 5]) 227 | spy.assert_called_once_with([1, 2, 3]) 228 | l4_5.get() 229 | 230 | assert spy.call_count == 2 231 | assert spy.call_args_list == [ 232 | (([1, 2, 3],),), 233 | (([4, 5],),), 234 | ] 235 | -------------------------------------------------------------------------------- /tests/test_keyed_fetcher.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from data_fetcher import GlobalRequest, ValueBoundDataFetcher 4 | 5 | 6 | def test_keyed_datafetcher_factory(): 7 | class DatetimeBoundDataFetcher(ValueBoundDataFetcher): 8 | bound_value = None 9 | 10 | def batch_load_dict(self, keys): 11 | return {key: (1, self.bound_value) for key in keys} 12 | 13 | dt1 = datetime.datetime(2021, 1, 1) 14 | dt2 = datetime.datetime(2021, 1, 2) 15 | 16 | cls1 = DatetimeBoundDataFetcher.get_value_bound_class(dt1) 17 | cls2 = DatetimeBoundDataFetcher.get_value_bound_class(dt2) 18 | cls3 = DatetimeBoundDataFetcher.get_value_bound_class(dt1) 19 | assert cls1 is not cls2 20 | assert cls1 is cls3 21 | 22 | with GlobalRequest(): 23 | loader_for_cls1 = cls1.get_instance() 24 | loader_for_cls2 = cls2.get_instance() 25 | loader_for_cls3 = cls3.get_instance() 26 | assert loader_for_cls1 is not loader_for_cls2 27 | assert loader_for_cls1 is loader_for_cls3 28 | -------------------------------------------------------------------------------- /tests/test_middleware.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | from django.test.client import Client 4 | from django.urls import reverse 5 | 6 | from sample_app import data_factories 7 | 8 | 9 | def test_view_with_loader(): 10 | u = data_factories.UserFactory() 11 | client1 = Client() 12 | client1.force_login(u) 13 | 14 | u2 = data_factories.UserFactory() 15 | client2 = Client() 16 | client2.force_login(u2) 17 | 18 | data_factories.AuthorFactory.create_batch(20) 19 | 20 | url = reverse("view1") 21 | 22 | spy1 = MagicMock() 23 | 24 | with patch("sample_app.views.spyable_func", spy1): 25 | response = client1.get(url) 26 | assert response.status_code == 200 27 | 28 | spy2 = MagicMock() 29 | with patch("sample_app.views.spyable_func", spy2): 30 | response = client2.get(url) 31 | assert response.status_code == 200 32 | 33 | assert spy1.call_count == 1 34 | assert spy2.call_count == 1 35 | 36 | 37 | def test_async_view_with_loader(): 38 | u = data_factories.UserFactory() 39 | client1 = Client() 40 | client1.force_login(u) 41 | 42 | u2 = data_factories.UserFactory() 43 | client2 = Client() 44 | client2.force_login(u2) 45 | 46 | data_factories.AuthorFactory.create_batch(20) 47 | 48 | url = reverse("async_view1") 49 | 50 | spy1 = MagicMock() 51 | 52 | with patch("sample_app.views.spyable_func", spy1): 53 | response = client1.get(url) 54 | assert response.status_code == 200 55 | 56 | spy2 = MagicMock() 57 | with patch("sample_app.views.spyable_func", spy2): 58 | response = client2.get(url) 59 | assert response.status_code == 200 60 | 61 | assert spy1.call_count == 1 62 | assert spy2.call_count == 1 63 | -------------------------------------------------------------------------------- /tests/test_model_fetchers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexCLeduc/django-data-fetcher/aada5f5630c25603af005a2e20a57bc1b6fa216e/tests/test_model_fetchers.py -------------------------------------------------------------------------------- /tests/test_singleton_fetcher.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexCLeduc/django-data-fetcher/aada5f5630c25603af005a2e20a57bc1b6fa216e/tests/test_singleton_fetcher.py --------------------------------------------------------------------------------