├── example ├── products │ ├── __init__.py │ ├── migrations │ │ ├── __init__.py │ │ └── 0001_initial.py │ ├── admin.py │ ├── apps.py │ ├── models.py │ ├── fixtures │ │ └── products.json │ ├── views.py │ └── serializers.py ├── requirements.txt ├── scripts │ ├── dev.sh │ └── devsetup.sh ├── wsgi.py ├── .gitignore ├── README.md ├── manage.py ├── urls.py └── settings.py ├── tests ├── migrations │ ├── __init__.py │ ├── 0002_alter_product_category_alter_product_partners_and_more.py │ └── 0001_initial.py ├── __init__.py ├── mixins.py ├── urls.py ├── settings.py ├── models.py ├── viewsets.py ├── serializers.py └── test_products_api.py ├── drf_sideloading ├── __init__.py ├── schema.py ├── serializers.py └── mixins.py ├── requirements.txt ├── .flake8 ├── MANIFEST.in ├── requirements_dev.txt ├── .coveragerc ├── pyproject.toml ├── manage.py ├── requirements_test.txt ├── setup.cfg ├── AUTHORS.md ├── .github └── ISSUE_TEMPLATE.md ├── .editorconfig ├── runtests.py ├── .travis.yml ├── .gitignore ├── LICENSE ├── Makefile ├── tox.ini ├── HISTORY.md ├── setup.py └── README.md /example/products/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example/products/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /drf_sideloading/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.2.3" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Additional requirements go here 2 | -------------------------------------------------------------------------------- /example/products/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = W503 4 | -------------------------------------------------------------------------------- /example/requirements.txt: -------------------------------------------------------------------------------- 1 | django>=2.1,<5.3 2 | djangorestframework>=3.9,<4.0 3 | django-debug-toolbar 4 | -------------------------------------------------------------------------------- /example/products/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ProductsConfig(AppConfig): 5 | name = "products" 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class TestsConfig(AppConfig): 5 | name = "tests" 6 | verbose_name = "Tests" 7 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.md 2 | include HISTORY.md 3 | include LICENSE 4 | include README.md 5 | recursive-include drf_sideloading *.html *.png *.gif *js *.css *jpg *jpeg *svg *py 6 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | bump2version==1.0.1 2 | codecov 3 | flake8 4 | sphinx 5 | tox 6 | twine 7 | wheel 8 | recommonmark 9 | 10 | django>=2.2,<5.0 11 | djangorestframework>=3.9,<4.0 12 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = true 3 | 4 | [report] 5 | omit = 6 | *site-packages* 7 | *tests* 8 | *.tox* 9 | show_missing = True 10 | exclude_lines = 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ["py37", "py38", "py39", "py310", "py311", "py312"] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | \.git 8 | | build 9 | | dist 10 | 11 | )/ 12 | ''' 13 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") 6 | from django.core.management import execute_from_command_line 7 | 8 | execute_from_command_line(sys.argv) 9 | -------------------------------------------------------------------------------- /example/scripts/dev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Activate virtualenv 4 | source ./.env/bin/activate 5 | 6 | # Add drf-sideloading library to PYTHONPATH 7 | export PYTHONPATH=$PYTHONPATH:$(cd .. && pwd) 8 | 9 | # Start development server 10 | python manage.py runserver 11 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | 3 | coverage==4.3.4 4 | mock>=1.0.1 5 | flake8>=2.1.0 6 | tox>=1.7.0 7 | tox-travis 8 | codecov>=2.0.0 9 | sphinx 10 | recommonmark 11 | 12 | # supported django and DRF 13 | django>=2.2,<5.2 14 | djangorestframework>=3.9,<4.0 15 | 16 | # to test conflicts with filtering 17 | django-filter 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 2.2.3 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:drf_sideloading/__init__.py] 7 | 8 | [wheel] 9 | universal = 1 10 | 11 | [flake8] 12 | ignore = D203 13 | exclude = 14 | drf_sideloading/migrations, 15 | .git, 16 | .tox, 17 | docs/conf.py, 18 | build, 19 | dist 20 | max-line-length = 120 21 | -------------------------------------------------------------------------------- /tests/mixins.py: -------------------------------------------------------------------------------- 1 | class OtherMixin(object): 2 | """Mixin for testing purposes 3 | Check if `self.action` attribute is availavle 4 | """ 5 | 6 | def get_serializer_class(self): 7 | if not hasattr(self, "action"): 8 | raise AttributeError("Action is not available") 9 | return super(OtherMixin, self).get_serializer_class() 10 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | 3 | ## Contributors 4 | 5 | - [Demur Nodia](https://github.com/demonno) 6 | - [Tõnis Väin](https://github.com/tonisvain) 7 | - [Madis Väin](https://github.com/madisvain) 8 | - [Lenno Nagel](https://github.com/lnagel) 9 | 10 | Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * drf-sideloading version: 2 | * Django version: 3 | * Python version: 4 | * Operating System: 5 | 6 | ### Description 7 | 8 | Describe what you were trying to get done. 9 | Tell us what happened, what went wrong, and what you expected to happen. 10 | 11 | ### What I Did 12 | 13 | ``` 14 | Paste the command(s) you ran and the output. 15 | If there was a crash, please include the traceback here. 16 | ``` 17 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | 11 | [*.{py,rst,ini}] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | [*.{html,css,scss,json,yml}] 16 | indent_style = space 17 | indent_size = 2 18 | 19 | [*.md] 20 | trim_trailing_whitespace = false 21 | 22 | [Makefile] 23 | indent_style = tab 24 | -------------------------------------------------------------------------------- /example/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for example project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.11/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "example.settings") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /example/scripts/devsetup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Create venv if not there (use venv) 6 | python3 -m venv --prompt "${PROJECT_NAME}" .env 7 | 8 | # Activate virtualenv 9 | source ./.env/bin/activate 10 | 11 | # Add drf-sideloading library to PYTHONPATH 12 | export PYTHONPATH=$PYTHONPATH:$(cd .. && pwd) 13 | 14 | # Install requirements 15 | pip install -r requirements.txt 16 | 17 | # Run migrate 18 | python manage.py migrate 19 | 20 | # Load example data from fixtures 21 | python manage.py loaddata products/fixtures/products.json 22 | -------------------------------------------------------------------------------- /runtests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import django 5 | from django.conf import settings 6 | from django.test.utils import get_runner 7 | 8 | 9 | def run_tests(*test_args): 10 | if not test_args: 11 | test_args = ["tests"] 12 | 13 | os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings" 14 | django.setup() 15 | TestRunner = get_runner(settings) 16 | test_runner = TestRunner() 17 | failures = test_runner.run_tests(test_args) 18 | sys.exit(bool(failures)) 19 | 20 | 21 | if __name__ == "__main__": 22 | run_tests(*sys.argv[1:]) 23 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | language: python 3 | 4 | python: 5 | - "3.6" 6 | - "3.7" 7 | - "3.8" 8 | - "3.9" 9 | - "3.10" 10 | - "3.11" 11 | - "3.12" 12 | 13 | matrix: 14 | fast_finish: true 15 | include: 16 | - python: 3.12 17 | env: TOXENV=lint 18 | 19 | # command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors 20 | install: 21 | - pip install -r requirements_test.txt 22 | 23 | # command to run tests using coverage, e.g. python setup.py test 24 | script: 25 | - tox 26 | 27 | after_success: 28 | - codecov -e TOX_ENV 29 | -------------------------------------------------------------------------------- /example/products/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class Category(models.Model): 5 | name = models.CharField(max_length=255) 6 | 7 | 8 | class Supplier(models.Model): 9 | name = models.CharField(max_length=255) 10 | 11 | 12 | class Partner(models.Model): 13 | name = models.CharField(max_length=255) 14 | 15 | 16 | class Product(models.Model): 17 | name = models.CharField(max_length=255) 18 | category = models.ForeignKey(Category, on_delete=models.CASCADE, related_name="products") 19 | supplier = models.ForeignKey(Supplier, on_delete=models.CASCADE, related_name="products") 20 | partners = models.ManyToManyField(Partner, related_name="products") 21 | -------------------------------------------------------------------------------- /example/.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | __pycache__ 3 | .env 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Packages 9 | *.egg 10 | *.egg-info 11 | dist 12 | build 13 | eggs 14 | parts 15 | bin 16 | var 17 | sdist 18 | develop-eggs 19 | .installed.cfg 20 | lib 21 | lib64 22 | 23 | # Installer logs 24 | pip-log.txt 25 | 26 | # Unit test / coverage reports 27 | .coverage 28 | .tox 29 | nosetests.xml 30 | htmlcov 31 | 32 | # Translations 33 | *.mo 34 | 35 | # Mr Developer 36 | .mr.developer.cfg 37 | .project 38 | .pydevproject 39 | 40 | # Pycharm/Intellij 41 | .idea 42 | .DS_Store 43 | 44 | # Complexity 45 | output/*.html 46 | output/*/index.html 47 | 48 | # Sphinx 49 | docs/_build 50 | 51 | *.sqlite3 52 | 53 | 54 | .python-version 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | db.sqlite3 3 | 4 | *.py[cod] 5 | __pycache__ 6 | env 7 | venv 8 | .python-version 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Packages 14 | *.egg 15 | *.egg-info 16 | dist 17 | build 18 | eggs 19 | parts 20 | bin 21 | var 22 | sdist 23 | develop-eggs 24 | .installed.cfg 25 | lib 26 | lib64 27 | 28 | # Installer logs 29 | pip-log.txt 30 | 31 | # Unit test / coverage reports 32 | .coverage 33 | .tox 34 | nosetests.xml 35 | htmlcov 36 | 37 | # Translations 38 | *.mo 39 | 40 | # Mr Developer 41 | .mr.developer.cfg 42 | .project 43 | .pydevproject 44 | 45 | # Pycharm/Intellij 46 | .idea 47 | .DS_Store 48 | 49 | # Complexity 50 | output/*.html 51 | output/*/index.html 52 | 53 | # Sphinx 54 | docs/_build 55 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | # Example Project 2 | 3 | This is very simple django application using django rest framework 4 | to demonstrate example use ceses and test the `drf_sideloading` library. 5 | 6 | This version requires python3 7 | 8 | ## Export PYTHONPATH 9 | 10 | To use latest version of cloned library export parent directory 11 | 12 | export PYTHONPATH=$PYTHONPATH:$(cd .. && pwd) 13 | 14 | Or install desired release using pip 15 | 16 | pip install drf-sideloading==0.1.7 17 | 18 | ## setup using script 19 | 20 | sh scripts/devsetup.sh 21 | 22 | ## Run using script 23 | 24 | sh scripts/dev.sh 25 | 26 | Visit browser: 27 | 28 | http://127.0.0.1:8000/ 29 | 30 | Test sideloading products endpoint 31 | 32 | http://127.0.0.1:8000/product/?sideload=category,supplier,partner 33 | -------------------------------------------------------------------------------- /tests/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import path, include 2 | from rest_framework import routers 3 | 4 | from tests import viewsets 5 | 6 | router = routers.DefaultRouter() 7 | router.register(r"product", viewsets.ProductViewSet) 8 | router.register(r"productlistonly", viewsets.ListOnlyProductViewSet, basename="productlistonly") 9 | router.register( 10 | r"productwrongmixinorder", viewsets.ProductViewSetSideloadingBeforeViews, basename="productwrongmixinorder" 11 | ) 12 | router.register(r"productretreiveonly", viewsets.RetreiveOnlyProductViewSet, basename="productretreiveonly") 13 | router.register(r"category", viewsets.CategoryViewSet) 14 | router.register(r"supplier", viewsets.SupplierViewSet) 15 | router.register(r"partner", viewsets.PartnerViewSet) 16 | 17 | urlpatterns = [path("", include(router.urls))] 18 | -------------------------------------------------------------------------------- /example/products/fixtures/products.json: -------------------------------------------------------------------------------- 1 | [ 2 | { "model": "products.category", "pk": 1, "fields": { "name": "Category1" } }, 3 | { "model": "products.category", "pk": 2, "fields": { "name": "Category2" } }, 4 | { "model": "products.supplier", "pk": 1, "fields": { "name": "Supplier1" } }, 5 | { "model": "products.supplier", "pk": 2, "fields": { "name": "Supplier2" } }, 6 | { "model": "products.partner", "pk": 1, "fields": { "name": "Partner1" } }, 7 | { "model": "products.partner", "pk": 2, "fields": { "name": "Partner1" } }, 8 | { "model": "products.partner", "pk": 3, "fields": { "name": "Partner3" } }, 9 | { 10 | "model": "products.product", 11 | "pk": 1, 12 | "fields": { 13 | "name": "Product 1", 14 | "category": 1, 15 | "supplier": 1, 16 | "partners": [1, 2, 3] 17 | } 18 | } 19 | ] 20 | -------------------------------------------------------------------------------- /example/manage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "settings") 6 | try: 7 | from django.core.management import execute_from_command_line 8 | except ImportError: 9 | # The above import may fail for some other reason. Ensure that the 10 | # issue is really that Django is missing to avoid masking other 11 | # exceptions on Python 2. 12 | try: 13 | import django 14 | except ImportError: 15 | raise ImportError( 16 | "Couldn't import Django. Are you sure it's installed and " 17 | "available on your PYTHONPATH environment variable? Did you " 18 | "forget to activate a virtual environment?" 19 | ) 20 | raise 21 | execute_from_command_line(sys.argv) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | (The MIT License) 2 | 3 | Copyright (c) Namespace OÜ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | DEBUG = True 2 | USE_TZ = True 3 | XXXXXXX = True 4 | 5 | # SECURITY WARNING: keep the secret key used in production secret! 6 | SECRET_KEY = "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz" 7 | 8 | DATABASES = {"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}} 9 | 10 | ROOT_URLCONF = "tests.urls" 11 | 12 | INSTALLED_APPS = [ 13 | "django.contrib.auth", 14 | "django.contrib.contenttypes", 15 | "django.contrib.staticfiles", 16 | "django.contrib.sites", 17 | "rest_framework", 18 | "drf_sideloading", 19 | "tests.TestsConfig", 20 | ] 21 | 22 | TEMPLATES = [ 23 | { 24 | "BACKEND": "django.template.backends.django.DjangoTemplates", 25 | "DIRS": [], 26 | "APP_DIRS": True, 27 | "OPTIONS": { 28 | "context_processors": [ 29 | "django.template.context_processors.debug", 30 | "django.template.context_processors.request", 31 | "django.contrib.auth.context_processors.auth", 32 | "django.contrib.messages.context_processors.messages", 33 | ] 34 | }, 35 | } 36 | ] 37 | 38 | STATIC_URL = "/static/" 39 | 40 | SITE_ID = 1 41 | 42 | MIDDLEWARE = () 43 | -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class Category(models.Model): 5 | name = models.CharField(max_length=255) 6 | 7 | 8 | class Supplier(models.Model): 9 | name = models.CharField(max_length=255) 10 | 11 | 12 | class SupplierMetadata(models.Model): 13 | supplier = models.OneToOneField(Supplier, related_name="metadata", on_delete=models.CASCADE) 14 | properties = models.CharField(max_length=255) 15 | 16 | 17 | class Partner(models.Model): 18 | name = models.CharField(max_length=255) 19 | 20 | 21 | class Product(models.Model): 22 | name = models.CharField(max_length=255) 23 | category = models.ForeignKey(Category, related_name="products", on_delete=models.CASCADE) 24 | supplier = models.ForeignKey(Supplier, related_name="products", on_delete=models.CASCADE) 25 | backup_supplier = models.ForeignKey( 26 | Supplier, related_name="backup_products", on_delete=models.CASCADE, null=True, blank=True 27 | ) 28 | partners = models.ManyToManyField(Partner, related_name="products", blank=True) 29 | 30 | 31 | class ProductMetadata(models.Model): 32 | product = models.OneToOneField(Product, related_name="metadata", on_delete=models.CASCADE) 33 | properties = models.CharField(max_length=255) 34 | -------------------------------------------------------------------------------- /example/products/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets 2 | 3 | from drf_sideloading.mixins import SideloadableRelationsMixin 4 | from .models import Product, Category, Supplier, Partner 5 | from .serializers import ( 6 | ProductSerializer, 7 | CategorySerializer, 8 | SupplierSerializer, 9 | PartnerSerializer, 10 | ProductSideloadableSerializer, 11 | CategorySideloadableSerializer, 12 | ) 13 | 14 | 15 | class ProductViewSet(SideloadableRelationsMixin, viewsets.ModelViewSet): 16 | """ 17 | A simple ViewSet for viewing and editing products. 18 | """ 19 | 20 | queryset = Product.objects.all() 21 | serializer_class = ProductSerializer 22 | sideloading_serializer_class = ProductSideloadableSerializer 23 | 24 | 25 | class CategoryViewSet(SideloadableRelationsMixin, viewsets.ModelViewSet): 26 | """ 27 | A more complex ViewSet with reverse relations. 28 | """ 29 | 30 | queryset = Category.objects.all() 31 | serializer_class = CategorySerializer 32 | sideloading_serializer_class = CategorySideloadableSerializer 33 | 34 | 35 | class SupplierViewSet(viewsets.ModelViewSet): 36 | queryset = Supplier.objects.all() 37 | serializer_class = SupplierSerializer 38 | 39 | 40 | class PartnerViewSet(viewsets.ModelViewSet): 41 | queryset = Partner.objects.all() 42 | serializer_class = PartnerSerializer 43 | -------------------------------------------------------------------------------- /example/urls.py: -------------------------------------------------------------------------------- 1 | """example URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/1.11/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.conf.urls import url, include 14 | 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) 15 | """ 16 | 17 | from django.conf import settings 18 | from django.urls import path, include 19 | from django.contrib import admin 20 | from rest_framework import routers 21 | 22 | from products.views import ( 23 | ProductViewSet, 24 | CategoryViewSet, 25 | SupplierViewSet, 26 | PartnerViewSet, 27 | ) 28 | 29 | router = routers.DefaultRouter() 30 | router.register(r"products", ProductViewSet) 31 | router.register(r"categorys", CategoryViewSet) 32 | router.register(r"suppliers", SupplierViewSet) 33 | router.register(r"partners", PartnerViewSet) 34 | 35 | 36 | urlpatterns = [path("admin/", admin.site.urls), path("", include(router.urls))] 37 | 38 | 39 | if settings.DEBUG: 40 | import debug_toolbar 41 | 42 | urlpatterns = [path("__debug__/", include(debug_toolbar.urls))] + urlpatterns 43 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | define BROWSER_PYSCRIPT 4 | import os, webbrowser, sys 5 | try: 6 | from urllib import pathname2url 7 | except: 8 | from urllib.request import pathname2url 9 | 10 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 11 | endef 12 | export BROWSER_PYSCRIPT 13 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 14 | 15 | help: 16 | @perl -nle'print $& if m{^[a-zA-Z_-]+:.*?## .*$$}' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-25s\033[0m %s\n", $$1, $$2}' 17 | 18 | clean: clean-build clean-pyc 19 | 20 | clean-build: ## remove build artifacts 21 | rm -fr build/ 22 | rm -fr dist/ 23 | rm -fr *.egg-info 24 | 25 | clean-pyc: ## remove Python file artifacts 26 | find . -name '*.pyc' -exec rm -f {} + 27 | find . -name '*.pyo' -exec rm -f {} + 28 | find . -name '*~' -exec rm -f {} + 29 | 30 | lint: ## check style with flake8 31 | flake8 drf_sideloading tests 32 | 33 | test: ## run tests quickly with the default Python 34 | python runtests.py tests 35 | 36 | test-all: ## run tests on every Python version with tox 37 | tox 38 | 39 | test-watch: ## run test in watch mode dependency entr and ag http://entrproject.org/ 40 | ag -l | entr make test 41 | 42 | coverage: ## check code coverage quickly with the default Python 43 | coverage run --source drf_sideloading runtests.py tests 44 | coverage report -m 45 | coverage html 46 | open htmlcov/index.html 47 | 48 | release: clean ## package and upload a release 49 | python setup.py sdist bdist_wheel 50 | twine upload dist/* 51 | -------------------------------------------------------------------------------- /example/products/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from drf_sideloading.serializers import SideLoadableSerializer 4 | from .models import Product, Category, Supplier, Partner 5 | 6 | 7 | class SupplierSerializer(serializers.ModelSerializer): 8 | class Meta: 9 | model = Supplier 10 | fields = "__all__" 11 | 12 | 13 | class PartnerSerializer(serializers.ModelSerializer): 14 | class Meta: 15 | model = Partner 16 | fields = "__all__" 17 | 18 | 19 | class CategorySerializer(serializers.ModelSerializer): 20 | class Meta: 21 | model = Category 22 | fields = "__all__" 23 | 24 | 25 | class ProductSerializer(serializers.ModelSerializer): 26 | class Meta: 27 | model = Product 28 | fields = "__all__" 29 | 30 | 31 | class CategorySideloadableSerializer(SideLoadableSerializer): 32 | categories = CategorySerializer(many=True) 33 | products = ProductSerializer(many=True) 34 | suppliers = SupplierSerializer(source="products__supplier", many=True) 35 | partners = PartnerSerializer(source="products__partners", many=True) 36 | 37 | class Meta: 38 | primary = "categories" 39 | prefetches = { 40 | "products": "products", 41 | "suppliers": "products__supplier", 42 | "partners": "products__partners", 43 | } 44 | 45 | 46 | class ProductSideloadableSerializer(SideLoadableSerializer): 47 | products = ProductSerializer(many=True) 48 | categories = CategorySerializer(source="category", many=True) 49 | suppliers = SupplierSerializer(source="supplier", many=True) 50 | partners = PartnerSerializer(many=True) 51 | 52 | class Meta: 53 | primary = "products" 54 | prefetches = { 55 | "categories": "category", 56 | "suppliers": "supplier", 57 | "partners": "partners", 58 | } 59 | -------------------------------------------------------------------------------- /example/products/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 1.11.3 on 2017-07-26 13:13 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | initial = True 10 | 11 | dependencies = [] 12 | 13 | operations = [ 14 | migrations.CreateModel( 15 | name="Category", 16 | fields=[ 17 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 18 | ("name", models.CharField(max_length=255)), 19 | ], 20 | ), 21 | migrations.CreateModel( 22 | name="Partner", 23 | fields=[ 24 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 25 | ("name", models.CharField(max_length=255)), 26 | ], 27 | ), 28 | migrations.CreateModel( 29 | name="Product", 30 | fields=[ 31 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 32 | ("name", models.CharField(max_length=255)), 33 | ("category", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="products.Category")), 34 | ("partners", models.ManyToManyField(to="products.Partner")), 35 | ], 36 | ), 37 | migrations.CreateModel( 38 | name="Supplier", 39 | fields=[ 40 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 41 | ("name", models.CharField(max_length=255)), 42 | ], 43 | ), 44 | migrations.AddField( 45 | model_name="product", 46 | name="supplier", 47 | field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="products.Supplier"), 48 | ), 49 | ] 50 | -------------------------------------------------------------------------------- /tests/migrations/0002_alter_product_category_alter_product_partners_and_more.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.16 on 2024-10-28 12:53 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | 6 | 7 | class Migration(migrations.Migration): 8 | dependencies = [ 9 | ("tests", "0001_initial"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AlterField( 14 | model_name="product", 15 | name="category", 16 | field=models.ForeignKey( 17 | on_delete=django.db.models.deletion.CASCADE, related_name="products", to="tests.category" 18 | ), 19 | ), 20 | migrations.AlterField( 21 | model_name="product", 22 | name="partners", 23 | field=models.ManyToManyField(blank=True, related_name="products", to="tests.partner"), 24 | ), 25 | migrations.CreateModel( 26 | name="SupplierMetadata", 27 | fields=[ 28 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 29 | ("properties", models.CharField(max_length=255)), 30 | ( 31 | "supplier", 32 | models.OneToOneField( 33 | on_delete=django.db.models.deletion.CASCADE, related_name="metadata", to="tests.supplier" 34 | ), 35 | ), 36 | ], 37 | ), 38 | migrations.CreateModel( 39 | name="ProductMetadata", 40 | fields=[ 41 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 42 | ("properties", models.CharField(max_length=255)), 43 | ( 44 | "product", 45 | models.OneToOneField( 46 | on_delete=django.db.models.deletion.CASCADE, related_name="metadata", to="tests.product" 47 | ), 48 | ), 49 | ], 50 | ), 51 | ] 52 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | # django 2 4 | py36-django22-drf{39,310,311,312}, 5 | py37-django22-drf{39,310,311,312}, 6 | py38-django22-drf{39,310,311,312}, 7 | py39-django22-drf{39,310,311,312}, 8 | 9 | # django 3 10 | py36-django{31,32}-drf312, 11 | py37-django{31,32}-drf312, 12 | py38-django{31,32}-drf312, 13 | py39-django{31,32}-drf312, 14 | py310-django32-drf312, 15 | 16 | # django 4.0 17 | py38-django40-drf{313,314}, 18 | py39-django40-drf{313,314}, 19 | py310-django40-drf{313,314}, 20 | 21 | # django 4.1+ 22 | py38-django41-drf314, 23 | py39-django41-drf314, 24 | py310-django41-drf314, 25 | py311-django{41,42}-drf314, 26 | 27 | # Django 5.0 28 | # * Python < 3.10 no longer supported 29 | # * DRF 3.15 first to support django 5 30 | py310-django{50,51}-drf315, 31 | py311-django{50,51}-drf315, 32 | py312-django{50,51}-drf315, 33 | 34 | lint 35 | 36 | [testenv] 37 | setenv = 38 | PYTHONPATH = {toxinidir}:{toxinidir}/drf_sideloading 39 | PYTHONDONTWRITEBYTECODE=1 40 | allowlist_externals = coverage 41 | commands = 42 | coverage run --source drf_sideloading runtests.py 43 | 44 | deps = 45 | # Django 46 | django22: Django>=2.2,<2.3 47 | django31: Django>=3.1,<3.2 48 | django32: Django>=3.2,<3.3 49 | django40: Django>=4.0,<4.1 50 | django41: Django>=4.1,<4.2 51 | django42: Django>=4.2,<4.3 52 | django50: Django>=5.0,<5.1 53 | django51: Django>=5.1,<5.2 54 | # Django rest framework 55 | drf39: djangorestframework>=3.9,<3.10 56 | drf310: djangorestframework>=3.10,<3.11 57 | drf311: djangorestframework>=3.11,<3.12 58 | drf312: djangorestframework>=3.12,<3.13 59 | drf313: djangorestframework>=3.13,<3.14 60 | drf314: djangorestframework>=3.14,<3.15 61 | drf315: djangorestframework>=3.15,<3.16 62 | 63 | -r{toxinidir}/requirements_test.txt 64 | 65 | basepython = 66 | py36: python3.6 67 | py37: python3.7 68 | py38: python3.8 69 | py39: python3.9 70 | py310: python3.10 71 | py311: python3.11 72 | py312: python3.12 73 | 74 | passenv = 75 | PYTHONPATH 76 | 77 | [testenv:lint] 78 | basepython = 79 | python3.12 80 | deps = 81 | flake8 82 | allowlist_externals = make 83 | commands = make lint 84 | -------------------------------------------------------------------------------- /tests/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 1.11.3 on 2017-07-26 13:13 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | initial = True 10 | 11 | dependencies = [] 12 | 13 | operations = [ 14 | migrations.CreateModel( 15 | name="Category", 16 | fields=[ 17 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 18 | ("name", models.CharField(max_length=255)), 19 | ], 20 | ), 21 | migrations.CreateModel( 22 | name="Partner", 23 | fields=[ 24 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 25 | ("name", models.CharField(max_length=255)), 26 | ], 27 | ), 28 | migrations.CreateModel( 29 | name="Product", 30 | fields=[ 31 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 32 | ("name", models.CharField(max_length=255)), 33 | ("category", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="tests.Category")), 34 | ("partners", models.ManyToManyField(to="tests.Partner")), 35 | ], 36 | ), 37 | migrations.CreateModel( 38 | name="Supplier", 39 | fields=[ 40 | ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), 41 | ("name", models.CharField(max_length=255)), 42 | ], 43 | ), 44 | migrations.AddField( 45 | model_name="product", 46 | name="supplier", 47 | field=models.ForeignKey( 48 | on_delete=django.db.models.deletion.CASCADE, to="tests.Supplier", related_name="products" 49 | ), 50 | ), 51 | migrations.AddField( 52 | model_name="product", 53 | name="backup_supplier", 54 | field=models.ForeignKey( 55 | on_delete=django.db.models.deletion.CASCADE, 56 | to="tests.Supplier", 57 | related_name="backup_products", 58 | null=True, 59 | blank=True, 60 | ), 61 | ), 62 | ] 63 | -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 2.2.2 (2024-10-28) 4 | - fix ReverseManyToOne reverse prefetch model selection 5 | 6 | ## 2.2.1 (2024-10-28) 7 | - fix ReverseManyToOne through prefetch model selection 8 | 9 | ## 2.2.0 (2024-10-22) 10 | - Support for Django 5 11 | - Django supported versions `5.0 -> 5.1` 12 | - Python supported versions `3.10 -> 3.12` 13 | - Django-rest-framework supported versions. `3.15` 14 | - Add support for request dependant prefetch filtering 15 | - refactored code for better readability 16 | - prefetches from view are also used when determining sideloading prefetches 17 | - Add support for drf_spectacular documentation 18 | - Add prefetch related support for multi source fields 19 | 20 | ## 2.1.0 (2024-01-26) 21 | 22 | - Support for Django 4 23 | - Django supported versions `4.0 -> 4.2` 24 | - Python supported versions `3.10 -> 3.11` 25 | - Django-rest-framework supported versions. `3.13 -> 3.14` 26 | - Fix issue with prefetch ordering 27 | 28 | ## 2.0.1 (2021-12-16) 29 | 30 | - Ensure that only allowed methods are sideloaded 31 | 32 | ## 2.0.0 (2021-12-10) 33 | 34 | Major refactoring to allow for multi source fields. 35 | 36 | - Add support for multi source fields 37 | - Add support for detail view sideloading 38 | - Dropped formless BrowsableAPIRenderer enforcement 39 | - Raises error in case invalid fields are requested for sideloading 40 | 41 | ## 1.4.2 (2021-04-12) 42 | 43 | - Add support for lists in filter_related_objects 44 | 45 | ## 1.4.1 (2021-04-09) 46 | 47 | - Fix sideloadable prefetches 48 | 49 | ## 1.4.0 (2021-04-07) 50 | 51 | - Python supported versions `3.6 -> 3.9` 52 | - Django supported versions `2.2`, `3.1`, `3.2` 53 | - Django-rest-framework supported versions. `3.9 -> 3.12` 54 | 55 | ## 1.3.1 (2021-04-07) 56 | 57 | Added support for `django.db.models.Prefetch` 58 | 59 | ## 1.3.0 (2019-04-23) 60 | 61 | Fix empty related fields sideloading bug 62 | 63 | - Support for Django 2.2 64 | 65 | ## 1.2.0 (2018-10-29) 66 | 67 | Completely refactored sideloading configuration via a custom serializer. 68 | 69 | - Support for Django 2.1 70 | - Support for Django-rest-framework 3.9 71 | 72 | ## 0.1.10 (2017-07-20) 73 | 74 | - Support for Django 2.0 75 | 76 | ## 0.1.8 (2017-07-20) 77 | 78 | - change sideloadable_relations dict 79 | - always required to define 'serializer' 80 | - key is referenced to url and serialized in as rendered json 81 | - add `source` which specifies original model field name 82 | 83 | ## 0.1.0 (2017-07-20) 84 | 85 | - First release on PyPI. 86 | -------------------------------------------------------------------------------- /drf_sideloading/schema.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from django.utils.translation import gettext_lazy as _ 3 | 4 | from drf_spectacular.utils import ( 5 | OpenApiParameter, 6 | OpenApiExample, 7 | OpenApiTypes, 8 | ) 9 | from drf_spectacular.openapi import AutoSchema 10 | 11 | 12 | class SideloadingAutoSchema(AutoSchema): 13 | override_parameters = [] 14 | 15 | def get_override_parameters(self): 16 | if self.method == "GET": 17 | # return self.override_parameters 18 | self.view.initialize_serializer(request=getattr(self.view, "request", None)) 19 | sideloading_keys_sources: Dict[str, Union[str, Dict[str, str]]] = self.view.get_sideloading_field_sources() 20 | sideloading_keys = list(k for k, v in sideloading_keys_sources.items() if isinstance(v, str)) 21 | multi_source_sideloading_items = { 22 | k: list(v.keys()) for k, v in sideloading_keys_sources.items() if isinstance(v, dict) 23 | } 24 | examples = [] 25 | if sideloading_keys: 26 | examples.append( 27 | OpenApiExample( 28 | name=_("Regular sideloading"), 29 | value=",".join(sideloading_keys[:2]), 30 | request_only=True, 31 | ) 32 | ) 33 | for k, v in multi_source_sideloading_items.items(): 34 | examples.append( 35 | OpenApiExample( 36 | name=_(f"Multi source sideloading for {k}"), 37 | value=f"{k}[{','.join(v)}]", 38 | request_only=True, 39 | ) 40 | ) 41 | return [ 42 | OpenApiParameter( 43 | name="sideload", 44 | type=OpenApiTypes.STR, 45 | location=OpenApiParameter.QUERY, 46 | many=True, 47 | description=_( 48 | "This option allows you to fetch related obejcts for all of the relations with a signle query. " 49 | "Multi-source sideloadable fields can be filtered by the sources by declaring the required " 50 | "sources in square brackets after the sideloading key. All available Mutli-source fields will " 51 | "have an example provided with all available sources. The comma separated sources can be " 52 | "ommited with the square brackets if all sources are to be sideloaded." 53 | ), 54 | enum=sideloading_keys_sources.keys(), 55 | examples=examples, 56 | ) 57 | ] 58 | return [] 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | 5 | try: 6 | from setuptools import setup 7 | except ImportError: 8 | from distutils.core import setup 9 | 10 | 11 | def get_version(*file_paths): 12 | """Retrieves the version from drf_sideloading/__init__.py""" 13 | filename = os.path.join(os.path.dirname(__file__), *file_paths) 14 | version_file = open(filename).read() 15 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 16 | if version_match: 17 | return version_match.group(1) 18 | raise RuntimeError("Unable to find version string.") 19 | 20 | 21 | version = get_version("drf_sideloading", "__init__.py") 22 | 23 | 24 | if sys.argv[-1] == "publish": 25 | try: 26 | import wheel 27 | 28 | print("Wheel version: ", wheel.__version__) 29 | except ImportError: 30 | print('Wheel library missing. Please run "pip install wheel"') 31 | sys.exit() 32 | os.system("python setup.py sdist upload") 33 | os.system("python setup.py bdist_wheel upload") 34 | sys.exit() 35 | 36 | if sys.argv[-1] == "tag": 37 | print("Tagging the version on git:") 38 | os.system("git tag -a %s -m 'version %s'" % (version, version)) 39 | os.system("git push --tags") 40 | sys.exit() 41 | 42 | readme = open("README.md").read() 43 | history = open("HISTORY.md").read() 44 | 45 | setup( 46 | name="drf-sideloading", 47 | version=version, 48 | description="""Extension for Django Rest Framework to enable simple sideloading""", 49 | long_description=readme + "\n\n" + history, 50 | long_description_content_type="text/markdown", 51 | author="Namespace OÜ", 52 | author_email="info@namespace.ee", 53 | url="https://github.com/namespace-ee/drf-sideloading", 54 | packages=["drf_sideloading"], 55 | include_package_data=True, 56 | install_requires=["Django>=2.0", "djangorestframework>=3.7.0"], 57 | license="MIT", 58 | zip_safe=False, 59 | keywords="drf-sideloading", 60 | classifiers=[ 61 | "Development Status :: 5 - Production/Stable", 62 | "Framework :: Django", 63 | "Framework :: Django :: 2.1", 64 | "Framework :: Django :: 2.2", 65 | "Framework :: Django :: 3.0", 66 | "Framework :: Django :: 3.1", 67 | "Framework :: Django :: 3.2", 68 | "Framework :: Django :: 4.0", 69 | "Framework :: Django :: 4.1", 70 | "Framework :: Django :: 4.2", 71 | "Intended Audience :: Developers", 72 | "License :: OSI Approved :: MIT License", 73 | "Natural Language :: English", 74 | "Programming Language :: Python", 75 | "Programming Language :: Python :: 3", 76 | "Programming Language :: Python :: 3.6", 77 | "Programming Language :: Python :: 3.7", 78 | "Programming Language :: Python :: 3.8", 79 | "Programming Language :: Python :: 3.9", 80 | "Programming Language :: Python :: 3.10", 81 | "Programming Language :: Python :: 3.11", 82 | "Programming Language :: Python :: 3.12", 83 | ], 84 | ) 85 | -------------------------------------------------------------------------------- /tests/viewsets.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets, filters, versioning 2 | from rest_framework.mixins import RetrieveModelMixin, ListModelMixin 3 | from rest_framework.viewsets import GenericViewSet 4 | 5 | from drf_sideloading.mixins import SideloadableRelationsMixin 6 | from tests.mixins import OtherMixin 7 | from tests.models import Product, Category, Supplier, Partner 8 | from tests.serializers import ( 9 | ProductSerializer, 10 | CategorySerializer, 11 | SupplierSerializer, 12 | PartnerSerializer, 13 | ProductSideloadableSerializer, 14 | CategorySideloadableSerializer, 15 | NewProductSideloadableSerializer, 16 | ) 17 | 18 | 19 | class ProductViewSet(SideloadableRelationsMixin, OtherMixin, viewsets.ModelViewSet): 20 | """ 21 | A simple ViewSet for viewing and editing products. 22 | """ 23 | 24 | queryset = Product.objects.all() 25 | serializer_class = ProductSerializer 26 | sideloading_serializer_class = ProductSideloadableSerializer 27 | versioning_class = versioning.AcceptHeaderVersioning 28 | filter_backends = [ 29 | filters.SearchFilter, 30 | # django_filters.rest_framework.DjangoFilterBackend, 31 | ] 32 | search_fields = ["name"] 33 | 34 | # filter_fields = ["confirmed"] 35 | 36 | def get_serializer_class(self): 37 | # if no super is called sideloading should still work 38 | return self.serializer_class 39 | 40 | def get_sideloading_serializer_class(self, request=None): 41 | # if no super is called sideloading should still work 42 | if self.request.version == "2.0.0": 43 | return NewProductSideloadableSerializer 44 | return super().get_sideloading_serializer_class(request=request) 45 | 46 | 47 | class ListOnlyProductViewSet(SideloadableRelationsMixin, OtherMixin, ListModelMixin, GenericViewSet): 48 | queryset = Product.objects.all() 49 | serializer_class = ProductSerializer 50 | sideloading_serializer_class = ProductSideloadableSerializer 51 | 52 | 53 | class RetreiveOnlyProductViewSet(SideloadableRelationsMixin, OtherMixin, RetrieveModelMixin, GenericViewSet): 54 | queryset = Product.objects.all() 55 | serializer_class = ProductSerializer 56 | sideloading_serializer_class = ProductSideloadableSerializer 57 | 58 | 59 | class ProductViewSetSideloadingBeforeViews(viewsets.ModelViewSet, SideloadableRelationsMixin, OtherMixin): 60 | queryset = Product.objects.all() 61 | serializer_class = ProductSerializer 62 | sideloading_serializer_class = ProductSideloadableSerializer 63 | 64 | 65 | class CategoryViewSet(SideloadableRelationsMixin, viewsets.ModelViewSet): 66 | queryset = Category.objects.all() 67 | serializer_class = CategorySerializer 68 | sideloading_serializer_class = CategorySideloadableSerializer 69 | 70 | 71 | class SupplierViewSet(viewsets.ModelViewSet): 72 | queryset = Supplier.objects.all() 73 | serializer_class = SupplierSerializer 74 | 75 | 76 | class PartnerViewSet(viewsets.ModelViewSet): 77 | queryset = Partner.objects.all() 78 | serializer_class = PartnerSerializer 79 | -------------------------------------------------------------------------------- /drf_sideloading/serializers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from rest_framework import serializers 4 | from rest_framework.fields import SkipField, empty 5 | 6 | 7 | class SideLoadableSerializer(serializers.Serializer): 8 | fields_to_load = None 9 | relations_to_sideload = None 10 | 11 | def __init__(self, instance=None, data=empty, relations_to_sideload=None, **kwargs): 12 | self.relations_to_sideload = relations_to_sideload 13 | self.fields_to_load = [self.Meta.primary] + list(relations_to_sideload.keys()) 14 | super(SideLoadableSerializer, self).__init__(instance=instance, data=data, **kwargs) 15 | 16 | @classmethod 17 | def many_init(cls, *args, **kwargs): 18 | raise NotImplementedError("Sideloadable serializer with many=True has not been implemented") 19 | 20 | @classmethod 21 | def check_setup(cls): 22 | # Check Meta class 23 | if not cls._declared_fields: 24 | raise ValueError("Setup error, no cls._declared_fields") 25 | if not getattr(cls, "Meta", None) or not getattr(cls.Meta, "primary", None): 26 | raise ValueError("Sideloadable serializer must have a Meta class defined with the 'primary' field name!") 27 | if cls.Meta.primary not in cls._declared_fields: 28 | raise ValueError("Sideloadable serializer Meta.primary must point to a field in the serializer!") 29 | if getattr(cls.Meta, "prefetches", None): 30 | if not isinstance(cls.Meta.prefetches, dict): 31 | raise ValueError("Sideloadable serializer Meta attribute 'prefetches' must be a dict.") 32 | 33 | # check serializer fields: 34 | for name, field in cls._declared_fields.items(): 35 | if not getattr(field, "many", None): 36 | raise ValueError(f"SideLoadable field '{name}' must be set as many=True") 37 | if not isinstance(field.child, serializers.ModelSerializer): 38 | raise ValueError(f"SideLoadable field '{name}' serializer must be inherited from ModelSerializer") 39 | 40 | def to_representation(self, instance): 41 | """ 42 | Object instance -> Dict of primitive datatypes. 43 | """ 44 | ret = OrderedDict() 45 | fields = [ 46 | f 47 | for f in self.fields.values() 48 | if not f.write_only and f.source in instance.keys() and f.field_name in self.fields_to_load 49 | ] 50 | 51 | for field in fields: 52 | try: 53 | attribute = field.get_attribute(instance) 54 | except SkipField: 55 | continue 56 | 57 | # We skip `to_representation` for `None` values so that fields do 58 | # not have to explicitly deal with that case. 59 | # 60 | # For related fields with `use_pk_only_optimization` we need to 61 | # resolve the pk value. 62 | if getattr(attribute, "pk", attribute) is None: 63 | ret[field.field_name] = None 64 | else: 65 | ret[field.field_name] = field.to_representation(attribute) 66 | 67 | return ret 68 | -------------------------------------------------------------------------------- /example/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for example project. 3 | 4 | Generated by 'django-admin startproject' using Django 1.11.3. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.11/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/1.11/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = "4hj+_4(v!gde+uwr^9_skgl6y9t+3*rbob2__0-by0(vf%*_qk" 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | "django.contrib.admin", 35 | "django.contrib.auth", 36 | "django.contrib.contenttypes", 37 | "django.contrib.sessions", 38 | "django.contrib.messages", 39 | "django.contrib.staticfiles", 40 | "rest_framework", 41 | "example", 42 | "products", 43 | # 'debug_toolbar', 44 | ] 45 | 46 | MIDDLEWARE = [ 47 | "django.middleware.security.SecurityMiddleware", 48 | "django.contrib.sessions.middleware.SessionMiddleware", 49 | "django.middleware.common.CommonMiddleware", 50 | "django.middleware.csrf.CsrfViewMiddleware", 51 | "django.contrib.auth.middleware.AuthenticationMiddleware", 52 | "django.contrib.messages.middleware.MessageMiddleware", 53 | "django.middleware.clickjacking.XFrameOptionsMiddleware", 54 | # 'debug_toolbar.middleware.DebugToolbarMiddleware', 55 | ] 56 | 57 | ROOT_URLCONF = "urls" 58 | 59 | TEMPLATES = [ 60 | { 61 | "BACKEND": "django.template.backends.django.DjangoTemplates", 62 | "DIRS": [], 63 | "APP_DIRS": True, 64 | "OPTIONS": { 65 | "context_processors": [ 66 | "django.template.context_processors.debug", 67 | "django.template.context_processors.request", 68 | "django.contrib.auth.context_processors.auth", 69 | "django.contrib.messages.context_processors.messages", 70 | ] 71 | }, 72 | } 73 | ] 74 | 75 | WSGI_APPLICATION = "example.wsgi.application" 76 | 77 | 78 | # Database 79 | # https://docs.djangoproject.com/en/1.11/ref/settings/#databases 80 | 81 | DATABASES = { 82 | "default": { 83 | "ENGINE": "django.db.backends.sqlite3", 84 | "NAME": os.path.join(BASE_DIR, "db.sqlite3"), 85 | } 86 | } 87 | 88 | 89 | # Password validation 90 | # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators 91 | 92 | AUTH_PASSWORD_VALIDATORS = [ 93 | {"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator"}, 94 | {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, 95 | {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, 96 | {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, 97 | ] 98 | 99 | 100 | # Internationalization 101 | # https://docs.djangoproject.com/en/1.11/topics/i18n/ 102 | 103 | LANGUAGE_CODE = "en-us" 104 | 105 | TIME_ZONE = "UTC" 106 | 107 | USE_I18N = True 108 | 109 | USE_L10N = True 110 | 111 | USE_TZ = True 112 | 113 | 114 | # Static files (CSS, JavaScript, Images) 115 | # https://docs.djangoproject.com/en/1.11/howto/static-files/ 116 | 117 | STATIC_URL = "/static/" 118 | -------------------------------------------------------------------------------- /tests/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from drf_sideloading.serializers import SideLoadableSerializer 4 | from tests.models import Supplier, Category, Product, Partner, ProductMetadata, SupplierMetadata 5 | 6 | 7 | class SupplierMetadataSerializer(serializers.ModelSerializer): 8 | class Meta: 9 | model = SupplierMetadata 10 | fields = ["supplier", "properties"] 11 | 12 | 13 | class SupplierSerializer(serializers.ModelSerializer): 14 | metadata = SupplierMetadataSerializer(read_only=True) 15 | 16 | class Meta: 17 | model = Supplier 18 | fields = ["name", "metadata"] 19 | 20 | 21 | class PartnerSerializer(serializers.ModelSerializer): 22 | class Meta: 23 | model = Partner 24 | fields = ["name"] 25 | 26 | 27 | class CategorySerializer(serializers.ModelSerializer): 28 | class Meta: 29 | model = Category 30 | fields = ["name"] 31 | 32 | 33 | class ProductMetadataSerializer(serializers.ModelSerializer): 34 | class Meta: 35 | model = ProductMetadata 36 | fields = ["product", "properties"] 37 | 38 | 39 | class ProductSerializer(serializers.ModelSerializer): 40 | metadata = ProductMetadataSerializer(read_only=True) 41 | 42 | class Meta: 43 | model = Product 44 | fields = ["name", "category", "supplier", "partners", "metadata"] 45 | 46 | 47 | class CategorySideloadableSerializer(SideLoadableSerializer): 48 | categories = CategorySerializer(many=True) 49 | products = ProductSerializer(many=True) 50 | suppliers = SupplierSerializer(source="products__supplier", many=True) 51 | partners = PartnerSerializer(source="products__partners", many=True) 52 | 53 | class Meta: 54 | primary = "categories" 55 | prefetches = { 56 | "products": "products", 57 | "suppliers": "products__supplier", 58 | "partners": "products__partners", 59 | } 60 | 61 | 62 | class ProductSideloadableSerializer(SideLoadableSerializer): 63 | products = ProductSerializer(many=True) 64 | categories = CategorySerializer(source="category", many=True) 65 | main_suppliers = SupplierSerializer(source="supplier", many=True) 66 | backup_suppliers = SupplierSerializer(source="backup_supplier", many=True) 67 | partners = PartnerSerializer(source="partner", many=True) 68 | combined_suppliers = SupplierSerializer(many=True) 69 | metadata = ProductMetadataSerializer(many=True, read_only=True) 70 | 71 | class Meta: 72 | primary = "products" 73 | prefetches = { 74 | "categories": "category", 75 | "main_suppliers": ["supplier", "supplier__metadata"], 76 | "backup_suppliers": ["backup_supplier", "backup_supplier__metadata"], 77 | "partners": "partners", 78 | # These can be defined to always load them, else they will be 79 | # copied over form all sources or selected sources only. 80 | "combined_suppliers": { 81 | "suppliers": ["supplier", "supplier__metadata"], 82 | "backup_supplier": ["backup_supplier", "backup_supplier__metadata"], 83 | }, 84 | "metadata": "metadata", 85 | } 86 | 87 | 88 | class NewProductSideloadableSerializer(SideLoadableSerializer): 89 | products = ProductSerializer(many=True) 90 | new_categories = CategorySerializer(source="category", many=True) 91 | new_main_suppliers = SupplierSerializer(source="supplier", many=True) 92 | new_backup_suppliers = SupplierSerializer(source="backup_supplier", many=True) 93 | new_partners = PartnerSerializer(source="partner", many=True) 94 | combined_suppliers = SupplierSerializer(many=True) 95 | metadata = ProductMetadataSerializer(many=True, read_only=True) 96 | 97 | class Meta: 98 | primary = "products" 99 | prefetches = { 100 | "new_categories": "category", 101 | "new_main_suppliers": ["supplier", "supplier__metadata"], 102 | "new_backup_suppliers": ["backup_supplier", "backup_supplier__metadata"], 103 | "new_partners": "partners", 104 | # These can be defined to always load them, else they will be 105 | # copied over form all sources or selected sources only. 106 | "combined_suppliers": { 107 | "suppliers": ["supplier", "supplier__metadata"], 108 | "backup_supplier": ["backup_supplier", "backup_supplier__metadata"], 109 | }, 110 | "metadata": "metadata", 111 | } 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Package Index](https://badge.fury.io/py/drf-sideloading.svg)](https://badge.fury.io/py/drf-sideloading) 2 | [![Build Status](https://travis-ci.org/namespace-ee/django-rest-framework-sideloading.svg?branch=master)](https://travis-ci.org/namespace-ee/django-rest-framework-sideloading) 3 | [![Code Coverage](https://codecov.io/gh/namespace-ee/django-rest-framework-sideloading/branch/master/graph/badge.svg)](https://codecov.io/gh/namespace-ee/django-rest-framework-sideloading) 4 | [![License is MIT](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/namespace-ee/drf-sideloading/blob/master/LICENSE) 5 | [![Code style Black](https://img.shields.io/badge/code%20style-black-000000.svg?maxAge=2592000)](https://github.com/ambv/black) 6 | 7 | :warning: Note that there are major API changes since version 0.1.1 that have to be taken into account when upgrading! 8 | 9 | :warning: Python 2 and Django 1.11 are no longer supported from version 1.4.0! 10 | 11 | # Django rest framework sideloading 12 | 13 | DRF-sideloading is an extension to provide side-loading functionality of related resources. Side-loading allows related resources to be optionally included in a single API response minimizing requests to the API. 14 | 15 | ## Quickstart 16 | 17 | 1. Install drf-sideloading: 18 | 19 | ```shell 20 | pip install drf-sideloading 21 | ``` 22 | 23 | 2. Import `SideloadableRelationsMixin`: 24 | 25 | ```python 26 | from drf_sideloading.mixins import SideloadableRelationsMixin 27 | ``` 28 | 29 | 3. Write your SideLoadableSerializer: 30 | 31 | You need to define the **primary** serializer in the Meta data and can define prefetching rules. 32 | Also notice the **many=True** on the sideloadable relationships. 33 | 34 | ```python 35 | from drf_sideloading.serializers import SideLoadableSerializer 36 | 37 | class ProductSideloadableSerializer(SideLoadableSerializer): 38 | products = ProductSerializer(many=True) 39 | categories = CategorySerializer(source="category", many=True) 40 | primary_suppliers = SupplierSerializer(source="primary_supplier", many=True) 41 | secondary_suppliers = SupplierSerializer(many=True) 42 | suppliers = SupplierSerializer(many=True) 43 | partners = PartnerSerializer(many=True) 44 | 45 | class Meta: 46 | primary = "products" 47 | prefetches = { 48 | "categories": "category", 49 | "primary_suppliers": "primary_supplier", 50 | "secondary_suppliers": "secondary_suppliers", 51 | "suppliers": { 52 | "primary_suppliers": "primary_supplier", 53 | "secondary_suppliers": "secondary_suppliers", 54 | }, 55 | "partners": "partners", 56 | } 57 | ``` 58 | 59 | 4. Prefetches 60 | 61 | For fields where the source is provided or where the source matches the field name, prefetches are not strictly required 62 | 63 | Multiple prefetches can be added to a single sideloadable field, but when using Prefetch object check that they don't clash with prefetches made in the get_queryset() method 64 | ```python 65 | from django.db.models import Prefetch 66 | 67 | prefetches = { 68 | "categories": "category", 69 | "primary_suppliers": ["primary_supplier", "primary_supplier__some_related_object"], 70 | "secondary_suppliers": Prefetch( 71 | lookup="secondary_suppliers", 72 | queryset=Supplier.objects.prefetch_related("some_related_object") 73 | ), 74 | "partners": Prefetch( 75 | lookup="partners", 76 | queryset=Partner.objects.select_related("some_related_object") 77 | ) 78 | } 79 | ``` 80 | 81 | Multiple sources can be added to a field using a dict. 82 | Each key is a source_key that can be used to filter what sources should be sideloaded. 83 | The values set the source and prefetches for this source. 84 | 85 | Note that this prefetch reuses `primary_supplier` and `secondary_suppliers` if suppliers and primary_supplier or secondary_suppliers are sideloaded 86 | ```python 87 | prefetches = { 88 | "primary_suppliers": "primary_supplier", 89 | "secondary_suppliers": "secondary_suppliers", 90 | "suppliers": { 91 | "primary_suppliers": "primary_supplier", 92 | "secondary_suppliers": "secondary_suppliers" 93 | } 94 | } 95 | ``` 96 | 97 | Usage of Prefetch() objects is supported. 98 | Prefetch() objects can be used to filter a subset of some relations or just to prefetch or select complicated related objects 99 | In case there are prefetch conflicts, `to_attr` can be set but be aware that this prefetch will now be a duplicate of similar prefetches. 100 | prefetch conflicts can also come from prefetched made in the ViewSet.get_queryset() method. 101 | 102 | Note that this prefetch noes not reuse `primary_supplier` and `secondary_suppliers` if **suppliers** and **primary_supplier** or **secondary_suppliers** are sideloaded at the same time. 103 | ```python 104 | from django.db.models import Prefetch 105 | 106 | prefetches = { 107 | "categories": "category", 108 | "primary_suppliers": "primary_supplier", 109 | "secondary_suppliers": "secondary_suppliers", 110 | "suppliers": { 111 | "primary_suppliers": Prefetch( 112 | lookup="secondary_suppliers", 113 | queryset=Supplier.objects.select_related("some_related_object"), 114 | to_attr="secondary_suppliers_with_preselected_relation" 115 | ), 116 | "secondary_suppliers": Prefetch( 117 | lookup="secondary_suppliers", 118 | queryset=Supplier.objects.filter(created_at__gt=pendulum.now().subtract(days=10)).order_by("created_at"), 119 | to_attr="latest_secondary_suppliers" 120 | ) 121 | }, 122 | } 123 | ``` 124 | 125 | 5. Configure sideloading in ViewSet: 126 | 127 | Include **SideloadableRelationsMixin** mixin in ViewSet and define **sideloading_serializer_class** as shown in example below. 128 | Everything else stays just like a regular ViewSet. 129 | Since version 2.0.0 there are 3 new methods that allow to overwrite the serializer used based on the request version for example 130 | Since version 2.1.0 an additional method was added that allow to add request dependent filters to sideloaded relations 131 | 132 | ```python 133 | from drf_sideloading.mixins import SideloadableRelationsMixin 134 | 135 | class ProductViewSet(SideloadableRelationsMixin, viewsets.ModelViewSet): 136 | """ 137 | A simple ViewSet for viewing and editing products. 138 | """ 139 | 140 | queryset = Product.objects.all() 141 | serializer_class = ProductSerializer 142 | sideloading_serializer_class = ProductSideloadableSerializer 143 | 144 | def get_queryset(self): 145 | # Add prefetches for the viewset as normal 146 | return super().get_queryset().prefetch_related("created_by") 147 | 148 | def get_sideloading_serializer_class(self, request=None): 149 | # use a different sideloadable serializer for older version 150 | if self.request.version < "1.0.0": 151 | return OldProductSideloadableSerializer 152 | return super().get_sideloading_serializer_class(request=request) 153 | 154 | def get_sideloading_serializer(self, *args, **kwargs): 155 | # if modifications are required to the serializer initialization this method can be used. 156 | return super().get_sideloading_serializer(*args, **kwargs) 157 | 158 | def get_sideloading_serializer_context(self): 159 | # Extra context provided to the serializer class. 160 | return {"request": self.request, "format": self.format_kwarg, "view": self} 161 | 162 | def add_sideloading_prefetch_filter(self, source, queryset, request): 163 | # 164 | if source == "model1__relation1": 165 | return queryset.filter(is_active=True), True 166 | if hasattr(queryset, "readable"): 167 | return queryset.readable(user=request.user), True 168 | return queryset, False 169 | ``` 170 | 171 | 6. Enjoy your API with sideloading support 172 | 173 | Example request and response when fetching all possible values 174 | ```http 175 | GET /api/products/?sideload=categories,partners,primary_suppliers,secondary_suppliers,suppliers,products 176 | ``` 177 | 178 | ```json 179 | { 180 | "products": [ 181 | { 182 | "id": 1, 183 | "name": "Product 1", 184 | "category": 1, 185 | "primary_supplier": 1, 186 | "secondary_suppliers": [2, 3], 187 | "partners": [1, 2, 3] 188 | } 189 | ], 190 | "categories": [ 191 | { 192 | "id": 1, 193 | "name": "Category1" 194 | } 195 | ], 196 | "primary_suppliers": [ 197 | { 198 | "id": 1, 199 | "name": "Supplier1" 200 | } 201 | ], 202 | "secondary_suppliers": [ 203 | { 204 | "id": 2, 205 | "name": "Supplier2" 206 | }, 207 | { 208 | "id": 3, 209 | "name": "Supplier3" 210 | } 211 | ], 212 | "suppliers": [ 213 | { 214 | "id": 1, 215 | "name": "Supplier1" 216 | }, 217 | { 218 | "id": 2, 219 | "name": "Supplier2" 220 | }, 221 | { 222 | "id": 3, 223 | "name": "Supplier3" 224 | } 225 | ], 226 | "partners": [ 227 | { 228 | "id": 1, 229 | "name": "Partner1" 230 | }, 231 | { 232 | "id": 2, 233 | "name": "Partner1" 234 | }, 235 | { 236 | "id": 3, 237 | "name": "Partner3" 238 | } 239 | ] 240 | } 241 | ``` 242 | 243 | The user can also select what sources to load to Multi source fields. 244 | Leaving the selections empty or omitting the brackets will load all the prefetched sources. 245 | 246 | Example: 247 | 248 | ```http 249 | GET /api/products/?sideload=suppliers[primary_suppliers] 250 | ``` 251 | ```json 252 | { 253 | "products": [ 254 | { 255 | "id": 1, 256 | "name": "Product 1", 257 | "category": 1, 258 | "primary_supplier": 1, 259 | "secondary_suppliers": [2, 3], 260 | "partners": [1, 2, 3] 261 | } 262 | ], 263 | "suppliers": [ 264 | { 265 | "id": 1, 266 | "name": "Supplier1" 267 | } 268 | ] 269 | } 270 | ``` 271 | ## Example Project 272 | 273 | Directory `example` contains an example project using django rest framework sideloading library. You can set it up and run it locally using following commands: 274 | 275 | ```shell 276 | cd example 277 | sh scripts/devsetup.sh 278 | sh scripts/dev.sh 279 | ``` 280 | 281 | ## Contributing 282 | 283 | Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. 284 | 285 | #### Setup for contribution 286 | 287 | ```shell 288 | source /bin/activate 289 | (myenv) $ pip install -r requirements_dev.txt 290 | ``` 291 | 292 | ### Test 293 | 294 | ```shell 295 | $ make test 296 | ``` 297 | 298 | #### Run tests with environment matrix 299 | 300 | ```shell 301 | $ make tox 302 | ``` 303 | 304 | #### Run tests with specific environment 305 | 306 | ```shell 307 | $ tox --listenvs 308 | py37-django22-drf39 309 | py38-django31-drf311 310 | py39-django32-drf312 311 | # ... 312 | $ tox -e py39-django32-drf312 313 | ``` 314 | 315 | #### Test coverage 316 | 317 | ```shell 318 | $ make coverage 319 | ``` 320 | 321 | Use [pyenv](https://github.com/pyenv/pyenv) for testing using different python versions locally. 322 | 323 | ## License 324 | 325 | [MIT](https://github.com/namespace-ee/drf-sideloading/blob/master/LICENSE) 326 | 327 | ## Credits 328 | 329 | - [Demur Nodia](https://github.com/demonno) 330 | - [Tõnis Väin](https://github.com/tonisvain) 331 | - [Madis Väin](https://github.com/madisvain) 332 | - [Lenno Nagel](https://github.com/lnagel) 333 | -------------------------------------------------------------------------------- /drf_sideloading/mixins.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib 3 | import re 4 | from itertools import chain 5 | from typing import Dict, Optional, Union, Set, List 6 | 7 | from django.core.exceptions import ValidationError as DjangoValidationError 8 | from django.db import models 9 | from django.db.models import Prefetch 10 | from django.db.models.fields.related_descriptors import ( 11 | ForwardManyToOneDescriptor, 12 | ForwardOneToOneDescriptor, 13 | ReverseOneToOneDescriptor, 14 | ReverseManyToOneDescriptor, 15 | ) 16 | from django.db.models.sql.where import WhereNode, AND 17 | from django.http import Http404 18 | from django.utils.translation import gettext_lazy as _ 19 | from rest_framework.exceptions import ValidationError 20 | from rest_framework.generics import get_object_or_404 21 | from rest_framework.mixins import RetrieveModelMixin, ListModelMixin 22 | from rest_framework.response import Response 23 | from rest_framework.serializers import ListSerializer 24 | 25 | from drf_sideloading.serializers import SideLoadableSerializer 26 | 27 | 28 | RELATION_DESCRIPTORS = [ 29 | ForwardManyToOneDescriptor, 30 | ForwardOneToOneDescriptor, 31 | ReverseOneToOneDescriptor, 32 | ReverseManyToOneDescriptor, 33 | ] 34 | 35 | 36 | def contains_where_node(existing_node: WhereNode, new_node: WhereNode) -> bool: 37 | """ 38 | Checks if the existing_node contains the new_node. 39 | It will no check OR conditions however! 40 | """ 41 | if not isinstance(new_node, WhereNode): 42 | raise ValueError("new_node has to be a WhereNode instance") 43 | if not isinstance(existing_node, WhereNode): 44 | return False 45 | if not set(new_node.children) - set(existing_node.children): # all new node children applied 46 | return True 47 | if existing_node.connector == AND: 48 | for child_node in existing_node.children: 49 | exists = contains_where_node(child_node, new_node) 50 | if exists: 51 | return True 52 | return False 53 | 54 | 55 | class SideloadableRelationsMixin(object): 56 | sideloading_query_param_name = "sideload" 57 | sideloading_serializer_class = None 58 | primary_field_name: str = None 59 | sideloadable_fields: Dict = {} 60 | user_defined_prefetches: Dict = {} 61 | primary_field = None 62 | sideloadable_field_sources: Dict = {} 63 | if importlib.util.find_spec("drf_spectacular") is not None: 64 | from drf_sideloading.schema import SideloadingAutoSchema 65 | 66 | # note: if required, the user can overwrite the schema 67 | schema = SideloadingAutoSchema() 68 | 69 | def __init__(self, **kwargs): 70 | super().__init__(**kwargs) 71 | self.check_sideloading_serializer_class(self.sideloading_serializer_class) 72 | 73 | def initialize_serializer(self, request): 74 | sideloading_serializer_class = self.get_sideloading_serializer_class(request=request) 75 | self.check_sideloading_serializer_class(sideloading_serializer_class) 76 | 77 | # sideloadable fields 78 | self.sideloadable_fields = copy.deepcopy(sideloading_serializer_class._declared_fields) 79 | self.primary_field_name = sideloading_serializer_class.Meta.primary 80 | self.primary_field = self.sideloadable_fields.pop(self.primary_field_name) 81 | self.primary_model = self.primary_field.child.Meta.model 82 | 83 | # fetch sideloading sources and prefetches 84 | self.user_defined_prefetches = getattr(sideloading_serializer_class.Meta, "prefetches", {}) 85 | self.sideloadable_field_sources = self.get_sideloading_field_sources() 86 | 87 | def get_source_from_prefetch(self, prefetches: Union[str, List, Dict]): 88 | if isinstance(prefetches, str): 89 | return prefetches 90 | if isinstance(prefetches, Prefetch): 91 | return prefetches.to_attr or prefetches.prefetch_through 92 | elif isinstance(prefetches, dict): 93 | if any(isinstance(v, dict) for v in prefetches.values()): 94 | raise ValueError("Can't find source to_attr from dict.") 95 | return {k: self.get_source_from_prefetch(v) for k, v in prefetches.items()} 96 | elif isinstance(prefetches, list): 97 | if not all(isinstance(v, (str, Prefetch)) for v in prefetches): 98 | raise ValueError("Can't find source to_attr from list not containing only strings or prefetches.") 99 | return sorted(self.get_source_from_prefetch(v) for v in prefetches)[0] 100 | 101 | def get_sideloading_field_sources(self) -> Dict: 102 | if not self.sideloadable_fields: 103 | raise ValueError("Sideloading serializer has not been initialized") 104 | 105 | relations_sources = {} 106 | for relation, field in self.sideloadable_fields.items(): 107 | relation_prefetches = self.user_defined_prefetches.get(relation) 108 | if isinstance(relation_prefetches, dict) and any(isinstance(v, dict) for v in relation_prefetches.values()): 109 | relation_prefetches = { 110 | k: self._clean_prefetches(field=field, relation=relation, value=v) 111 | for k, v in relation_prefetches.items() 112 | } 113 | 114 | sideloadable_field_source = field.child.source 115 | 116 | # its a MultiSource field, fetch values from sources defined with prefetches. 117 | if isinstance(relation_prefetches, dict) and sideloadable_field_source: 118 | raise ValueError("Multi source field with source defined in serializer.") 119 | 120 | if relation_prefetches: 121 | data_source = self.get_source_from_prefetch(relation_prefetches) 122 | elif sideloadable_field_source: 123 | data_source = sideloadable_field_source 124 | elif isinstance(getattr(self.primary_model, relation), tuple(RELATION_DESCRIPTORS)): 125 | data_source = relation 126 | else: 127 | raise ValueError(f"Could not determine source for field '{relation}'.") 128 | 129 | relations_sources[relation] = data_source 130 | 131 | return relations_sources 132 | 133 | def get_relations_to_sideload(self, request) -> Optional[Dict]: 134 | """ 135 | Parse query param and take validated names 136 | 137 | :param sideload_parameter string 138 | :return valid relation names list 139 | 140 | comma separated relation names may contain invalid or unusable characters. 141 | This function finds string match between requested names and defined relation in view 142 | 143 | new: 144 | 145 | response changed to dict as the sources for multi source fields must be selectable. 146 | 147 | """ 148 | if request.method != "GET": 149 | return None 150 | 151 | if self.sideloading_query_param_name not in request.query_params: 152 | return None 153 | 154 | sideload_parameter = request.query_params[self.sideloading_query_param_name] 155 | if not sideload_parameter: 156 | return None 157 | # raise ValidationError({self.sideloading_query_param_name: [_(f"'{relation}' Can not be blank.")]}) 158 | 159 | # This fetches the correct serializer and prepares sideloadable_fields ect. 160 | self.initialize_serializer(request=request) 161 | 162 | relations_to_sideload = {} 163 | for param in re.split(r",\s*(?![^\[\]]*\])", sideload_parameter): 164 | if "[" in param: 165 | fieldname, sources_str = param.split("[", 1) 166 | if not sources_str.strip("]"): 167 | msg = _(f"'{fieldname}' source can not be empty.") 168 | raise ValidationError({self.sideloading_query_param_name: [msg]}) 169 | relations = set(sources_str.strip("]").split(",")) 170 | else: 171 | fieldname = param 172 | relations = None 173 | 174 | if fieldname not in self.sideloadable_fields: 175 | msg = _(f"'{fieldname}' is not one of the available choices.") 176 | raise ValidationError({self.sideloading_query_param_name: [msg]}) 177 | 178 | # check for source selection. select all if nothing given 179 | if isinstance(self.user_defined_prefetches.get(fieldname), dict): 180 | source_relations = sorted(self.user_defined_prefetches[fieldname].keys()) 181 | if relations is None: 182 | relations = source_relations 183 | else: 184 | # Check if all requested sources are defined 185 | invalid_sources = set(relations) - set(source_relations) 186 | if invalid_sources: 187 | msg = _(f"'{fieldname}' sources {', '.join(invalid_sources)} are not defined.") 188 | raise ValidationError({self.sideloading_query_param_name: [msg]}) 189 | elif relations: 190 | msg = _(f"'{fieldname}' is not a multi source field.") 191 | raise ValidationError({self.sideloading_query_param_name: [msg]}) 192 | 193 | # everything checks out. 194 | relations_to_sideload[fieldname] = relations 195 | 196 | return relations_to_sideload 197 | 198 | def check_sideloading_serializer_class(self, sideloading_serializer_class): 199 | if not sideloading_serializer_class: 200 | raise ValueError(f"'{self.__class__.__name__}' sideloading_serializer_class not found") 201 | if not issubclass(sideloading_serializer_class, SideLoadableSerializer): 202 | raise ValueError( 203 | f"'{self.__class__.__name__}' sideloading_serializer_class must be a SideLoadableSerializer subclass" 204 | ) 205 | sideloading_serializer_class.check_setup() 206 | 207 | def get_sideloading_serializer(self, *args, **kwargs): 208 | """ 209 | Return the sideloading_serializer instance that should be used for serializing output. 210 | """ 211 | sideloading_serializer_class = self.get_sideloading_serializer_class() 212 | kwargs["context"] = self.get_sideloading_serializer_context() 213 | return sideloading_serializer_class(*args, **kwargs) 214 | 215 | def get_sideloading_serializer_class(self, request=None): 216 | """ 217 | Return the class to use for the sideloading_serializer. 218 | Defaults to using `self.sideloading_serializer_class`. 219 | 220 | You may want to override this if you need to provide different 221 | serializations depending on the incoming request. 222 | 223 | (Eg. admins get full serialization, others get basic serialization) 224 | """ 225 | assert self.sideloading_serializer_class is not None, ( 226 | f"'{self.__class__.__name__}' should either include a `sideloading_serializer_class` attribute, " 227 | f"or override the `get_sideloading_serializer_class()` method." 228 | ) 229 | 230 | return self.sideloading_serializer_class 231 | 232 | def get_sideloading_serializer_context(self): 233 | """ 234 | Extra context provided to the serializer class. 235 | """ 236 | return {"request": self.request, "format": self.format_kwarg, "view": self} 237 | 238 | def get_sideloadable_queryset(self, prefetch): 239 | if isinstance(prefetch, str): 240 | model = self.primary_model 241 | for x in prefetch.split("__"): 242 | descriptor = getattr(model, x) 243 | if isinstance(descriptor, ForwardManyToOneDescriptor): 244 | model = descriptor.field.remote_field.model 245 | elif isinstance(descriptor, ForwardOneToOneDescriptor): 246 | model = descriptor.field.remote_field.model 247 | elif isinstance(descriptor, ReverseOneToOneDescriptor): 248 | model = descriptor.related.related_model 249 | elif isinstance(descriptor, ReverseManyToOneDescriptor): 250 | if getattr(descriptor, "reverse", None): 251 | model = descriptor.field.model 252 | elif getattr(descriptor, "through", None): 253 | model = descriptor.field.related_model 254 | else: 255 | model = descriptor.field.model 256 | else: 257 | raise NotImplementedError(f"Descriptor {descriptor.__class__.__name__} has not been implemented") 258 | return model.objects.all() 259 | elif isinstance(prefetch, Prefetch): 260 | return prefetch.queryset 261 | else: 262 | raise NotImplementedError(f"finding queryset for prefetch type {type(prefetch)} has not been implemented") 263 | 264 | def add_sideloading_prefetches(self, queryset, request, relations_to_sideload): 265 | # Iterate over the prefetches of the original queryset and modify them 266 | view_prefetches = {} 267 | for prefetch in queryset._prefetch_related_lookups: 268 | self._add_prefetch(prefetches=view_prefetches, prefetch=prefetch, request=request) 269 | original_prefetches = [v for k, v in sorted(view_prefetches.items())] 270 | 271 | # find applicable prefetches 272 | gathered_prefetches = self._get_relevant_prefetches( 273 | relations_to_sideload=relations_to_sideload, 274 | gathered_prefetches=view_prefetches, 275 | request=request, 276 | ) 277 | 278 | # replace prefetches if any change made 279 | prefetches = [v for k, v in sorted(gathered_prefetches.items())] 280 | if prefetches != original_prefetches: 281 | if original_prefetches: 282 | queryset = queryset.prefetch_related(None) 283 | queryset = queryset.prefetch_related(*prefetches) 284 | return queryset 285 | 286 | # modified DRF methods 287 | 288 | def retrieve(self, request, *args, **kwargs): 289 | if not isinstance(self, RetrieveModelMixin): 290 | # The viewset does not have RetrieveModelMixin and therefore the method is not allowed 291 | return self.http_method_not_allowed(request, *args, **kwargs) 292 | 293 | relations_to_sideload = self.get_relations_to_sideload(request=request) 294 | if not relations_to_sideload: 295 | try: 296 | return super().retrieve(request=request, *args, **kwargs) 297 | except AttributeError as exc: 298 | if "super' object has no attribute 'retrieve'" in exc.args[0]: 299 | # self.retrieve() method was not declared before this mixin. 300 | # Make sure the SideloadableRelationsMixin is defined higher than RetrieveModelMixin. 301 | return self.http_method_not_allowed(request, *args, **kwargs) 302 | raise exc 303 | 304 | # return object with sideloading serializer 305 | queryset = self.get_sideloadable_object_as_queryset( 306 | request=request, 307 | relations_to_sideload=relations_to_sideload, 308 | ) 309 | sideloadable_page = self.get_sideloadable_page_from_queryset( 310 | queryset=queryset, 311 | relations_to_sideload=relations_to_sideload, 312 | ) 313 | serializer = self.get_sideloading_serializer( 314 | instance=sideloadable_page, 315 | relations_to_sideload=relations_to_sideload, 316 | context={"request": request}, 317 | ) 318 | return Response(serializer.data) 319 | 320 | def list(self, request, *args, **kwargs): 321 | if not isinstance(self, ListModelMixin): 322 | # The viewset does not have ListModelMixin and therefore the method is not allowed 323 | return self.http_method_not_allowed(request, *args, **kwargs) 324 | 325 | relations_to_sideload = self.get_relations_to_sideload(request=request) 326 | if not relations_to_sideload: 327 | try: 328 | return super().list(request=request, *args, **kwargs) 329 | except AttributeError as exc: 330 | if "super' object has no attribute 'list'" in exc.args[0]: 331 | # self.list() method was not declared before this mixin. 332 | # Make sure the SideloadableRelationsMixin is defined higher than ListModelMixin. 333 | return self.http_method_not_allowed(request, *args, **kwargs) 334 | raise exc 335 | 336 | # After this `relations_to_sideload` is safe to use 337 | queryset = self.get_queryset() 338 | queryset = self.add_sideloading_prefetches( 339 | queryset=queryset, 340 | request=request, 341 | relations_to_sideload=relations_to_sideload, 342 | ) 343 | queryset = self.filter_queryset(queryset) 344 | 345 | # Create page 346 | page = self.paginate_queryset(queryset) 347 | if page is not None: 348 | sideloadable_page = self.get_sideloadable_page( 349 | page=page, 350 | relations_to_sideload=relations_to_sideload, 351 | ) 352 | serializer = self.get_sideloading_serializer( 353 | instance=sideloadable_page, 354 | relations_to_sideload=relations_to_sideload, 355 | context={"request": request}, 356 | ) 357 | return self.get_paginated_response(serializer.data) 358 | else: 359 | sideloadable_page = self.get_sideloadable_page_from_queryset( 360 | queryset=queryset, 361 | relations_to_sideload=relations_to_sideload, 362 | ) 363 | serializer = self.get_sideloading_serializer( 364 | instance=sideloadable_page, 365 | relations_to_sideload=relations_to_sideload, 366 | context={"request": request}, 367 | ) 368 | return Response(serializer.data) 369 | 370 | def get_sideloadable_page_from_queryset(self, queryset, relations_to_sideload: Dict): 371 | """ 372 | Populates page with sideloaded data by collecting ids form sideloaded values and then making into a query 373 | """ 374 | 375 | if not relations_to_sideload: 376 | raise ValueError("relations_to_sideload is required") 377 | # this works wonders, but can't be used when page is paginated... 378 | sideloadable_page = {self.primary_field_name: queryset} 379 | 380 | for relation, source_keys in relations_to_sideload.items(): 381 | field = self.sideloadable_fields[relation] 382 | field_source = field.child.source 383 | source_model = field.child.Meta.model 384 | relation_key = field_source or relation 385 | 386 | related_ids = set() 387 | sideloadable_field_source = self.sideloadable_field_sources.get(relation) 388 | if isinstance(sideloadable_field_source, dict): 389 | for src_key, src in sideloadable_field_source.items(): 390 | if src_key in source_keys or source_keys is None or src_key == "__all__": 391 | related_ids |= set(queryset.values_list(src, flat=True)) 392 | else: 393 | prefetch_key = field_source or self.sideloadable_field_sources[relation] 394 | prefetch_object = next( 395 | (x for x in queryset._prefetch_related_lookups if getattr(x, "prefetch_to", None) == prefetch_key), 396 | None, 397 | ) 398 | if prefetch_key in queryset._prefetch_related_lookups: 399 | related_ids |= set(queryset.values_list(prefetch_key, flat=True)) 400 | elif prefetch_object: 401 | if prefetch_object.queryset: 402 | # performance thing? 403 | # related_ids |= set( 404 | # prefetch_object.queryset.filter( 405 | # id__in=(queryset.values_list(prefetch_key, flat=True)) 406 | # ).values_list("id", flat=True) 407 | # ) 408 | 409 | for obj in queryset.all(): 410 | prefetched_data = getattr(obj, prefetch_key) 411 | if prefetched_data.__class__.__name__ in [ 412 | "ManyRelatedManager", 413 | "RelatedManager", 414 | ]: 415 | related_ids |= set(prefetched_data.values_list("id", flat=True)) 416 | elif isinstance(prefetched_data, models.Model): 417 | related_ids.add(prefetched_data.id) 418 | elif isinstance(prefetched_data, list): 419 | try: 420 | related_ids |= set(x.id for x in prefetched_data) 421 | except AttributeError: 422 | related_ids |= set(prefetched_data) 423 | elif prefetched_data: 424 | raise ValueError("???") 425 | 426 | else: 427 | related_ids |= set(queryset.values_list(prefetch_key, flat=True)) 428 | else: 429 | raise ValueError(f"No prefetch for {prefetch_key} found!") 430 | 431 | sideloadable_page[relation_key] = source_model.objects.filter(id__in=related_ids) 432 | 433 | return sideloadable_page 434 | 435 | def get_sideloadable_page(self, page, relations_to_sideload: Dict): 436 | """ 437 | Populates page with sideloaded data by collecting distinct values form sideloaded data 438 | """ 439 | sideloadable_page = {self.primary_field_name: page} 440 | for relation, source_keys in relations_to_sideload.items(): 441 | field = self.sideloadable_fields[relation] 442 | field_source = field.child.source 443 | relation_key = field_source or relation 444 | 445 | if not isinstance(field, ListSerializer): 446 | raise RuntimeError("SideLoadable field '{}' must be set as many=True".format(relation)) 447 | 448 | if relation not in sideloadable_page: 449 | sideloadable_page[relation_key] = set() 450 | 451 | if isinstance(self.sideloadable_field_sources.get(relation), dict): 452 | # Multi source relation 453 | for src_key, source_prefetch in self.sideloadable_field_sources[relation].items(): 454 | if not source_keys or src_key in source_keys: 455 | sideloadable_page[relation_key] |= self.filter_related_objects( 456 | related_objects=page, lookup=source_prefetch 457 | ) 458 | else: 459 | sideloadable_page[relation_key] |= self.filter_related_objects( 460 | related_objects=page, lookup=field_source or self.sideloadable_field_sources[relation] 461 | ) 462 | 463 | return sideloadable_page 464 | 465 | def get_sideloadable_object_as_queryset(self, request, relations_to_sideload): 466 | """ 467 | mimics DRF original method get_object() 468 | Returns the object the view is displaying with sideloaded models prefetched. 469 | 470 | You may want to override this if you need to provide non-standard 471 | queryset lookups. Eg if objects are referenced using multiple 472 | keyword arguments in the url conf. 473 | """ 474 | # Add prefetches if applicable 475 | queryset = self.get_queryset() 476 | queryset = self.add_sideloading_prefetches( 477 | queryset=queryset, 478 | request=request, 479 | relations_to_sideload=relations_to_sideload, 480 | ) 481 | queryset = self.filter_queryset(queryset) 482 | 483 | # Perform the lookup filtering. 484 | lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field 485 | 486 | assert lookup_url_kwarg in self.kwargs, ( 487 | "Expected view %s to be called with a URL keyword argument " 488 | 'named "%s". Fix your URL conf, or set the `.lookup_field` ' 489 | "attribute on the view correctly." % (self.__class__.__name__, lookup_url_kwarg) 490 | ) 491 | 492 | filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} 493 | try: 494 | queryset = queryset.filter(**filter_kwargs) 495 | except (TypeError, ValueError, DjangoValidationError): 496 | raise Http404 497 | 498 | # check single object fetched 499 | obj = get_object_or_404(queryset) 500 | # May raise a permission denied 501 | self.check_object_permissions(self.request, obj) 502 | 503 | return queryset 504 | 505 | def filter_related_objects(self, related_objects, lookup: Optional[str]) -> Set: 506 | current_lookup, remaining_lookup = lookup.split("__", 1) if "__" in lookup else (lookup, None) 507 | lookup_values = [ 508 | getattr(r, current_lookup) for r in related_objects if getattr(r, current_lookup, None) is not None 509 | ] 510 | 511 | if lookup_values: 512 | if lookup_values[0].__class__.__name__ in ["ManyRelatedManager", "RelatedManager"]: 513 | # FIXME: apply filtering here! 514 | related_objects_set = set(chain(*[related_queryset.all() for related_queryset in lookup_values])) 515 | elif isinstance(lookup_values[0], list): 516 | related_objects_set = set(chain(*[related_list for related_list in lookup_values])) 517 | else: 518 | related_objects_set = set(lookup_values) 519 | else: 520 | related_objects_set = set() 521 | 522 | if remaining_lookup: 523 | return self.filter_related_objects(related_objects=related_objects_set, lookup=remaining_lookup) 524 | return set(related_objects_set) - {"", None} 525 | 526 | # internal_methods: 527 | 528 | def _clean_prefetches(self, field, relation, value, ensure_list=False): 529 | if not value: 530 | raise ValueError(f"Sideloadable field '{relation}' prefetch or source must be set!") 531 | elif isinstance(value, str): 532 | cleaned_value = value 533 | elif isinstance(value, list): 534 | cleaned_value = [self._clean_prefetches(field=field, relation=relation, value=val) for val in value] 535 | # filter out empty values 536 | cleaned_value = [val for val in cleaned_value if val] 537 | elif isinstance(value, dict): 538 | if "lookup" not in value: 539 | raise ValueError(f"Sideloadable field '{relation}' Prefetch 'lookup' must be set!") 540 | if value.get("to_attr") and field.child.source and field.child.source != value.get("to_attr"): 541 | raise ValueError( 542 | f"Sideloadable field '{relation}' Prefetch 'to_attr' can't be used with source defined. " 543 | f"Remove source from field serializer." 544 | ) 545 | if value.get("queryset") or value.get("to_attr"): 546 | if not value.get("to_attr"): 547 | value["to_attr"] = relation 548 | cleaned_value = Prefetch(**value) 549 | else: 550 | cleaned_value = value["lookup"] 551 | elif isinstance(value, Prefetch): 552 | # check that Prefetch.to_attr is set the same as the field.source! 553 | if value.to_attr and field.child.source and field.child.source != value.to_attr: 554 | raise ValueError( 555 | f"Sideloadable field '{relation}' Prefetch 'to_attr' can't be different from source defined. " 556 | f"Tip: Remove source from field serializer." 557 | ) 558 | cleaned_value = value 559 | else: 560 | raise ValueError("Sideloadable prefetch values must be a list of strings or Prefetch objects") 561 | 562 | if ensure_list: 563 | if not cleaned_value: 564 | return [] 565 | elif not isinstance(cleaned_value, list): 566 | return [cleaned_value] 567 | 568 | return cleaned_value 569 | 570 | def _gather_all_prefetches(self) -> Dict: 571 | """ 572 | this method finds all prefetches required and checks if they are correctly defined 573 | """ 574 | cleaned_prefetches = {} 575 | 576 | if not self.sideloadable_fields: 577 | raise ValueError("Sideloading serializer has not been initialized") 578 | 579 | # find prefetches for all sideloadable relations 580 | for relation, field in self.sideloadable_fields.items(): 581 | user_prefetches = self.user_defined_prefetches.get(relation) 582 | field_source = field.child.source 583 | if relation in self.user_defined_prefetches and not user_prefetches: 584 | raise ValueError(f"prefetches for field '{relation}' have been left empty") 585 | elif not user_prefetches: 586 | if field_source: 587 | # default to field source if not defined by user 588 | cleaned_prefetches[relation] = [field_source] 589 | elif getattr(self.primary_field.child.Meta.model, relation, None): 590 | # default to parent serializer model field with the relation name if it exists 591 | cleaned_prefetches[relation] = [relation] 592 | else: 593 | raise ValueError(f"Either source or prefetches must be set for sideloadable field '{relation}'") 594 | elif isinstance(user_prefetches, (str, list, Prefetch)): 595 | cleaned_prefetches[relation] = self._clean_prefetches( 596 | field=field, relation=relation, value=user_prefetches, ensure_list=True 597 | ) 598 | elif isinstance(user_prefetches, dict): 599 | # This is a multi source field! 600 | # make prefetches for all relations separately 601 | cleaned_prefetches[relation] = {} 602 | for rel, rel_prefetches in user_prefetches.items(): 603 | relation_prefetches = self._clean_prefetches( 604 | field=field, relation=rel, value=rel_prefetches, ensure_list=True 605 | ) 606 | cleaned_prefetches[relation][rel] = relation_prefetches 607 | else: 608 | raise NotImplementedError(f"prefetch with type '{type(user_prefetches)}' is not implemented") 609 | 610 | return cleaned_prefetches 611 | 612 | def add_sideloading_prefetch_filter(self, source, queryset, request): 613 | """ 614 | This method is intended to e overwritten in case the user wants to implement 615 | their own filters based on the related model or the relationship to the base model 616 | 617 | source - string path to the value that is sideloded. 618 | queryset - QuerySet that you can add filtering to 619 | 620 | Example: 621 | 622 | add_sideloading_prefetch_filter(self, source, queryset, request): 623 | if source == "model1__relation1": 624 | return queryset.filter(is_active=True), True 625 | if hasattr(queryset, "readable"): 626 | return queryset.readable(user=request.user), True 627 | return queryset, False 628 | 629 | """ 630 | 631 | return queryset, False 632 | 633 | def _add_sideloading_filter(self, prefetch: Union[str, Prefetch], request) -> Union[str, Prefetch]: 634 | # fetch sideloadable source and queryset 635 | prefetch_source = self.get_source_from_prefetch(prefetches=prefetch) 636 | prefetch_queryset = self.get_sideloadable_queryset(prefetch) 637 | filtered_queryset, added = self.add_sideloading_prefetch_filter( 638 | source=prefetch_source, queryset=prefetch_queryset, request=request 639 | ) 640 | if added: 641 | filter_node = self.add_sideloading_prefetch_filter( 642 | source=prefetch_source, queryset=prefetch_queryset.model.objects.all(), request=request 643 | )[0].query.where 644 | if filter_node: # check if any filtering is actually applied 645 | if isinstance(prefetch, str): 646 | # Replace string prefetch with a filtered one 647 | prefetch = Prefetch(lookup=prefetch, queryset=filtered_queryset) 648 | elif isinstance(prefetch, Prefetch): 649 | # add filters if not already applied 650 | if not contains_where_node(existing_node=prefetch_queryset.query.where, new_node=filter_node): 651 | prefetch.queryset = filtered_queryset 652 | else: 653 | raise NotImplementedError( 654 | f"Adding filters to prefetch type {type(prefetch)} has not been implemented" 655 | ) 656 | 657 | return prefetch 658 | 659 | def _add_prefetch(self, prefetches: Dict, prefetch: Union[str, Prefetch], request) -> str: 660 | # add prefetch to prefetches dict and return the prefetch_attr 661 | if not isinstance(prefetch, (str, Prefetch)): 662 | raise ValueError(f"Adding prefetch of type '{type(prefetch)}' has not been implemented") 663 | if isinstance(prefetch, str) and len(prefetch) == 1: 664 | raise ValueError("single letter prefetches are not allowed") 665 | 666 | prefetch = self._add_sideloading_filter(prefetch=prefetch, request=request) 667 | 668 | prefetch_attr = self.get_source_from_prefetch(prefetch) 669 | existing_prefetch = prefetches.get(prefetch_attr) 670 | if not existing_prefetch: 671 | prefetches[prefetch_attr] = prefetch 672 | elif isinstance(existing_prefetch, str): 673 | if isinstance(prefetch, str): 674 | if prefetch != existing_prefetch: 675 | raise ValueError("Got different string prefetches to the same attribute name") 676 | elif isinstance(prefetch, Prefetch): 677 | if prefetch.queryset.query.where: 678 | raise ValueError( 679 | f"Can't add filtered Prefetch '{prefetch_attr}'. Existing prefetch does not have filters. " 680 | "APIView might have an unfiltered prefetch_related that sideloading is trying to filter." 681 | ) 682 | # Do nothing, as no filters where applied, leave the prefetch as a string 683 | else: 684 | raise NotImplementedError(f"overwriting existing string prefetch wit type {type(prefetch)}") 685 | elif isinstance(existing_prefetch, Prefetch): 686 | if isinstance(prefetch, str): 687 | if existing_prefetch.queryset.query.where: 688 | raise ValueError( 689 | f"Can't add non-filtered prefetch '{prefetch_attr}'. Existing Prefetch has filters applied. " 690 | "Sideloading serializer tries to apply a non-filtered prefetch to a previously filtered " 691 | "prefetch" 692 | ) 693 | # Don't make any changes as the Prefetch does not have filters 694 | elif isinstance(prefetch, Prefetch): 695 | if prefetch.queryset.model != existing_prefetch.queryset.model: 696 | raise ValueError( 697 | f"Can't add filtered Prefetch '{prefetch_attr}'. Existing Prefetch has a different model." 698 | ) 699 | if set(prefetch.queryset.query.where.children) != set(existing_prefetch.queryset.query.where.children): 700 | raise ValueError( 701 | f"Can't add filtered Prefetch '{prefetch_attr}'. " 702 | "Existing Prefetch has different filters applied. " 703 | "Check that sideloading serializer and view prefetch_related values don't clash" 704 | ) 705 | # Don't make any changes as the filters have to match each other 706 | else: 707 | raise NotImplementedError(f"overwriting existing Prefetch with type {type(prefetch)}") 708 | else: 709 | raise NotImplementedError(f"Adding prefetch of type '{type(prefetch)}' has not been implemented") 710 | 711 | return prefetch_attr 712 | 713 | def _get_relevant_prefetches(self, relations_to_sideload: Dict, request, gathered_prefetches: Dict = None) -> Dict: 714 | """ 715 | Collects all relevant prefetches and returns 716 | compressed prefetches and sources per relation to be used later. 717 | """ 718 | 719 | if gathered_prefetches is None: 720 | gathered_prefetches = {} 721 | 722 | # cleaned prefetches 723 | cleaned_prefetches = self._gather_all_prefetches() 724 | 725 | if not relations_to_sideload: 726 | raise ValueError("'relations_to_sideload' is a required argument") 727 | if not cleaned_prefetches: 728 | raise ValueError("'cleaned_prefetches' is a required argument") 729 | 730 | for relation, requested_sources in relations_to_sideload.items(): 731 | relation_prefetches = cleaned_prefetches.get(relation) 732 | if requested_sources: 733 | for source in requested_sources: 734 | for source_prefetch in relation_prefetches[source]: 735 | self._add_prefetch(prefetches=gathered_prefetches, prefetch=source_prefetch, request=request) 736 | elif isinstance(relation_prefetches, dict): 737 | for source_prefetches in relation_prefetches.values(): 738 | for source_prefetch in source_prefetches: 739 | self._add_prefetch(prefetches=gathered_prefetches, prefetch=source_prefetch, request=request) 740 | else: 741 | for relation_prefetch in relation_prefetches: 742 | self._add_prefetch(prefetches=gathered_prefetches, prefetch=relation_prefetch, request=request) 743 | 744 | return gathered_prefetches 745 | -------------------------------------------------------------------------------- /tests/test_products_api.py: -------------------------------------------------------------------------------- 1 | from django.db.models import Prefetch 2 | from django.test import TestCase 3 | from django.urls import reverse 4 | from rest_framework import status, serializers 5 | from rest_framework.permissions import BasePermission 6 | from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer 7 | from rest_framework.settings import api_settings 8 | 9 | from drf_sideloading.serializers import SideLoadableSerializer 10 | from tests.models import Category, Supplier, Product, Partner, ProductMetadata, SupplierMetadata 11 | from tests.serializers import ( 12 | ProductSerializer, 13 | CategorySerializer, 14 | SupplierSerializer, 15 | PartnerSerializer, 16 | ProductMetadataSerializer, 17 | ) 18 | from tests.viewsets import ProductViewSet 19 | 20 | 21 | class BaseTestCase(TestCase): 22 | """Minimum common model setups""" 23 | 24 | DEFAULT_HEADERS = { 25 | # "content_type": "application/json", # defaults to "application/octet-stream" 26 | "HTTP_ACCEPT": "application/json", 27 | } 28 | 29 | @classmethod 30 | def setUpClass(cls): 31 | super(BaseTestCase, cls).setUpClass() 32 | 33 | def setUp(self): 34 | self.category = Category.objects.create(name="Category") 35 | self.supplier1 = Supplier.objects.create(name="Supplier1") 36 | self.supplier_metadata_1 = SupplierMetadata.objects.create( 37 | supplier=self.supplier1, properties="Supplier1 metadata" 38 | ) 39 | self.supplier2 = Supplier.objects.create(name="Supplier2") 40 | self.supplier_metadata_2 = SupplierMetadata.objects.create( 41 | supplier=self.supplier2, properties="Supplier2 metadata" 42 | ) 43 | self.supplier3 = Supplier.objects.create(name="Supplier3") 44 | self.supplier_metadata_3 = SupplierMetadata.objects.create( 45 | supplier=self.supplier3, properties="Supplier3 metadata" 46 | ) 47 | self.supplier4 = Supplier.objects.create(name="Supplier4") 48 | self.supplier_metadata_4 = SupplierMetadata.objects.create( 49 | supplier=self.supplier4, properties="Supplier4 metadata" 50 | ) 51 | self.partner1 = Partner.objects.create(name="Partner1") 52 | self.partner2 = Partner.objects.create(name="Partner2") 53 | self.partner3 = Partner.objects.create(name="Partner3") 54 | self.partner4 = Partner.objects.create(name="Partner4") 55 | 56 | self.product1 = Product.objects.create(name="Product1", category=self.category, supplier=self.supplier1) 57 | self.product1_metadata = ProductMetadata.objects.create(product=self.product1, properties="value 1") 58 | self.product1.partners.add(self.partner1) 59 | self.product1.partners.add(self.partner2) 60 | self.product1.partners.add(self.partner4) 61 | self.product1.save() 62 | 63 | self.product2 = Product.objects.create(name="Product2", category=self.category, supplier=self.supplier2) 64 | self.product2_metadata = ProductMetadata.objects.create(product=self.product2, properties="value 2") 65 | self.product2.partners.add(self.partner2) 66 | self.product2.save() 67 | 68 | self.product3 = Product.objects.create(name="Product3", category=self.category, supplier=self.supplier3) 69 | self.product3_metadata = ProductMetadata.objects.create(product=self.product3, properties="value 3") 70 | self.product3.partners.add(self.partner3) 71 | self.product3.save() 72 | 73 | self.product4 = Product.objects.create(name="Product4", category=self.category, supplier=self.supplier4) 74 | self.product4_metadata = ProductMetadata.objects.create(product=self.product4, properties="value 4") 75 | 76 | 77 | ################################### 78 | # Different Correct usages of API # 79 | ################################### 80 | class ProductSideloadTestCase(BaseTestCase): 81 | @classmethod 82 | def setUpClass(cls): 83 | super(ProductSideloadTestCase, cls).setUpClass() 84 | 85 | class TempProductSideloadableSerializer(SideLoadableSerializer): 86 | products = ProductSerializer(many=True) 87 | categories = CategorySerializer(source="category", many=True) 88 | suppliers = SupplierSerializer(source="supplier", many=True) 89 | partners = PartnerSerializer(many=True) 90 | partners = PartnerSerializer(many=True) 91 | metadata = ProductMetadataSerializer(many=True) 92 | 93 | class Meta: 94 | primary = "products" 95 | prefetches = { 96 | "categories": "category", 97 | "suppliers": ["supplier", "supplier__metadata"], 98 | "partners": "partners", 99 | "metadata": "metadata", 100 | } 101 | 102 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 103 | 104 | def test_list(self): 105 | response = self.client.get(path=reverse("product-list"), **self.DEFAULT_HEADERS) 106 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 107 | self.assertIsInstance(response.json(), list) 108 | self.assertEqual(4, len(response.json())) 109 | self.assertEqual("Product1", response.json()[0]["name"]) 110 | 111 | def test_list_sideloading(self): 112 | """Test sideloading for all defined relations""" 113 | response = self.client.get( 114 | path=reverse("product-list"), 115 | data={"sideload": "categories,suppliers,partners,metadata"}, 116 | **self.DEFAULT_HEADERS, 117 | ) 118 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 119 | self.assertIsInstance(response.json(), dict) 120 | self.assertListEqual( 121 | ["products", "categories", "suppliers", "partners", "metadata"], list(response.json().keys()) 122 | ) 123 | 124 | def test_list_sideloading_with_direct_missing_one_to_one_relation(self): 125 | """Test sideloading for all defined relations""" 126 | ProductMetadata.objects.all().delete() 127 | response = self.client.get( 128 | path=reverse("product-list"), 129 | data={"sideload": "metadata"}, 130 | **self.DEFAULT_HEADERS, 131 | ) 132 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 133 | self.assertIsInstance(response.json(), dict) 134 | self.assertListEqual(["products", "metadata"], list(response.json().keys())) 135 | 136 | def test_list_sideloading_with_indirect_missing_one_to_one_relation(self): 137 | """Test sideloading for all defined relations""" 138 | SupplierMetadata.objects.all().delete() 139 | response = self.client.get( 140 | path=reverse("product-list"), 141 | data={"sideload": "suppliers"}, 142 | **self.DEFAULT_HEADERS, 143 | ) 144 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 145 | self.assertIsInstance(response.json(), dict) 146 | self.assertListEqual(["products", "suppliers"], list(response.json().keys())) 147 | 148 | def test_list_partial_sideloading(self): 149 | """Test sideloading for selected relations""" 150 | response = self.client.get( 151 | path=reverse("product-list"), data={"sideload": "suppliers,partners"}, **self.DEFAULT_HEADERS 152 | ) 153 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 154 | self.assertIsInstance(response.json(), dict) 155 | self.assertListEqual(["products", "suppliers", "partners"], list(response.json().keys())) 156 | 157 | def test_detail(self): 158 | response = self.client.get(path=reverse("product-detail", args=[self.product1.id]), **self.DEFAULT_HEADERS) 159 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 160 | self.assertIsInstance(response.json(), dict) 161 | self.assertListEqual(["name", "category", "supplier", "partners", "metadata"], list(response.json().keys())) 162 | # TODO: check details 163 | 164 | def test_detail_sideloading(self): 165 | """Test sideloading for all defined relations in detail view""" 166 | response = self.client.get( 167 | path=reverse("product-detail", args=[self.product1.id]), 168 | data={"sideload": "categories,suppliers,partners,metadata"}, 169 | **self.DEFAULT_HEADERS, 170 | ) 171 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 172 | self.assertIsInstance(response.json(), dict) 173 | self.assertListEqual( 174 | ["products", "categories", "suppliers", "partners", "metadata"], list(response.json().keys()) 175 | ) 176 | self.assertEqual(1, len(response.json().get("products"))) 177 | # TODO: check details 178 | 179 | def test_detail_sideloading_with_direct_missing_one_to_one_relation(self): 180 | """Test sideloading for all defined relations""" 181 | ProductMetadata.objects.all().delete() 182 | response = self.client.get( 183 | path=reverse("product-detail", args=[self.product1.id]), 184 | data={"sideload": "metadata"}, 185 | **self.DEFAULT_HEADERS, 186 | ) 187 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 188 | self.assertIsInstance(response.json(), dict) 189 | self.assertListEqual(["products", "metadata"], list(response.json().keys())) 190 | 191 | def test_detail_sideloading_with_indirect_missing_one_to_one_relation(self): 192 | """Test sideloading for all defined relations""" 193 | SupplierMetadata.objects.all().delete() 194 | response = self.client.get( 195 | path=reverse("product-detail", args=[self.product1.id]), 196 | data={"sideload": "suppliers"}, 197 | **self.DEFAULT_HEADERS, 198 | ) 199 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 200 | self.assertIsInstance(response.json(), dict) 201 | self.assertListEqual(["products", "suppliers"], list(response.json().keys())) 202 | 203 | def test_detail_partial_sideloading(self): 204 | """Test sideloading for selected relations in detail view""" 205 | response = self.client.get( 206 | path=reverse("product-detail", args=[self.product1.id]), 207 | data={"sideload": "suppliers,partners"}, 208 | **self.DEFAULT_HEADERS, 209 | ) 210 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 211 | self.assertIsInstance(response.json(), dict) 212 | self.assertListEqual(["products", "suppliers", "partners"], list(response.json().keys())) 213 | self.assertEqual(1, len(response.json().get("products"))) 214 | # TODO: check details 215 | 216 | # all negative test cases below only here 217 | def test_sideload_param_empty_string(self): 218 | response = self.client.get(path=reverse("product-list"), data={"sideload": ""}, **self.DEFAULT_HEADERS) 219 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 220 | self.assertIsInstance(response.json(), list) 221 | self.assertEqual(4, len(response.json())) 222 | self.assertEqual("Product1", response.json()[0]["name"]) 223 | 224 | def test_sideload_param_nonexistent_relation(self): 225 | response = self.client.get( 226 | path=reverse("product-list"), data={"sideload": "nonexistent"}, **self.DEFAULT_HEADERS 227 | ) 228 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 229 | self.assertDictEqual({"sideload": ["'nonexistent' is not one of the available choices."]}, response.json()) 230 | 231 | def test_sideload_param_nonexistent_mixed_existing_relation(self): 232 | response = self.client.get( 233 | path=reverse("product-list"), data={"sideload": "nonexistent,suppliers"}, **self.DEFAULT_HEADERS 234 | ) 235 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 236 | self.assertDictEqual({"sideload": ["'nonexistent' is not one of the available choices."]}, response.json()) 237 | 238 | def test_sideloading_param_wrongly_formed_query(self): 239 | response = self.client.get( 240 | path=reverse("product-list"), 241 | data={"sideload": "@,123,categories,123,.unexisting,123,,,,suppliers,!@"}, 242 | **self.DEFAULT_HEADERS, 243 | ) 244 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 245 | self.assertDictEqual({"sideload": ["'@' is not one of the available choices."]}, response.json()) 246 | 247 | 248 | class ProductMultiSourceSideloadTestCase(BaseTestCase): 249 | """Test sideloading multiple related fields to a signle field""" 250 | 251 | @classmethod 252 | def setUpClass(cls): 253 | super(ProductMultiSourceSideloadTestCase, cls).setUpClass() 254 | 255 | class TempProductSideloadableSerializer(SideLoadableSerializer): 256 | products = ProductSerializer(many=True) 257 | categories = CategorySerializer(source="category", many=True) 258 | main_suppliers = SupplierSerializer(source="supplier", many=True) 259 | backup_suppliers = SupplierSerializer(source="backup_supplier", many=True) 260 | partners = PartnerSerializer(many=True) 261 | combined_suppliers = SupplierSerializer(many=True) 262 | 263 | class Meta: 264 | primary = "products" 265 | prefetches = { 266 | "categories": "category", 267 | "main_suppliers": "supplier", 268 | "backup_suppliers": "backup_supplier", 269 | "partners": "partners", 270 | # These can be defined to always load them, else they will be 271 | # copied over form all sources or selected sources only. 272 | "combined_suppliers": { 273 | "suppliers": {"lookup": "supplier"}, 274 | "backup_suppliers": {"lookup": "backup_supplier"}, 275 | }, 276 | } 277 | 278 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 279 | 280 | def setUp(self): 281 | super().setUp() 282 | self.product1.backup_supplier = self.supplier4 283 | self.product1.save() 284 | self.product2.backup_supplier = self.supplier3 285 | self.product2.save() 286 | self.product3.backup_supplier = self.supplier2 287 | self.product3.save() 288 | 289 | def test_list_sideloading_all(self): 290 | """Test sideloading for all defined relations""" 291 | response = self.client.get( 292 | path=reverse("product-list"), 293 | data={ 294 | "sideload": "categories,main_suppliers,backup_suppliers,combined_suppliers,partners", 295 | "search": self.product1.name, 296 | }, 297 | **self.DEFAULT_HEADERS, 298 | ) 299 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 300 | self.assertIsInstance(response.json(), dict) 301 | self.assertListEqual( 302 | ["products", "categories", "main_suppliers", "backup_suppliers", "partners", "combined_suppliers"], 303 | list(response.json().keys()), 304 | ) 305 | self.assertEqual(1, len(response.json()["products"])) 306 | self.assertSetEqual( 307 | {self.product1.supplier.name}, 308 | {supplier["name"] for supplier in response.json()["main_suppliers"]}, 309 | ) 310 | self.assertSetEqual( 311 | {self.product1.backup_supplier.name}, 312 | {supplier["name"] for supplier in response.json()["backup_suppliers"]}, 313 | ) 314 | self.assertSetEqual( 315 | {self.product1.supplier.name, self.product1.backup_supplier.name}, 316 | {supplier["name"] for supplier in response.json()["combined_suppliers"]}, 317 | ) 318 | 319 | def test_list_sideloading_backup_suppliers(self): 320 | """Test sideloading for selected supplier relations""" 321 | response = self.client.get( 322 | path=reverse("product-list"), 323 | data={"sideload": "combined_suppliers[backup_suppliers]", "search": self.product1.name}, 324 | **self.DEFAULT_HEADERS, 325 | ) 326 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 327 | self.assertIsInstance(response.json(), dict) 328 | self.assertListEqual(["products", "combined_suppliers"], list(response.json().keys())) 329 | self.assertEqual(1, len(response.json()["products"])) 330 | self.assertSetEqual( 331 | {self.product1.backup_supplier.name}, # regular supplier should not end up here. 332 | {supplier["name"] for supplier in response.json()["combined_suppliers"]}, 333 | ) 334 | 335 | def test_list_sideloading_combined_supplier(self): 336 | """Test sideloading for all supplier relations""" 337 | response = self.client.get( 338 | path=reverse("product-list"), 339 | data={"sideload": "combined_suppliers", "search": self.product1.name}, 340 | **self.DEFAULT_HEADERS, 341 | ) 342 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 343 | self.assertIsInstance(response.json(), dict) 344 | self.assertListEqual(["products", "combined_suppliers"], list(response.json().keys())) 345 | self.assertEqual(1, len(response.json()["products"])) 346 | self.assertSetEqual( 347 | { 348 | self.product1.supplier.name, 349 | self.product1.backup_supplier.name, 350 | }, # regular supplier should not end up here. 351 | {supplier["name"] for supplier in response.json()["combined_suppliers"]}, 352 | ) 353 | 354 | def test_list_sideloading_combined_supplier_with_filtered_prefetch(self): 355 | """Test sideloading for all supplier relations""" 356 | response = self.client.get( 357 | path=reverse("product-list"), 358 | data={"sideload": "combined_suppliers", "search": self.product1.name}, 359 | **self.DEFAULT_HEADERS, 360 | ) 361 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 362 | self.assertIsInstance(response.json(), dict) 363 | self.assertListEqual(["products", "combined_suppliers"], list(response.json().keys())) 364 | self.assertEqual(1, len(response.json()["products"])) 365 | self.assertSetEqual( 366 | { 367 | self.product1.supplier.name, 368 | self.product1.backup_supplier.name, 369 | }, # regular supplier should not end up here. 370 | {supplier["name"] for supplier in response.json()["combined_suppliers"]}, 371 | ) 372 | 373 | 374 | ################################### 375 | # Different Correct usages of API # 376 | ################################### 377 | class CategorySideloadTestCase(BaseTestCase): 378 | def test_list(self): 379 | response = self.client.get(path=reverse("category-list"), data={}, **self.DEFAULT_HEADERS) 380 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 381 | self.assertIsInstance(response.json(), list) 382 | self.assertEqual(1, len(response.json())) 383 | self.assertEqual("Category", response.json()[0]["name"]) 384 | 385 | def test_list_sideloading_with_reverse_relations_and_its_relations(self): 386 | """Test sideloading for all defined relations""" 387 | response = self.client.get( 388 | path=reverse("category-list"), data={"sideload": "products,suppliers,partners"}, **self.DEFAULT_HEADERS 389 | ) 390 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 391 | self.assertIsInstance(response.json(), dict) 392 | self.assertListEqual(["categories", "products", "suppliers", "partners"], list(response.json().keys())) 393 | 394 | def test_list_sideloading_with_reverse_relations_relations_without_the_reverse_relation_itself(self): 395 | """Test sideloading for related items to products, that are related to the categories 396 | while the products list itself is not sideloaded""" 397 | response = self.client.get( 398 | path=reverse("category-list"), data={"sideload": "suppliers,partners"}, **self.DEFAULT_HEADERS 399 | ) 400 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 401 | self.assertIsInstance(response.json(), dict) 402 | self.assertListEqual(["categories", "suppliers", "partners"], list(response.json().keys())) 403 | 404 | 405 | ###################################################################################### 406 | # Incorrect definitions sideloadable_relations in ViewSet and SideloadableSerializer # 407 | ###################################################################################### 408 | class TestDrfSideloadingNoMetaClassDefined(BaseTestCase): 409 | """Run tests while including mixin but not defining sideloading""" 410 | 411 | @classmethod 412 | def setUpClass(cls): 413 | super(TestDrfSideloadingNoMetaClassDefined, cls).setUpClass() 414 | 415 | class TempProductSideloadableSerializer(SideLoadableSerializer): 416 | products = ProductSerializer(many=True) 417 | categories = CategorySerializer(many=True) 418 | suppliers = SupplierSerializer(many=True) 419 | partners = PartnerSerializer(many=True) 420 | 421 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 422 | 423 | def test_correct_exception_raised(self): 424 | expected_error_message = "Sideloadable serializer must have a Meta class defined with the 'primary' field name!" 425 | with self.assertRaisesMessage(ValueError, expected_error_message): 426 | self.client.get( 427 | path=reverse("product-list"), 428 | data={"sideload": "categories,suppliers,partners"}, 429 | **self.DEFAULT_HEADERS, 430 | ) 431 | 432 | 433 | class TestDrfSideloadingNoPrimaryDefined(BaseTestCase): 434 | """Run tests with invalid sideloadabale serializer setup (Meta primary_field not set)""" 435 | 436 | @classmethod 437 | def setUpClass(cls): 438 | super(TestDrfSideloadingNoPrimaryDefined, cls).setUpClass() 439 | 440 | class TempProductSideloadableSerializer(SideLoadableSerializer): 441 | products = ProductSerializer(many=True) 442 | categories = CategorySerializer(many=True) 443 | suppliers = SupplierSerializer(many=True) 444 | partners = PartnerSerializer(many=True) 445 | 446 | class Meta: 447 | pass 448 | 449 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 450 | 451 | def test_correct_exception_raised(self): 452 | expected_error_message = "Sideloadable serializer must have a Meta class defined with the 'primary' field name!" 453 | with self.assertRaisesMessage(ValueError, expected_error_message): 454 | self.client.get( 455 | path=reverse("product-list"), 456 | data={"sideload": "categories,suppliers,partners"}, 457 | **self.DEFAULT_HEADERS, 458 | ) 459 | 460 | 461 | class TestDrfSideloadingRelationsNotListSerializers(BaseTestCase): 462 | """Run tests with invalid sideloadabale serializer setup (fields not set as many=True)""" 463 | 464 | @classmethod 465 | def setUpClass(cls): 466 | super(TestDrfSideloadingRelationsNotListSerializers, cls).setUpClass() 467 | 468 | class TempProductSideloadableSerializer(SideLoadableSerializer): 469 | products = ProductSerializer(many=True) 470 | categories = CategorySerializer() 471 | suppliers = SupplierSerializer() 472 | partners = PartnerSerializer(many=True) 473 | 474 | class Meta: 475 | primary = "products" 476 | 477 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 478 | 479 | def test_correct_exception_raised(self): 480 | expected_error_message = "SideLoadable field 'categories' must be set as many=True" 481 | with self.assertRaisesMessage(ValueError, expected_error_message): 482 | self.client.get( 483 | path=reverse("product-list"), data={"sideload": "categories,suppliers,partners"}, **self.DEFAULT_HEADERS 484 | ) 485 | 486 | 487 | class TestDrfSideloadingInvalidPrimary(BaseTestCase): 488 | """Run tests with invalid sideloadabale serializer setup (invalid primary_field)""" 489 | 490 | @classmethod 491 | def setUpClass(cls): 492 | super(TestDrfSideloadingInvalidPrimary, cls).setUpClass() 493 | 494 | class TempProductSideloadableSerializer(SideLoadableSerializer): 495 | products = ProductSerializer(many=True) 496 | categories = CategorySerializer(many=True) 497 | suppliers = SupplierSerializer(many=True) 498 | partners = PartnerSerializer(many=True) 499 | 500 | class Meta: 501 | primary = "other" 502 | 503 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 504 | 505 | def test_correct_exception_raised(self): 506 | expected_error_message = "Sideloadable serializer Meta.primary must point to a field in the serializer!" 507 | with self.assertRaisesMessage(ValueError, expected_error_message): 508 | self.client.get( 509 | path=reverse("product-list"), 510 | data={"sideload": "categories,suppliers,partners"}, 511 | **self.DEFAULT_HEADERS, 512 | ) 513 | 514 | 515 | class TestDrfSideloadingInvalidPrefetchesType(BaseTestCase): 516 | """Run tests with invalid sideloadabale serializer setup (prefetches not described as a dict)""" 517 | 518 | @classmethod 519 | def setUpClass(cls): 520 | super(TestDrfSideloadingInvalidPrefetchesType, cls).setUpClass() 521 | 522 | class TempProductSideloadableSerializer(SideLoadableSerializer): 523 | products = ProductSerializer(many=True) 524 | categories = CategorySerializer(many=True) 525 | suppliers = SupplierSerializer(many=True) 526 | partners = PartnerSerializer(many=True) 527 | 528 | class Meta: 529 | primary = "products" 530 | prefetches = ( 531 | ("categories", "category"), 532 | ("suppliers", "supplier"), 533 | ("partners", "partners"), 534 | ) 535 | 536 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 537 | 538 | def test_correct_exception_raised(self): 539 | expected_error_message = "Sideloadable serializer Meta attribute 'prefetches' must be a dict." 540 | with self.assertRaisesMessage(ValueError, expected_error_message): 541 | self.client.get( 542 | path=reverse("product-list"), 543 | data={"sideload": "categories,suppliers,partners"}, 544 | **self.DEFAULT_HEADERS, 545 | ) 546 | 547 | 548 | class TestDrfSideloadingInvalidPrefetchesValuesType(BaseTestCase): 549 | """Run tests with invalid sideloadabale serializer setup (invalid prefetch types)""" 550 | 551 | @classmethod 552 | def setUpClass(cls): 553 | super(TestDrfSideloadingInvalidPrefetchesValuesType, cls).setUpClass() 554 | 555 | class TempProductSideloadableSerializer(SideLoadableSerializer): 556 | products = ProductSerializer(many=True) 557 | categories = CategorySerializer(many=True) 558 | suppliers = SupplierSerializer(many=True) 559 | partners = PartnerSerializer(many=True) 560 | 561 | class Meta: 562 | primary = "products" 563 | prefetches = { 564 | "categories": "category", 565 | "suppliers": ["supplier"], 566 | "partners": 123, 567 | } 568 | 569 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 570 | 571 | def test_correct_exception_raised(self): 572 | expected_error_message = "prefetch with type '' is not implemented" 573 | with self.assertRaisesMessage(NotImplementedError, expected_error_message): 574 | self.client.get( 575 | path=reverse("product-list"), 576 | data={"sideload": "categories,suppliers,partners"}, 577 | **self.DEFAULT_HEADERS, 578 | ) 579 | 580 | 581 | class TestDrfSideloadingValidPrefetches(BaseTestCase): 582 | """ 583 | Run tests with prefetch is user defined and another prefetch for the same relation is also created. 584 | Preftch.to_attr is not set. This should be automatically be set by our code. 585 | """ 586 | 587 | @classmethod 588 | def setUpClass(cls): 589 | super(TestDrfSideloadingValidPrefetches, cls).setUpClass() 590 | 591 | class TempProductSideloadableSerializer(SideLoadableSerializer): 592 | products = ProductSerializer(many=True) 593 | categories = CategorySerializer(source="category", many=True) 594 | suppliers = SupplierSerializer(source="supplier", many=True) 595 | filtered_suppliers = SupplierSerializer(many=True) 596 | filtered_suppliers2 = SupplierSerializer(many=True) 597 | partners = PartnerSerializer(many=True) 598 | filtered_partners = PartnerSerializer(many=True) 599 | combined_suppliers = SupplierSerializer(many=True) 600 | 601 | class Meta: 602 | primary = "products" 603 | prefetches = { 604 | "categories": "category", 605 | "suppliers": ["supplier"], 606 | "partners": ["partners"], 607 | "filtered_suppliers": Prefetch( 608 | lookup="supplier", 609 | queryset=Supplier.objects.filter(name__in=["Supplier2", "Supplier4"]), 610 | to_attr="filtered_suppliers", 611 | ), 612 | "filtered_suppliers2": Prefetch( 613 | lookup="supplier", 614 | queryset=Supplier.objects.filter(name__in=["Supplier3"]), 615 | to_attr="filtered_suppliers2", 616 | ), 617 | "filtered_partners": Prefetch( 618 | lookup="partners", 619 | queryset=Partner.objects.filter(name__in=["Partner2", "Partner4"]), 620 | to_attr="filtered_partners", 621 | ), 622 | "combined_suppliers": { 623 | "filtered_suppliers": Prefetch( 624 | lookup="supplier", 625 | queryset=Supplier.objects.filter(name__in=["Supplier2", "Supplier4"]), 626 | to_attr="filtered_suppliers", 627 | ), 628 | "filtered_suppliers2": Prefetch( 629 | lookup="supplier", 630 | queryset=Supplier.objects.filter(name__in=["Supplier3"]), 631 | to_attr="filtered_suppliers2", 632 | ), 633 | "filtered_partners": Prefetch( 634 | lookup="partners", 635 | queryset=Partner.objects.filter(name__in=["Partner2", "Partner4"]), 636 | to_attr="filtered_partners", 637 | ), 638 | }, 639 | } 640 | 641 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 642 | 643 | def test_sideloading_with_dual_usage_prefetches(self): 644 | response_2 = self.client.get( 645 | path=reverse("product-list"), 646 | data={"sideload": "categories,suppliers,partners"}, 647 | **self.DEFAULT_HEADERS, 648 | ) 649 | self.assertEqual(response_2.status_code, status.HTTP_200_OK, response_2.data) 650 | self.assertIsInstance(response_2.json(), dict) 651 | self.assertListEqual(["products", "categories", "suppliers", "partners"], list(response_2.json().keys())) 652 | # check suppliers and partners are the same as from previous query! 653 | supplier_names = {supplier["name"] for supplier in response_2.json()["suppliers"]} 654 | self.assertSetEqual({"Supplier1", "Supplier2", "Supplier3", "Supplier4"}, supplier_names) 655 | partner_names = {partner["name"] for partner in response_2.json()["partners"]} 656 | self.assertSetEqual({"Partner1", "Partner2", "Partner3", "Partner4"}, partner_names) 657 | 658 | def test_sideloading_normally(self): 659 | response_1 = self.client.get( 660 | path=reverse("product-list"), 661 | data={"sideload": "categories,suppliers,filtered_suppliers,partners,filtered_partners"}, 662 | **self.DEFAULT_HEADERS, 663 | ) 664 | self.assertEqual(response_1.status_code, status.HTTP_200_OK, response_1.json()) 665 | self.assertIsInstance(response_1.json(), dict) 666 | self.assertListEqual( 667 | ["products", "categories", "suppliers", "filtered_suppliers", "partners", "filtered_partners"], 668 | list(response_1.json().keys()), 669 | ) 670 | # check filtered_suppliers and filtered_partners are different from suppliers and partners! 671 | supplier_names = {partner["name"] for partner in response_1.json()["suppliers"]} 672 | self.assertSetEqual({"Supplier1", "Supplier2", "Supplier3", "Supplier4"}, supplier_names) 673 | 674 | filtered_supplier_names = {partner["name"] for partner in response_1.json()["filtered_suppliers"]} 675 | self.assertSetEqual({"Supplier2", "Supplier4"}, filtered_supplier_names) 676 | 677 | partner_names = {partner["name"] for partner in response_1.json()["partners"]} 678 | self.assertSetEqual({"Partner1", "Partner2", "Partner3", "Partner4"}, partner_names) 679 | 680 | filtered_partner_names = {partner["name"] for partner in response_1.json()["filtered_partners"]} 681 | self.assertSetEqual({"Partner2", "Partner4"}, filtered_partner_names) 682 | 683 | def test_sideloading_with_filtered_prefetch(self): 684 | response_1 = self.client.get( 685 | path=reverse("product-list"), 686 | data={"sideload": "filtered_suppliers,filtered_partners"}, 687 | **self.DEFAULT_HEADERS, 688 | ) 689 | self.assertEqual(response_1.status_code, status.HTTP_200_OK, response_1.json()) 690 | self.assertIsInstance(response_1.json(), dict) 691 | self.assertListEqual( 692 | ["products", "filtered_suppliers", "filtered_partners"], 693 | list(response_1.json().keys()), 694 | ) 695 | filtered_supplier_names = {partner["name"] for partner in response_1.json()["filtered_suppliers"]} 696 | self.assertSetEqual({"Supplier2", "Supplier4"}, filtered_supplier_names) 697 | 698 | filtered_partner_names = {partner["name"] for partner in response_1.json()["filtered_partners"]} 699 | self.assertSetEqual({"Partner2", "Partner4"}, filtered_partner_names) 700 | 701 | def test_sideloading_with_filtered_prefetches(self): 702 | response_1 = self.client.get( 703 | path=reverse("product-list"), 704 | data={"sideload": "categories,filtered_suppliers,filtered_suppliers2,filtered_partners"}, 705 | **self.DEFAULT_HEADERS, 706 | ) 707 | self.assertEqual(response_1.status_code, status.HTTP_200_OK, response_1.json()) 708 | self.assertIsInstance(response_1.json(), dict) 709 | self.assertListEqual( 710 | ["products", "categories", "filtered_suppliers", "filtered_suppliers2", "filtered_partners"], 711 | list(response_1.json().keys()), 712 | ) 713 | # check filtered_suppliers and filtered_partners are different from suppliers and partners! 714 | 715 | filtered_supplier_names = {supplier["name"] for supplier in response_1.json()["filtered_suppliers"]} 716 | self.assertSetEqual({"Supplier2", "Supplier4"}, filtered_supplier_names) 717 | 718 | filtered_supplier_names = {supplier["name"] for supplier in response_1.json()["filtered_suppliers2"]} 719 | self.assertSetEqual({"Supplier3"}, filtered_supplier_names) 720 | 721 | filtered_partner_names = {partner["name"] for partner in response_1.json()["filtered_partners"]} 722 | self.assertSetEqual({"Partner2", "Partner4"}, filtered_partner_names) 723 | 724 | def test_sideloading_combined_suppliers_with_filtered_prefetches(self): 725 | response_1 = self.client.get( 726 | path=reverse("product-list"), 727 | data={"sideload": "categories,filtered_suppliers,filtered_partners"}, 728 | **self.DEFAULT_HEADERS, 729 | ) 730 | self.assertEqual(response_1.status_code, status.HTTP_200_OK, response_1.json()) 731 | self.assertIsInstance(response_1.json(), dict) 732 | self.assertListEqual( 733 | ["products", "categories", "filtered_suppliers", "filtered_partners"], 734 | list(response_1.json().keys()), 735 | ) 736 | # check filtered_suppliers and filtered_partners are different from suppliers and partners! 737 | 738 | filtered_supplier_names = {partner["name"] for partner in response_1.json()["filtered_suppliers"]} 739 | self.assertSetEqual({"Supplier2", "Supplier4"}, filtered_supplier_names) 740 | 741 | filtered_partner_names = {partner["name"] for partner in response_1.json()["filtered_partners"]} 742 | self.assertSetEqual({"Partner2", "Partner4"}, filtered_partner_names) 743 | 744 | 745 | class TestDrfSideloadingInvalidPrefetchSource(BaseTestCase): 746 | """ 747 | Run tests with prefetch is user defined and another prefetch for the same relation is also created. 748 | Preftch.to_attr is not set. This should be automatically be set by our code. 749 | """ 750 | 751 | @classmethod 752 | def setUpClass(cls): 753 | super(TestDrfSideloadingInvalidPrefetchSource, cls).setUpClass() 754 | 755 | class TempProductSideloadableSerializer(SideLoadableSerializer): 756 | products = ProductSerializer(many=True) 757 | categories = CategorySerializer(source="category", many=True) 758 | suppliers = SupplierSerializer(source="supplier", many=True) 759 | filtered_suppliers = SupplierSerializer(source="supplier", many=True, read_only=True) 760 | filtered_suppliers2 = SupplierSerializer(source="supplier", many=True, read_only=True) 761 | partners = PartnerSerializer(many=True) 762 | filtered_partners = PartnerSerializer(source="partners", many=True, read_only=True) 763 | combined_suppliers = SupplierSerializer(many=True) 764 | 765 | class Meta: 766 | primary = "products" 767 | prefetches = { 768 | "categories": "category", 769 | "suppliers": ["supplier"], 770 | "filtered_suppliers": Prefetch( 771 | lookup="supplier", 772 | queryset=Supplier.objects.filter(name__in=["Supplier2", "Supplier4"]), 773 | to_attr="filtered_suppliers", 774 | ), 775 | "filtered_suppliers2": Prefetch( 776 | lookup="supplier", 777 | queryset=Supplier.objects.filter(name__in=["Supplier3"]), 778 | to_attr="filtered_suppliers2", 779 | ), 780 | "filtered_partners": Prefetch( 781 | lookup="partners", 782 | queryset=Partner.objects.filter(name__in=["Partner2", "Partner4"]), 783 | to_attr="filtered_partners", 784 | ), 785 | "combined_suppliers": { 786 | "filtered_suppliers": Prefetch( 787 | lookup="supplier", 788 | queryset=Supplier.objects.filter(name__in=["Supplier2", "Supplier4"]), 789 | to_attr="filtered_suppliers", 790 | ), 791 | "filtered_suppliers2": Prefetch( 792 | lookup="supplier", 793 | queryset=Supplier.objects.filter(name__in=["Supplier3"]), 794 | to_attr="filtered_suppliers2", 795 | ), 796 | "filtered_partners": Prefetch( 797 | lookup="partners", 798 | queryset=Partner.objects.filter(name__in=["Partner2", "Partner4"]), 799 | to_attr="filtered_partners", 800 | ), 801 | }, 802 | } 803 | 804 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 805 | 806 | def test_sideloading_combined_suppliers_with_mismatching_to_attr_and_source(self): 807 | msg = ( 808 | "ideloadable field 'filtered_suppliers' Prefetch 'to_attr' can't be different from source defined. " 809 | "Tip: Remove source from field serializer." 810 | ) 811 | with self.assertRaisesMessage(ValueError, msg): 812 | self.client.get( 813 | path=reverse("product-list"), 814 | data={"sideload": "filtered_suppliers"}, 815 | **self.DEFAULT_HEADERS, 816 | ) 817 | 818 | 819 | class TestDrfSideloadingValidPrefetchObjectsImplicit(BaseTestCase): 820 | """ 821 | Run tests with prefetch is user defined and another prefetch for the same relation is also created. 822 | Preftch.to_attr is not set. add field name as default. 823 | """ 824 | 825 | @classmethod 826 | def setUpClass(cls): 827 | super(TestDrfSideloadingValidPrefetchObjectsImplicit, cls).setUpClass() 828 | 829 | class TempProductSideloadableSerializer(SideLoadableSerializer): 830 | products = ProductSerializer(many=True) 831 | categories = CategorySerializer(source="category", many=True) 832 | suppliers = SupplierSerializer(source="supplier", many=True) 833 | filtered_suppliers = SupplierSerializer(source="supplier", many=True) 834 | partners = PartnerSerializer(many=True) 835 | 836 | class Meta: 837 | primary = "products" 838 | prefetches = { 839 | "categories": "category", 840 | "suppliers": ["supplier"], 841 | # "partners": ["products__partners"], 842 | "filtered_suppliers": Prefetch( 843 | lookup="supplier", 844 | queryset=Supplier.objects.filter(name__in=["Supplier2", "Supplier4"]), 845 | # to_attr="filtered_suppliers", 846 | ), 847 | } 848 | 849 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 850 | 851 | def test_sideloading_with_prefetch_object_without_to_attr(self): 852 | msg = ( 853 | "Can't add filtered Prefetch 'supplier'. Existing prefetch does not have filters. " 854 | "APIView might have an unfiltered prefetch_related that sideloading is trying to filter." 855 | ) 856 | with self.assertRaisesMessage(ValueError, msg): 857 | self.client.get( 858 | path=reverse("product-list"), 859 | data={"sideload": "categories,suppliers,filtered_suppliers,partners"}, 860 | **self.DEFAULT_HEADERS, 861 | ) 862 | 863 | 864 | class TestDrfSideloadingPrefetchObjectsMatchingLookup(BaseTestCase): 865 | """ 866 | Use Prefetch object on a field where lookup matches field name 867 | """ 868 | 869 | @classmethod 870 | def setUpClass(cls): 871 | super(TestDrfSideloadingPrefetchObjectsMatchingLookup, cls).setUpClass() 872 | 873 | class TempProductSideloadableSerializer(SideLoadableSerializer): 874 | products = ProductSerializer(many=True) 875 | categories = CategorySerializer(source="category", many=True) 876 | suppliers = SupplierSerializer(source="supplier", many=True) 877 | partners = PartnerSerializer(many=True) 878 | 879 | class Meta: 880 | primary = "products" 881 | prefetches = { 882 | "categories": "category", 883 | "suppliers": ["supplier"], 884 | "partners": Prefetch( 885 | lookup="partners", 886 | queryset=Partner.objects.filter(name__in=["Partner2", "Partner4"]), 887 | # to_attr="partners", # we are testing a case where this is not set. 888 | ), 889 | } 890 | 891 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 892 | 893 | def test_sideloading_with_prefetch_object_without_to_attr_but_lookup_matching_field(self): 894 | response_1 = self.client.get( 895 | path=reverse("product-list"), 896 | data={"sideload": "categories,partners"}, 897 | **self.DEFAULT_HEADERS, 898 | ) 899 | self.assertEqual(response_1.status_code, status.HTTP_200_OK, response_1.data) 900 | self.assertIsInstance(response_1.json(), dict) 901 | self.assertListEqual( 902 | ["products", "categories", "partners"], 903 | list(response_1.json().keys()), 904 | ) 905 | # check filtered_suppliers and filtered_partners are different from suppliers and partners! 906 | partner_names = {partner["name"] for partner in response_1.json()["partners"]} 907 | self.assertSetEqual({"Partner2", "Partner4"}, partner_names) 908 | 909 | 910 | class TestDrfSideloadingInvalidPrefetchObject(BaseTestCase): 911 | """ 912 | Run tests with prefetch is user defined and another prefetch for the same relation is also created. 913 | Preftch.to_attr is invlalid 914 | """ 915 | 916 | @classmethod 917 | def setUpClass(cls): 918 | super(TestDrfSideloadingInvalidPrefetchObject, cls).setUpClass() 919 | 920 | class TempProductSideloadableSerializer(SideLoadableSerializer): 921 | products = ProductSerializer(many=True) 922 | categories = CategorySerializer(source="category", many=True) 923 | suppliers = SupplierSerializer(source="supplier", many=True) 924 | filtered_suppliers = SupplierSerializer(source="supplier", many=True) 925 | partners = PartnerSerializer(many=True) 926 | filtered_partners = PartnerSerializer(source="partners", many=True) 927 | 928 | class Meta: 929 | primary = "products" 930 | prefetches = { 931 | "categories": "category", 932 | "suppliers": ["supplier"], 933 | "filtered_suppliers": Prefetch( 934 | lookup="supplier", 935 | queryset=Supplier.objects.filter(name__in=["Supplier2", "Supplier4"]), 936 | to_attr="wrong_field", 937 | ), 938 | # "partners": None, 939 | "filtered_partners": Prefetch( 940 | lookup="partners", 941 | queryset=Partner.objects.filter(name__in=["Partner2", "Partner4"]), 942 | to_attr="wrong_field_2", 943 | ), 944 | } 945 | 946 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 947 | 948 | def test_sideloading_with_prefetches(self): 949 | # cases where field source is present but Prefetch.to_attr mismatches 950 | msg = ( 951 | "Sideloadable field 'filtered_suppliers' Prefetch 'to_attr' can't be different from source defined. " 952 | "Tip: Remove source from field serializer." 953 | ) 954 | with self.assertRaisesMessage(ValueError, msg): 955 | self.client.get( 956 | path=reverse("product-list"), 957 | data={"sideload": "categories,suppliers,filtered_suppliers,partners"}, 958 | **self.DEFAULT_HEADERS, 959 | ) 960 | 961 | 962 | class TestDrfSideloadingWithoutListModelMixin(BaseTestCase): 963 | def test_list(self): 964 | response = self.client.get( 965 | path=reverse("productretreiveonly-list"), 966 | data={"sideload": "categories"}, 967 | **self.DEFAULT_HEADERS, 968 | ) 969 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 970 | 971 | def test_detail(self): 972 | response = self.client.get( 973 | path=reverse("productretreiveonly-detail", args=[self.product1.id]), 974 | data={"sideload": "categories"}, 975 | **self.DEFAULT_HEADERS, 976 | ) 977 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 978 | 979 | 980 | class TestDrfSideloadingWithoutRetreiveModelMixin(BaseTestCase): 981 | def test_list(self): 982 | response = self.client.get( 983 | path=reverse("productlistonly-list"), 984 | data={"sideload": "categories"}, 985 | **self.DEFAULT_HEADERS, 986 | ) 987 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 988 | 989 | def test_detail(self): 990 | response = self.client.get( 991 | path=reverse("productlistonly-detail", args=[self.product1.id]), 992 | data={"sideload": "categories"}, 993 | **self.DEFAULT_HEADERS, 994 | ) 995 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 996 | 997 | 998 | class TestDrfSideloadingListModelMixinAfterSideloading(BaseTestCase): 999 | # in case the DRF views are added after sideloading, 1000 | # the list and retreive methods will be overwritten and no sideloading happens. 1001 | 1002 | def test_list(self): 1003 | response = self.client.get( 1004 | path=reverse("productwrongmixinorder-list"), 1005 | data={"sideload": "categories,suppliers,filtered_suppliers,partners"}, 1006 | **self.DEFAULT_HEADERS, 1007 | ) 1008 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1009 | self.assertIsInstance(response.json(), list) 1010 | 1011 | def test_detail(self): 1012 | response = self.client.get( 1013 | path=reverse("productwrongmixinorder-detail", args=[self.product1.id]), 1014 | data={"sideload": "categories,suppliers,filtered_suppliers,partners"}, 1015 | **self.DEFAULT_HEADERS, 1016 | ) 1017 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1018 | self.assertIsInstance(response.json(), dict) 1019 | self.assertListEqual(["name", "category", "supplier", "partners", "metadata"], list(response.json().keys())) 1020 | 1021 | 1022 | class TestDrfSideloadingBrowsableApiPermissions(BaseTestCase): 1023 | """Run tests while including mixin but not defining sideloading""" 1024 | 1025 | @classmethod 1026 | def setUpClass(cls): 1027 | super(TestDrfSideloadingBrowsableApiPermissions, cls).setUpClass() 1028 | 1029 | class ProductPermission(BasePermission): 1030 | def has_permission(self, request, view): 1031 | """ 1032 | Return `True` if permission is granted, `False` otherwise. 1033 | """ 1034 | return True 1035 | 1036 | def has_object_permission(self, request, view, obj): 1037 | raise ValueError("This must not be called, when sideloadading is used!") 1038 | 1039 | class TempProductSideloadableSerializer(SideLoadableSerializer): 1040 | products = ProductSerializer(many=True) 1041 | categories = CategorySerializer(source="category", many=True) 1042 | suppliers = SupplierSerializer(source="supplier", many=True) 1043 | partners = PartnerSerializer(many=True) 1044 | 1045 | class Meta: 1046 | primary = "products" 1047 | prefetches = { 1048 | "categories": "category", 1049 | "suppliers": ["supplier"], 1050 | # "partners": None, 1051 | } 1052 | 1053 | ProductViewSet.renderer_classes = (BrowsableAPIRenderer, JSONRenderer) 1054 | ProductViewSet.permission_classes = (ProductPermission,) 1055 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 1056 | 1057 | @classmethod 1058 | def tearDownClass(cls): 1059 | ProductViewSet.renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES 1060 | ProductViewSet.permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES 1061 | super(TestDrfSideloadingBrowsableApiPermissions, cls).tearDownClass() 1062 | 1063 | def test_sideloading_does_not_render_forms_and_check_object_permissions(self): 1064 | response = self.client.get( 1065 | path=reverse("product-list"), data={"sideload": "categories,suppliers,partners"}, **self.DEFAULT_HEADERS 1066 | ) 1067 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1068 | self.assertIsInstance(response.json(), dict) 1069 | self.assertListEqual(["products", "categories", "suppliers", "partners"], list(response.json().keys())) 1070 | 1071 | def test_sideloading_allow_post_without_sideloading(self): 1072 | category = Category.objects.create(name="Category") 1073 | supplier = Supplier.objects.create(name="Supplier") 1074 | 1075 | headers = {"HTTP_ACCEPT": "application/json"} 1076 | response = self.client.post( 1077 | path=reverse("product-list"), 1078 | data={ 1079 | "name": "Random product", 1080 | "category": category.id, 1081 | "supplier": supplier.id, 1082 | "partners": [], 1083 | }, 1084 | **headers, 1085 | ) 1086 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 1087 | self.assertTrue(isinstance(response.json(), dict)) 1088 | self.assertListEqual(["name", "category", "supplier", "partners", "metadata"], list(response.json().keys())) 1089 | 1090 | def test_sideloading_allow_post_with_sideloading(self): 1091 | # TODO: check response with new detail view sideloading logic! 1092 | category = Category.objects.create(name="Category") 1093 | supplier = Supplier.objects.create(name="Supplier") 1094 | 1095 | headers = {"HTTP_ACCEPT": "application/json"} 1096 | response = self.client.post( 1097 | path="{}{}".format(reverse("product-list"), "?sideload=categories,suppliers,partners"), 1098 | data={ 1099 | "name": "Random product", 1100 | "category": category.id, 1101 | "supplier": supplier.id, 1102 | "partners": [], 1103 | }, 1104 | **headers, 1105 | ) 1106 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 1107 | self.assertTrue(isinstance(response.json(), dict)) 1108 | self.assertListEqual(["name", "category", "supplier", "partners", "metadata"], list(response.json().keys())) 1109 | 1110 | 1111 | class ProductSideloadSameSourceDuplicationTestCase(BaseTestCase): 1112 | @classmethod 1113 | def setUpClass(cls): 1114 | super(ProductSideloadSameSourceDuplicationTestCase, cls).setUpClass() 1115 | 1116 | class OldCategorySerializer(serializers.ModelSerializer): 1117 | old_name = serializers.CharField(source="name") 1118 | 1119 | class Meta: 1120 | model = Category 1121 | fields = ["old_name"] 1122 | 1123 | class TempProductSideloadableSerializer(SideLoadableSerializer): 1124 | products = ProductSerializer(many=True) 1125 | categories = CategorySerializer(source="category", many=True) 1126 | old_categories = OldCategorySerializer(source="category", many=True) 1127 | suppliers = SupplierSerializer(source="supplier", many=True) 1128 | partners = PartnerSerializer(many=True) 1129 | 1130 | class Meta: 1131 | primary = "products" 1132 | prefetches = {"category": "category", "old_categories": "category"} 1133 | 1134 | ProductViewSet.sideloading_serializer_class = TempProductSideloadableSerializer 1135 | 1136 | def test_list_sideload_categories(self): 1137 | response = self.client.get( 1138 | path=reverse("product-list"), data={"sideload": "categories"}, **self.DEFAULT_HEADERS 1139 | ) 1140 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1141 | self.assertIsInstance(response.json(), dict) 1142 | self.assertListEqual(["products", "category"], list(response.data.serializer.instance.keys())) 1143 | self.assertListEqual(["products", "categories"], list(response.json().keys())) 1144 | 1145 | def test_list_sideload_old_categories(self): 1146 | response = self.client.get( 1147 | path=reverse("product-list"), data={"sideload": "old_categories"}, **self.DEFAULT_HEADERS 1148 | ) 1149 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1150 | self.assertIsInstance(response.json(), dict) 1151 | self.assertListEqual(["products", "category"], list(response.data.serializer.instance.keys())) 1152 | self.assertListEqual(["products", "old_categories"], list(response.json().keys())) 1153 | 1154 | def test_list_sideload_new_categories_and_old_categories(self): 1155 | response = self.client.get( 1156 | path=reverse("product-list"), data={"sideload": "categories,old_categories"}, **self.DEFAULT_HEADERS 1157 | ) 1158 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1159 | self.assertIsInstance(response.json(), dict) 1160 | self.assertListEqual(["products", "category"], list(response.data.serializer.instance.keys())) 1161 | self.assertListEqual(["products", "categories", "old_categories"], list(response.json().keys())) 1162 | 1163 | 1164 | class VersionesSideloadableSerializerTestCase(BaseTestCase): 1165 | def test_list_sideload_categories(self): 1166 | # old as example 1167 | response = self.client.get( 1168 | path=reverse("product-list"), 1169 | data={"sideload": "categories"}, 1170 | # headers: 1171 | HTTP_ACCEPT="application/json; version=1.0", 1172 | ) 1173 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1174 | self.assertIsInstance(response.json(), dict) 1175 | self.assertListEqual(["products", "categories"], list(response.json().keys())) 1176 | 1177 | # new version can't sideload this value 1178 | response = self.client.get( 1179 | path=reverse("product-list"), 1180 | data={"sideload": "categories"}, 1181 | # headers: 1182 | HTTP_ACCEPT="application/json; version=2.0.0", 1183 | ) 1184 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 1185 | self.assertDictEqual({"sideload": ["'categories' is not one of the available choices."]}, response.json()) 1186 | 1187 | # new version called correctly 1188 | response = self.client.get( 1189 | path=reverse("product-list"), 1190 | data={"sideload": "new_categories"}, 1191 | # headers: 1192 | HTTP_ACCEPT="application/json; version=2.0.0", 1193 | ) 1194 | self.assertEqual(response.status_code, status.HTTP_200_OK, response.json()) 1195 | self.assertIsInstance(response.json(), dict) 1196 | self.assertListEqual(["products", "new_categories"], list(response.json().keys())) 1197 | --------------------------------------------------------------------------------