├── .coveragerc ├── docs ├── history.rst ├── reqs.txt ├── api.rst ├── api │ ├── modules.rst │ ├── rest_witchcraft.utils.rst │ ├── rest_witchcraft.fields.rst │ ├── rest_witchcraft.filters.rst │ ├── rest_witchcraft.mixins.rst │ ├── rest_witchcraft.routers.rst │ ├── rest_witchcraft.generics.rst │ ├── rest_witchcraft.viewsets.rst │ ├── rest_witchcraft.serializers.rst │ ├── rest_witchcraft.field_mapping.rst │ └── rest_witchcraft.rst ├── index.rst ├── Makefile └── conf.py ├── .sourcery.yaml ├── rest_witchcraft ├── __init__.py ├── __version__.py ├── utils.py ├── viewsets.py ├── routers.py ├── generics.py ├── fields.py ├── field_mapping.py ├── filters.py ├── mixins.py └── serializers.py ├── .dockerignore ├── tests ├── __init__.py ├── urls │ ├── testhyperlinkedidentityfield.py │ └── testhyperlinkedidentityfieldcomposite.py ├── test_utils.py ├── models_composite.py ├── settings.py ├── test_generics.py ├── test_fields.py ├── test_field_mapping.py ├── models.py ├── test_mixins.py ├── test_filters.py ├── test_routers.py └── test_serializers.py ├── .editorconfig ├── requirements.txt ├── pyproject.toml ├── .importanizerc ├── MANIFEST.in ├── setup.cfg ├── docker-compose.yml ├── .gitignore ├── LICENSE ├── Dockerfile ├── .gitchangelog.rc ├── .pre-commit-config.yaml ├── setup.py ├── .github └── workflows │ └── build.yml ├── Makefile ├── tox.ini ├── README.rst └── HISTORY.rst /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | relative_files = True 3 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst 2 | -------------------------------------------------------------------------------- /.sourcery.yaml: -------------------------------------------------------------------------------- 1 | refactor: 2 | python_version: '3.6' 3 | -------------------------------------------------------------------------------- /docs/reqs.txt: -------------------------------------------------------------------------------- 1 | Django 2 | SQLAlchemy 3 | djangorestframework 4 | six 5 | -------------------------------------------------------------------------------- /rest_witchcraft/__init__.py: -------------------------------------------------------------------------------- 1 | from .__version__ import __version__ # noqa 2 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 10 6 | 7 | api/modules 8 | -------------------------------------------------------------------------------- /docs/api/modules.rst: -------------------------------------------------------------------------------- 1 | rest_witchcraft 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | rest_witchcraft 8 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .env 2 | .git 3 | .idea 4 | **/*.pyc 5 | **/*.pyo 6 | **/__pycache__ 7 | **/*.db 8 | **/*.egg-info 9 | **/.tox 10 | **/htmlcov 11 | .pytest_cache 12 | .mypy_cache 13 | .tox 14 | -------------------------------------------------------------------------------- /rest_witchcraft/__version__.py: -------------------------------------------------------------------------------- 1 | __author__ = "Serkan Hosca" 2 | __author_email__ = "serkan@hosca.com" 3 | __version__ = "0.12.1" 4 | __description__ = "Django REST Framework and SQLAlchemy integration" 5 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.utils.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.utils module 2 | ============================= 3 | 4 | .. automodule:: rest_witchcraft.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.fields.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.fields module 2 | ============================== 3 | 4 | .. automodule:: rest_witchcraft.fields 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.filters.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.filters module 2 | =============================== 3 | 4 | .. automodule:: rest_witchcraft.filters 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.mixins.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.mixins module 2 | ============================== 3 | 4 | .. automodule:: rest_witchcraft.mixins 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.routers.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.routers module 2 | =============================== 3 | 4 | .. automodule:: rest_witchcraft.routers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.generics.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.generics module 2 | ================================ 3 | 4 | .. automodule:: rest_witchcraft.generics 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.viewsets.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.viewsets module 2 | ================================ 3 | 4 | .. automodule:: rest_witchcraft.viewsets 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.serializers.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.serializers module 2 | =================================== 3 | 4 | .. automodule:: rest_witchcraft.serializers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.field_mapping.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft.field\_mapping module 2 | ====================================== 3 | 4 | .. automodule:: rest_witchcraft.field_mapping 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from psycopg2cffi import compat 4 | 5 | import django 6 | import django.test.utils 7 | 8 | 9 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") 10 | 11 | compat.register() 12 | 13 | django.setup() 14 | -------------------------------------------------------------------------------- /tests/urls/testhyperlinkedidentityfield.py: -------------------------------------------------------------------------------- 1 | try: 2 | from django.conf.urls import url as re_path 3 | except ImportError: # pragma: no cover 4 | from django.urls import re_path 5 | 6 | 7 | urlpatterns = [re_path(r"^example/(?P.+)/$", lambda: None, name="owner")] 8 | -------------------------------------------------------------------------------- /tests/urls/testhyperlinkedidentityfieldcomposite.py: -------------------------------------------------------------------------------- 1 | try: 2 | from django.conf.urls import url as re_path 3 | except ImportError: # pragma: no cover 4 | from django.urls import re_path 5 | 6 | 7 | urlpatterns = [re_path(r"^example/(?P.+)/(?P.+)/$", lambda: None, name="owner")] 8 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [LICENSE] 14 | insert_final_newline = false 15 | 16 | [Makefile] 17 | indent_style = tab 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | Sphinx 3 | coreapi 4 | coreschema 5 | coveralls 6 | flake8 7 | flake8-bugbear 8 | flake8-comprehensions 9 | flake8-django 10 | gitchangelog 11 | pdbpp 12 | pre_commit 13 | psycopg2cffi 14 | pytest 15 | pytest-cov 16 | simplejson 17 | sphinx-autobuild 18 | sphinx_rtd_theme 19 | sqlalchemy_utils 20 | tox==3.28.0 21 | tox-factor 22 | tox-pyenv 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py36', 'py37', 'py38'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | \.eggs 8 | | \.git 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | build 15 | | dist 16 | | env 17 | | htmlcov 18 | | node_modules 19 | | public 20 | | venv 21 | )/ 22 | ''' 23 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Django REST Witchcraft documentation master file, created by 2 | sphinx-quickstart on Sat Jun 10 09:20:17 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. include:: ../README.rst 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | history 12 | api 13 | 14 | Indices and tables 15 | ================== 16 | 17 | * :ref:`genindex` 18 | * :ref:`modindex` 19 | * :ref:`search` 20 | -------------------------------------------------------------------------------- /.importanizerc: -------------------------------------------------------------------------------- 1 | { 2 | "length": 120, 3 | "groups": [ 4 | {"type": "stdlib"}, 5 | {"type": "remainder"}, 6 | {"type": "packages", "packages": ["sqlalchemy"]}, 7 | {"type": "packages", "packages": ["django"]}, 8 | {"type": "packages", "packages": ["django_sorcery"]}, 9 | {"type": "packages", "packages": ["rest_framework"]}, 10 | {"type": "packages", "packages": ["rest_enumfield"]}, 11 | {"type": "packages", "packages": ["rest_witchcraft"]}, 12 | {"type": "local"} 13 | ] 14 | } 15 | -------------------------------------------------------------------------------- /docs/api/rest_witchcraft.rst: -------------------------------------------------------------------------------- 1 | rest\_witchcraft package 2 | ======================== 3 | 4 | .. automodule:: rest_witchcraft 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 4 14 | 15 | rest_witchcraft.field_mapping 16 | rest_witchcraft.fields 17 | rest_witchcraft.filters 18 | rest_witchcraft.generics 19 | rest_witchcraft.mixins 20 | rest_witchcraft.routers 21 | rest_witchcraft.serializers 22 | rest_witchcraft.utils 23 | rest_witchcraft.viewsets 24 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.coveragerc 2 | include *.rc 3 | include *.editorconfig 4 | include *.importanizerc 5 | include *.rst 6 | include *.sh 7 | include *.toml 8 | include *.txt 9 | include *.yaml *.yml 10 | exclude .* 11 | include LICENSE 12 | include Makefile 13 | include pytest.ini 14 | include tox.ini 15 | include Dockerfile 16 | exclude docker-compose.override.yml 17 | recursive-exclude * __pycache__ 18 | recursive-exclude * *.py[co] 19 | 20 | recursive-include rest_witchcraft *.py 21 | 22 | recursive-include docs * 23 | recursive-exclude docs/_build * 24 | 25 | recursive-include tests * 26 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from django.core.exceptions import ValidationError 4 | 5 | from rest_witchcraft.utils import _django_to_drf 6 | 7 | 8 | class TestUtils(unittest.TestCase): 9 | def test_django_to_drf(self): 10 | self.assertEqual(_django_to_drf("hello"), "hello") 11 | self.assertEqual(_django_to_drf(["hello"]), ["hello"]) 12 | self.assertEqual(_django_to_drf({"hello": "world"}), {"hello": "world"}) 13 | self.assertEqual(_django_to_drf(ValidationError("hello")), ["hello"]) 14 | self.assertEqual(_django_to_drf(ValidationError({"hello": "world"})), {"hello": ["world"]}) 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | 4 | [tool:pytest] 5 | addopts= 6 | --tb native 7 | -r sfxX 8 | 9 | [flake8] 10 | show-source = true 11 | enable-extensions = B,G 12 | max-line-length = 120 13 | exclude = .eggs,.tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules 14 | select = C,E,F,W,B,B950 15 | ignore = E501,W503 16 | 17 | [importanize] 18 | allow_plugins=True 19 | plugins= 20 | unused_imports 21 | length=120 22 | groups= 23 | stdlib 24 | sitepackages 25 | remainder 26 | packages:sqlalchemy 27 | packages:django 28 | packages:django_sorcery 29 | packages:rest_witchcraft 30 | packages:tests 31 | local 32 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.7' 2 | 3 | services: 4 | pg: 5 | # 12 breaks older versions of sqlalchemy so sticking with 11 for now 6 | # https://github.com/sqlalchemy/sqlalchemy/issues/4463 7 | image: postgres:11 8 | ports: 9 | - "5432:5432" 10 | environment: 11 | - POSTGRES_PASSWORD=postgres 12 | 13 | py: 14 | build: 15 | context: . 16 | args: 17 | USER_ID: ${USER_ID:-1000} 18 | GROUP_ID: ${GROUP_ID:-1000} 19 | command: sleep infinity 20 | depends_on: 21 | - pg 22 | environment: 23 | - DATABASE_URL=postgresql://postgres:postgres@pg 24 | - TOX_WORK_DIR=/tmp/tox 25 | volumes: 26 | - .:/code 27 | - /tmp:/root/.cache 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | .eggs 21 | .cache 22 | .python-version 23 | 24 | # Installer logs 25 | pip-log.txt 26 | 27 | # Unit test / coverage reports 28 | .coverage 29 | .tox 30 | nosetests.xml 31 | htmlcov 32 | 33 | # Translations 34 | *.mo 35 | 36 | # Mr Developer 37 | .mr.developer.cfg 38 | .project 39 | .pydevproject 40 | 41 | # Complexity 42 | output/*.html 43 | output/*/index.html 44 | 45 | # Sphinx 46 | docs/_build 47 | 48 | # IDEs 49 | /.idea 50 | .venv 51 | .pytest_cache 52 | .mypy_cache 53 | pip-wheel-metadata 54 | .vscode 55 | .python-version 56 | docker-compose.override.yml 57 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = DjangoRESTWitchcraft 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | autodoc: 18 | rm -rf api/*.rst 19 | sphinx-apidoc --module-first --separate --output-dir=api ../rest_witchcraft 20 | 21 | # Catch-all target: route all unknown targets to Sphinx using the new 22 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 23 | %: Makefile autodoc 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | -------------------------------------------------------------------------------- /tests/models_composite.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, create_engine, orm, types 2 | from sqlalchemy.ext.declarative import declarative_base 3 | 4 | from django.conf import settings 5 | 6 | 7 | engine = create_engine(settings.DB_URL) 8 | session = orm.scoped_session(orm.sessionmaker(bind=engine)) 9 | Base = declarative_base() 10 | Base.query = session.query_property() 11 | 12 | 13 | class RouterTestModel(Base): 14 | __tablename__ = "routertest" 15 | id = Column(types.Integer(), default=3, primary_key=True) 16 | text = Column(types.String(length=200)) 17 | 18 | 19 | class RouterTestCompositeKeyModel(Base): 20 | __tablename__ = "routertestcomposite" 21 | id = Column(types.Integer(), default=1, primary_key=True) 22 | other_id = Column(types.Integer(), default=3, primary_key=True) 23 | text = Column(types.String(length=200)) 24 | 25 | 26 | Base.metadata.create_all(engine) 27 | -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | INSTALLED_APPS = ["rest_witchcraft", "rest_framework", "django.contrib.staticfiles"] 5 | 6 | LANGUAGE_CODE = "en-us" 7 | 8 | TIME_ZONE = "UTC" 9 | 10 | USE_I18N = True 11 | 12 | USE_L10N = True 13 | 14 | USE_TZ = True 15 | 16 | USE_THOUSAND_SEPARATOR = True 17 | 18 | DB_URL = os.environ.get("DATABASE_URL", "postgresql://postgres:postgres@localhost") + "/test" 19 | 20 | SECRET_KEY = "secret" 21 | 22 | ALLOWED_HOSTS = ["*"] 23 | 24 | TEMPLATES = [ 25 | { 26 | "BACKEND": "django.template.backends.django.DjangoTemplates", 27 | "DIRS": [os.path.join(os.path.dirname(__file__), "templates")], 28 | "APP_DIRS": True, 29 | "OPTIONS": { 30 | "context_processors": [ 31 | "django.template.context_processors.debug", 32 | "django.template.context_processors.request", 33 | ] 34 | }, 35 | } 36 | ] 37 | -------------------------------------------------------------------------------- /rest_witchcraft/utils.py: -------------------------------------------------------------------------------- 1 | from django.core.exceptions import NON_FIELD_ERRORS, ValidationError as DjangoValidationError 2 | 3 | from rest_framework.serializers import ValidationError 4 | from rest_framework.settings import api_settings 5 | 6 | 7 | def _django_to_drf(e): 8 | if hasattr(e, "error_dict") or isinstance(e, dict): 9 | return { 10 | k if k != NON_FIELD_ERRORS else api_settings.NON_FIELD_ERRORS_KEY: _django_to_drf(v) 11 | for k, v in getattr(e, "error_dict", e).items() 12 | } 13 | 14 | elif hasattr(e, "error_list"): 15 | return e.messages 16 | elif isinstance(e, list): 17 | errors = [] 18 | for j in e: 19 | if isinstance(j, DjangoValidationError): 20 | errors += _django_to_drf(j) 21 | else: 22 | errors.append(_django_to_drf(j)) 23 | return errors 24 | return e 25 | 26 | 27 | def django_to_drf_validation_error(e): 28 | return ValidationError( 29 | _django_to_drf(e) if hasattr(e, "error_dict") else {api_settings.NON_FIELD_ERRORS_KEY: e.messages} 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Serkan Hosca 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM buildpack-deps:buster 2 | 3 | RUN mkdir -p /opt/pyenv &&\ 4 | curl -sL https://github.com/pyenv/pyenv/archive/refs/heads/master.tar.gz | \ 5 | tar -C /opt/pyenv --strip-components 1 -xz && \ 6 | apt-get update && \ 7 | apt-get install -y postgresql-client && \ 8 | apt-get clean 9 | 10 | ARG USER_ID=1000 11 | ARG GROUP_ID=1000 12 | 13 | RUN groupadd -g ${GROUP_ID} sorcerer && \ 14 | useradd -l -u ${USER_ID} -g ${GROUP_ID} -m -d /home/sorcerer sorcerer && \ 15 | mkdir -p /code && chown -R sorcerer:sorcerer /code 16 | 17 | USER sorcerer 18 | ENV PATH=/opt/pyenv/bin:/home/sorcerer/.pyenv/shims:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin 19 | WORKDIR /code 20 | 21 | RUN pyenv install 3.6:latest 22 | RUN pyenv install 3.7:latest 23 | RUN pyenv install 3.8:latest 24 | RUN pyenv install 3.9:latest 25 | RUN pyenv install 3.10:latest 26 | RUN pyenv install 3.11:latest 27 | 28 | RUN pyenv install pypy3.6:latest 29 | RUN pyenv install pypy3.7:latest 30 | RUN pyenv install pypy3.8:latest 31 | 32 | RUN pyenv global $(pyenv versions | grep ' 3.11.') && \ 33 | pip install pre-commit tox==3.28.0 tox-pyenv tox-factor && \ 34 | pyenv global $(pyenv versions | grep -v system | cut -c 3- | cut -d' ' -f1) 35 | -------------------------------------------------------------------------------- /.gitchangelog.rc: -------------------------------------------------------------------------------- 1 | # Ignore commits 2 | ignore_regexps = [ 3 | r'@minor', 4 | r'!minor', 5 | r'@cosmetic', 6 | r'!cosmetic', 7 | r'@refactor', 8 | r'!refactor', 9 | r'@wip', 10 | r'!wip', 11 | r'^Merge commit .* into HEAD', 12 | r'^Version bump*', 13 | r'^Docs: update*', 14 | r'^Bump version*', 15 | r'^Update history*', 16 | r'^Initial commit* into HEAD', 17 | ] 18 | 19 | # Detect sections 20 | section_regexps = [ 21 | ('New features', [ 22 | r'^[fF]eat.*:\s*?([^\n]*)$', 23 | ]), 24 | ('Fix', [ 25 | r'^[fF]ix:\s*?([^\n]*)$', 26 | ]), 27 | ('Refactor', [ 28 | r'^[rR]efactor:\s*?([^\n]*)$', 29 | ]), 30 | ('Documentation', [ 31 | r'^[dD]oc.*:\s*?([^\n]*)$', 32 | ]), 33 | ('Other', None ## Match all lines 34 | ), 35 | ] 36 | 37 | # Rewrite body 38 | body_process = (ReSub(r'.*', '') | 39 | ReSub(r'^(\n|\r)$', '')) 40 | 41 | # Rewrite subject 42 | subject_process = (strip | 43 | ReSub(r'^(\w+)\s*:\s*([^\n@]*)(@[a-z]+\s+)*$', r'\2') | 44 | ReSub(r'^\*\*: ', '') | 45 | ReSub(r'\) \(', ' ') | 46 | strip | ucfirst | final_dot) 47 | 48 | tag_filter_regexp = r'^[0-9]+\.[0-9]+(\.[0-9]+)?$' 49 | 50 | unreleased_version_label = "Next version (unreleased yet)" 51 | 52 | output_engine = rest_py 53 | include_merges = False 54 | -------------------------------------------------------------------------------- /rest_witchcraft/viewsets.py: -------------------------------------------------------------------------------- 1 | from rest_framework import mixins, viewsets 2 | 3 | from .generics import GenericAPIView 4 | from .mixins import DestroyModelMixin, ExpandableQuerySerializerMixin 5 | 6 | 7 | class GenericViewSet(viewsets.ViewSetMixin, GenericAPIView): 8 | """The GenericViewSet class does not provide any actions by default, but 9 | does include the base set of generic view behavior, such as the 10 | `get_object` and `get_queryset` methods.""" 11 | 12 | 13 | class ReadOnlyViewModelViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet): 14 | """A viewset that provides default `list()` and `retrieve()` actions.""" 15 | 16 | 17 | class ModelViewSet( 18 | mixins.CreateModelMixin, 19 | mixins.RetrieveModelMixin, 20 | mixins.UpdateModelMixin, 21 | DestroyModelMixin, 22 | mixins.ListModelMixin, 23 | GenericViewSet, 24 | ): 25 | """A viewset that provides default `create()`, `retrieve()`, `update()`, 26 | `partial_update()`, `destroy()` and `list()` actions.""" 27 | 28 | 29 | class ExpandableModelViewSet(ExpandableQuerySerializerMixin, ModelViewSet): 30 | """A viewset that provides automatically eagerloadsany subfields that are 31 | expanded via querystring. 32 | 33 | For queryset to be expanded, either :py:class:`rest_witchcraft.serializers.ExpandableModelSerializer` 34 | needs to be used in ``serializer_class`` or ``query_serializer_class`` can be manually provided. 35 | """ 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | 4 | - repo: https://github.com/miki725/importanize 5 | rev: '0.7' 6 | hooks: 7 | - id: importanize 8 | language_version: python3 9 | 10 | - repo: https://github.com/psf/black 11 | rev: 22.12.0 12 | hooks: 13 | - id: black 14 | additional_dependencies: ["click==8.0.4"] 15 | language_version: python3 16 | 17 | - repo: https://github.com/asottile/pyupgrade 18 | rev: v3.3.1 19 | hooks: 20 | - id: pyupgrade 21 | args: [--py3-plus] 22 | 23 | - repo: https://github.com/myint/docformatter 24 | rev: v1.5.1 25 | hooks: 26 | - id: docformatter 27 | 28 | - repo: https://github.com/PyCQA/flake8 29 | rev: 6.0.0 30 | hooks: 31 | - id: flake8 32 | exclude: deployment/roles 33 | additional_dependencies: 34 | - flake8-bugbear 35 | - flake8-comprehensions 36 | - flake8-debugger 37 | 38 | - repo: https://github.com/mgedmin/check-manifest 39 | rev: '0.49' 40 | hooks: 41 | - id: check-manifest 42 | 43 | - repo: https://github.com/pre-commit/pre-commit-hooks 44 | rev: v4.4.0 45 | hooks: 46 | - id: check-added-large-files 47 | - id: check-builtin-literals 48 | - id: check-byte-order-marker 49 | - id: check-case-conflict 50 | - id: check-docstring-first 51 | - id: check-executables-have-shebangs 52 | - id: check-json 53 | - id: check-merge-conflict 54 | - id: check-xml 55 | - id: check-yaml 56 | - id: debug-statements 57 | - id: trailing-whitespace 58 | - id: mixed-line-ending 59 | args: [--fix=lf] 60 | - id: pretty-format-json 61 | args: [--autofix] 62 | -------------------------------------------------------------------------------- /rest_witchcraft/routers.py: -------------------------------------------------------------------------------- 1 | from django_sorcery.db import meta 2 | 3 | from rest_framework import routers 4 | 5 | 6 | class DefaultRouter(routers.DefaultRouter): 7 | def get_default_base_name(self, viewset): 8 | model = getattr(viewset, "get_model", lambda: None)() 9 | 10 | assert model is not None, ( 11 | "`base_name` argument not specified, and could not automatically determine the name from the viewset, " 12 | "as either queryset is is missing or is not a sqlalchemy query, or the serializer_class is not a " 13 | "sqlalchemy model serializer" 14 | ) 15 | 16 | return model.__name__.lower() 17 | 18 | # for backwards compatibility DRF<3.9 19 | get_default_basename = get_default_base_name 20 | 21 | def get_lookup_regex(self, viewset, lookup_prefix=""): 22 | """Given a viewset, return the portion of the url regex that is used to 23 | match against a single instance. 24 | 25 | Can be overwritten by providing a `lookup_url_regex` on the 26 | viewset. 27 | """ 28 | 29 | lookup_url_regex = getattr(viewset, "lookup_url_regex", None) 30 | if lookup_url_regex: 31 | return lookup_url_regex 32 | 33 | model = getattr(viewset, "get_model", lambda: None)() 34 | if model: 35 | info = meta.model_info(model) 36 | base_regex = "(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})" 37 | 38 | lookup_keys = [getattr(viewset, "lookup_url_kwarg", None) or getattr(viewset, "lookup_field", None)] 39 | if not lookup_keys[0] or len(info.primary_keys) > 1: 40 | lookup_keys = list(info.primary_keys) 41 | 42 | regexes = [ 43 | base_regex.format( 44 | lookup_prefix=lookup_prefix, 45 | lookup_url_kwarg=key, 46 | lookup_value="[^/.]+", 47 | ) 48 | for key in lookup_keys 49 | ] 50 | 51 | return "/".join(regexes) 52 | 53 | return super().get_lookup_regex(viewset, lookup_prefix) 54 | -------------------------------------------------------------------------------- /rest_witchcraft/generics.py: -------------------------------------------------------------------------------- 1 | from contextlib import suppress 2 | 3 | from sqlalchemy.exc import InvalidRequestError 4 | 5 | from django.http import Http404 6 | 7 | from django_sorcery.db.meta import model_info 8 | 9 | from rest_framework import generics 10 | 11 | 12 | class GenericAPIView(generics.GenericAPIView): 13 | """Base class for sqlalchemy specific views.""" 14 | 15 | @classmethod 16 | def get_model(cls): 17 | """Returns the model class.""" 18 | model = None 19 | 20 | with suppress(AttributeError, InvalidRequestError): 21 | model = cls.queryset._only_full_mapper_zero("get").class_ 22 | 23 | if model: 24 | return model 25 | 26 | with suppress(AttributeError): 27 | model = cls.serializer_class.Meta.model 28 | 29 | assert model is not None, ( 30 | "Couldn't figure out the model for {viewset} attribute, either provide a" 31 | "queryset or a serializer with a Meta.model".format(viewset=cls.__name__) 32 | ) 33 | 34 | return model 35 | 36 | def get_session(self): 37 | """Returns the session.""" 38 | queryset = self.get_queryset() 39 | return queryset.session 40 | 41 | def get_object(self): 42 | """Returns the object the view is displaying. 43 | 44 | We ignore the `lookup_field` and `lookup_url_kwarg` values only 45 | when tere are multiple primary keys 46 | """ 47 | queryset = self.get_queryset() 48 | model = self.get_model() 49 | info = model_info(model) 50 | kwargs = self.kwargs.copy() 51 | 52 | # we want to honor DRF lookup_field and lookup_url_kwarg API 53 | # but only if they are defined and there is single primary key. 54 | # When there are multiple, all bets are off so we restrict url kwargs 55 | # to model column names 56 | if len(info.primary_keys) == 1: 57 | lookup_field = self.lookup_field 58 | lookup_url_kwarg = self.lookup_url_kwarg or lookup_field 59 | 60 | kwargs[lookup_field] = kwargs.pop(lookup_url_kwarg) 61 | 62 | obj = queryset.get(info.primary_keys_from_dict(kwargs)) 63 | 64 | if not obj: 65 | raise Http404("No %s matches the given query." % model.__name__) 66 | 67 | return obj 68 | -------------------------------------------------------------------------------- /tests/test_generics.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, create_engine, orm, types 2 | from sqlalchemy.ext.declarative import declarative_base 3 | 4 | from django.http import Http404 5 | from django.test import SimpleTestCase 6 | 7 | from rest_framework.test import APIRequestFactory 8 | 9 | from rest_witchcraft import serializers, viewsets 10 | 11 | 12 | factory = APIRequestFactory() 13 | 14 | engine = create_engine("sqlite://") 15 | session = orm.scoped_session(orm.sessionmaker(bind=engine)) 16 | Base = declarative_base() 17 | Base.query = session.query_property() 18 | 19 | 20 | class RouterTestModel(Base): 21 | __tablename__ = "routertest" 22 | id = Column(types.Integer(), default=3, primary_key=True) 23 | text = Column(types.String(length=200)) 24 | 25 | 26 | Base.metadata.create_all(engine) 27 | 28 | 29 | class RouterTestModelSerializer(serializers.ModelSerializer): 30 | class Meta: 31 | model = RouterTestModel 32 | session = session 33 | fields = "__all__" 34 | 35 | 36 | class TestModelRoutes(SimpleTestCase): 37 | def test_get_model_using_queryset(self): 38 | class RouterTestViewSet(viewsets.ModelViewSet): 39 | queryset = RouterTestModel.query 40 | serializer_class = RouterTestModelSerializer 41 | 42 | model = RouterTestViewSet.get_model() 43 | 44 | self.assertEqual(model, RouterTestModel) 45 | 46 | def test_get_model_using_serializer(self): 47 | class RouterTestViewSet(viewsets.ModelViewSet): 48 | serializer_class = RouterTestModelSerializer 49 | 50 | model = RouterTestViewSet.get_model() 51 | 52 | self.assertEqual(model, RouterTestModel) 53 | 54 | def test_get_model_fails_with_assert_error(self): 55 | class RouterTestViewSet(viewsets.ModelViewSet): 56 | pass 57 | 58 | with self.assertRaises(AssertionError): 59 | RouterTestViewSet.get_model() 60 | 61 | def test_get_object_raises_404(self): 62 | class RouterTestViewSet(viewsets.ModelViewSet): 63 | queryset = RouterTestModel.query 64 | serializer_class = RouterTestModelSerializer 65 | lookup_field = "id" 66 | lookup_url_kwarg = "pk" 67 | 68 | viewset = RouterTestViewSet() 69 | viewset.kwargs = {"pk": 1} 70 | 71 | with self.assertRaises(Http404): 72 | viewset.get_object() 73 | -------------------------------------------------------------------------------- /tests/test_fields.py: -------------------------------------------------------------------------------- 1 | from django.test import SimpleTestCase, override_settings 2 | 3 | from rest_framework.fields import ChoiceField, IntegerField 4 | from rest_framework.serializers import Serializer 5 | 6 | from rest_witchcraft.fields import HyperlinkedIdentityField, ImplicitExpandableListField, SkippableField, UriField 7 | 8 | from .models import Owner 9 | from .models_composite import RouterTestCompositeKeyModel 10 | 11 | 12 | @override_settings(ROOT_URLCONF="tests.urls.testhyperlinkedidentityfield") 13 | class TestHyperlinkedIdentityField(SimpleTestCase): 14 | def test_url(self): 15 | field = HyperlinkedIdentityField(view_name="foo", lookup_url_kwarg="id", lookup_field="id") 16 | owner = Owner(id=1, first_name="Jon", last_name="Snow") 17 | 18 | url = field.get_url(owner, "owner", None, None) 19 | self.assertEqual(url, "/example/1/") 20 | 21 | def test_url_not_saved(self): 22 | field = HyperlinkedIdentityField(view_name="foo") 23 | 24 | self.assertIsNone(field.get_url(Owner(first_name="Jon", last_name="Snow"), "owner", None, None)) 25 | 26 | 27 | @override_settings(ROOT_URLCONF="tests.urls.testhyperlinkedidentityfieldcomposite") 28 | class TestHyperlinkedIdentityFieldComposite(SimpleTestCase): 29 | def test_url_composite(self): 30 | field = HyperlinkedIdentityField(view_name="foo") 31 | instance = RouterTestCompositeKeyModel(id=1, other_id=2) 32 | 33 | url = field.get_url(instance, "owner", None, None) 34 | 35 | self.assertEqual(url, "/example/1/2/") 36 | 37 | 38 | @override_settings(ROOT_URLCONF="tests.urls.testhyperlinkedidentityfield") 39 | class TestUriField(SimpleTestCase): 40 | def test_url(self): 41 | field = UriField(view_name="foo", lookup_url_kwarg="id", lookup_field="id") 42 | owner = Owner(id=1, first_name="Jon", last_name="Snow") 43 | 44 | url = field.get_url(owner, "owner", None, None) 45 | self.assertEqual(url, "/example/1/") 46 | 47 | 48 | class TestImplicitExpandableListField(SimpleTestCase): 49 | def test_to_internal_value(self): 50 | f = ImplicitExpandableListField(child=ChoiceField(choices=["foo", "foo__bar", "bar"])) 51 | 52 | self.assertEqual(f.run_validation(["foo"]), ["foo"]) 53 | self.assertEqual(set(f.run_validation(["foo__bar"])), {"foo__bar", "foo"}) 54 | 55 | 56 | class TestSkippableField(SimpleTestCase): 57 | def test_field_always_skippable(self): 58 | class TestSerializer(Serializer): 59 | foo = SkippableField() 60 | bar = IntegerField() 61 | 62 | self.assertEqual(TestSerializer(instance={"foo": "foo", "bar": 5}).data, {"bar": 5}) 63 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import shutil 4 | import sys 5 | 6 | from setuptools import find_packages, setup 7 | 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | about = {} 12 | with open(os.path.join(here, "rest_witchcraft", "__version__.py")) as f: 13 | exec(f.read(), about) # yapf: disable 14 | 15 | 16 | def read(fname): 17 | return open(os.path.join(os.path.dirname(__file__), fname), "rb").read().decode("utf-8") 18 | 19 | 20 | if sys.argv[-1] == "publish": 21 | twine = shutil.which("twine") 22 | if twine is None: 23 | print("twine not installed.\nUse `pip install twine`.\nExiting.") 24 | sys.exit() 25 | os.system("python setup.py sdist bdist_wheel") 26 | os.system("twine upload dist/*") 27 | print("You probably want to also tag the version now:") 28 | print(" git tag -a {} -m {}".format(about["__version__"], about["__version__"])) 29 | print(" git push --tags") 30 | os.system("make clean") 31 | sys.exit() 32 | 33 | setup( 34 | author=about["__author__"], 35 | author_email=about["__author_email__"], 36 | description=about["__description__"], 37 | install_requires=["djangorestframework", "six", "django-sorcery>=0.11.1", "django-rest-enumfield"], 38 | license="MIT", 39 | long_description=read("README.rst"), 40 | name="django-rest-witchcraft", 41 | packages=find_packages(exclude=["tests"]), 42 | url="https://github.com/shosca/django-rest-witchcraft", 43 | version=about["__version__"], 44 | keywords="sqlalchemy django rest framework drf rest_framework", 45 | classifiers=[ 46 | "Development Status :: 4 - Beta", 47 | "Environment :: Web Environment", 48 | "Framework :: Django :: 1.11", 49 | "Framework :: Django :: 2.0", 50 | "Framework :: Django :: 2.1", 51 | "Framework :: Django :: 2.2", 52 | "Framework :: Django :: 3.0", 53 | "Framework :: Django :: 3.1", 54 | "Framework :: Django :: 3.2", 55 | "Framework :: Django :: 4.0", 56 | "Framework :: Django :: 4.1", 57 | "Framework :: Django", 58 | "Intended Audience :: Developers", 59 | "License :: OSI Approved :: MIT License", 60 | "Natural Language :: English", 61 | "Operating System :: OS Independent", 62 | "Programming Language :: Python :: 3 :: Only", 63 | "Programming Language :: Python :: 3", 64 | "Programming Language :: Python :: 3.10", 65 | "Programming Language :: Python :: 3.11", 66 | "Programming Language :: Python :: 3.5", 67 | "Programming Language :: Python :: 3.6", 68 | "Programming Language :: Python :: 3.7", 69 | "Programming Language :: Python :: 3.8", 70 | "Programming Language :: Python :: 3.9", 71 | "Programming Language :: Python :: Implementation :: CPython", 72 | "Programming Language :: Python", 73 | "Topic :: Internet :: WWW/HTTP", 74 | ], 75 | ) 76 | -------------------------------------------------------------------------------- /rest_witchcraft/fields.py: -------------------------------------------------------------------------------- 1 | """Some SQLAlchemy specific field types.""" 2 | 3 | from django.db.models.constants import LOOKUP_SEP 4 | 5 | from django_sorcery.db import meta 6 | 7 | from rest_framework import fields, relations 8 | 9 | 10 | class HyperlinkedIdentityField(relations.HyperlinkedIdentityField): 11 | def get_url(self, obj, view_name, request, format): 12 | info = meta.model_info(obj.__class__) 13 | 14 | # Unsaved objects will not yet have a valid URL. 15 | if not all(getattr(obj, i) for i in info.primary_keys): 16 | return None 17 | 18 | if len(info.primary_keys) == 1: 19 | kwargs = {self.lookup_url_kwarg: getattr(obj, self.lookup_field)} 20 | else: 21 | kwargs = {k: getattr(obj, k) for k in info.primary_keys} 22 | 23 | return self.reverse(view_name, kwargs=kwargs, request=request, format=format) 24 | 25 | 26 | class UriField(HyperlinkedIdentityField): 27 | """Represents a uri to the resource.""" 28 | 29 | def get_url(self, obj, view_name, request, format): 30 | """Same as basic HyperlinkedIdentityField except return uri vs full 31 | url.""" 32 | return super().get_url(obj, view_name, None, format) 33 | 34 | 35 | class CharMappingField(fields.DictField): 36 | """Used for Postgresql HSTORE columns for storing key-value pairs.""" 37 | 38 | child = fields.CharField(allow_null=True) 39 | 40 | 41 | class ImplicitExpandableListField(fields.ListField): 42 | """List field which implicitly expands parent field when child field is 43 | expanded assuming parent field is also expandable by being one of the 44 | choices.""" 45 | 46 | def to_internal_value(self, data): 47 | data = super().to_internal_value(data) 48 | for i in data[:]: 49 | parts = i.split(LOOKUP_SEP) 50 | data = list( 51 | ({LOOKUP_SEP.join(parts[:i]) for i in range(1, len(parts))} & set(self.child.choices)) | set(data) 52 | ) 53 | return data 54 | 55 | 56 | class SkippableField(fields.Field): 57 | """Field which is always skipped on to_representation. 58 | 59 | Useful when used together with ``ExpandableModelSerializer`` since it allows 60 | to completely skip expandable field when it is not being expanded. 61 | Especially useful for ``OneToMany`` relations since by default nested 62 | serializer cannot be rendered as none of the PKs of the "many" items are 63 | known unlike ``ManyToOne`` when nested serializer can be rendered with PK. 64 | For example: 65 | 66 | .. code:: 67 | 68 | class FooSerializer(ExpandableModelSerializer): 69 | bar = BarSerializer(many=True) 70 | 71 | class Meta: 72 | model = Foo 73 | session = session 74 | fields = "__all__" 75 | expandable_fields = { 76 | "bar": SkippableField() 77 | } 78 | """ 79 | 80 | def get_attribute(self, instance): 81 | raise fields.SkipField 82 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | services: 14 | postgres: 15 | image: postgres 16 | ports: 17 | - "5432:5432" 18 | env: 19 | POSTGRES_USER: postgres 20 | POSTGRES_PASSWORD: postgres 21 | POSTGRES_DB: test 22 | 23 | strategy: 24 | matrix: 25 | container: 26 | - image: python:latest 27 | toxenv: lint 28 | - image: python:3.6 29 | toxenv: py36-sqla12 30 | - image: python:3.6 31 | toxenv: py36-sqla13 32 | - image: python:3.6 33 | toxenv: py36-sqla14 34 | 35 | - image: python:3.7 36 | toxenv: py37-sqla12 37 | - image: python:3.7 38 | toxenv: py37-sqla13 39 | - image: python:3.7 40 | toxenv: py37-sqla14 41 | 42 | - image: python:3.8 43 | toxenv: py38-sqla12 44 | - image: python:3.8 45 | toxenv: py38-sqla13 46 | - image: python:3.8 47 | toxenv: py38-sqla14 48 | 49 | - image: python:3.9 50 | toxenv: py39-sqla12 51 | - image: python:3.9 52 | toxenv: py39-sqla13 53 | - image: python:3.9 54 | toxenv: py39-sqla14 55 | 56 | - image: python:3.10 57 | toxenv: py310-sqla12 58 | - image: python:3.10 59 | toxenv: py310-sqla13 60 | - image: python:3.10 61 | toxenv: py310-sqla14 62 | 63 | - image: python:3.11 64 | toxenv: py311-sqla12 65 | - image: python:3.11 66 | toxenv: py311-sqla13 67 | - image: python:3.11 68 | toxenv: py311-sqla14 69 | 70 | container: 71 | image: ${{ matrix.container.image }} 72 | 73 | steps: 74 | - uses: actions/checkout@v3 75 | - name: Install dependencies 76 | run: | 77 | apt-get update 78 | apt-get install -y postgresql-client 79 | pip install pre-commit tox==3.28.0 tox-factor coveralls 80 | 81 | - name: Tox 82 | run: | 83 | tox -f ${{ matrix.container.toxenv }} 84 | env: 85 | DATABASE_URL: postgresql://postgres:postgres@postgres 86 | - name: Upload Coverage 87 | if: "!startsWith('lint', matrix.container.toxenv)" 88 | run: coveralls --service=github 89 | env: 90 | GITHUB_TOKEN: ${{ secrets.github_token }} 91 | COVERALLS_FLAG_NAME: ${{ matrix.container.image }} 92 | COVERALLS_PARALLEL: true 93 | 94 | finish: 95 | needs: build 96 | runs-on: ubuntu-latest 97 | container: python:3-slim 98 | steps: 99 | - name: Coveralls Finished 100 | run: | 101 | pip3 install coveralls 102 | coveralls --finish 103 | env: 104 | GITHUB_TOKEN: ${{ secrets.github_token }} 105 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PACKAGE=rest_witchcraft 2 | FILES=$(shell find $(PACKAGE) -iname '*.py' ! -iname '__*') 3 | VERSION=$(shell python setup.py --version) 4 | NEXT=$(shell semver -i $(BUMP) $(VERSION)) 5 | DBS=\ 6 | test 7 | RESETDBS=$(addsuffix -resetdb,$(DBS)) 8 | COVERAGE_FLAGS?=--cov-report term-missing --cov-fail-under=100 9 | 10 | DATABASE_URL?=postgresql://postgres:postgres@localhost 11 | 12 | .PHONY: help list docs $(FILES) 13 | 14 | 15 | help: 16 | @for f in $(MAKEFILE_LIST) ; do \ 17 | echo "$$f:" ; \ 18 | grep -E '^[^[:space:]].*:.*?## .*$$' $$f | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' ; \ 19 | done ; \ 20 | 21 | list: ## list all possible targets 22 | @$(MAKE) -pRrq -f $(lastword $(MAKEFILE_LIST)) : 2>/dev/null | awk -v RS= -F: '/^# File/,/^# Finished Make data base/ {if ($$1 !~ "^[#.]") {print $$1}}' | sort | egrep -v -e '^[^[:alnum:]]' -e '^$@$$' 23 | 24 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 25 | 26 | clean-build: ## remove build artifacts 27 | find -name '*.sqlite3' -delete 28 | rm -rf build dist .eggs .mypy_cache .pytest_cache docs/_build *.egg* 29 | 30 | clean-pyc: ## remove Python file artifacts 31 | find -name '*.pyc' -delete 32 | find -name '*.pyo' -delete 33 | find -name '*~' -delete 34 | find -name '__pycache__' -delete 35 | 36 | clean-test: ## remove test and coverage artifacts 37 | rm -rf .tox .coverage htmlcov 38 | 39 | %-resetdb: 40 | -psql $(DATABASE_URL) -c "drop database $*;" 41 | psql $(DATABASE_URL) -c "create database $*;" 42 | 43 | resetdb: $(RESETDBS) 44 | 45 | lint: ## run pre-commit hooks on all files 46 | pre-commit run --all-files 47 | 48 | coverage: ## check code coverage quickly with the default Python 49 | py.test $(PYTEST_OPTS) \ 50 | --cov=$(PACKAGE) \ 51 | $(COVERAGE_FLAGS) \ 52 | tests 53 | 54 | $(FILES): ## helper target to run coverage tests on a module 55 | py.test $(PYTEST_OPTS) $(COVERAGE_FLAGS) \ 56 | --cov=$(subst /,.,$(firstword $(subst ., ,$@))) $(subst $(PACKAGE),tests,$(dir $@))test_$(notdir $@) 57 | 58 | test: ## run tests 59 | py.test $(PYTEST_OPTS) tests $(PACKAGE) 60 | 61 | check: ## run all tests 62 | tox 63 | 64 | history: docs ## generate HISTORY.rst 65 | gitchangelog > HISTORY.rst 66 | 67 | docs: ## generate docs 68 | $(MAKE) -C docs html 69 | 70 | livedocs: ## generate docs live 71 | $(MAKE) -C docs live 72 | 73 | version: # print version 74 | @echo $(VERSION) 75 | 76 | next: # print next version 77 | @echo $(NEXT) 78 | 79 | bump: history 80 | @sed -i 's/$(VERSION)/$(NEXT)/g' $(PACKAGE)/__version__.py 81 | @sed -i 's/Next version (unreleased yet)/$(NEXT) ($(shell date +"%Y-%m-%d"))/g' HISTORY.rst 82 | @git add . 83 | @git commit -am "Bump version: $(VERSION) → $(NEXT)" 84 | 85 | tag: ## tags branch 86 | git tag -a $$(python setup.py --version) -m $$(python setup.py --version) 87 | 88 | release: dist ## package and upload a release 89 | twine upload dist/* 90 | 91 | dist: clean ## builds source and wheel package 92 | python setup.py sdist 93 | python setup.py bdist_wheel 94 | ls -l dist 95 | -------------------------------------------------------------------------------- /rest_witchcraft/field_mapping.py: -------------------------------------------------------------------------------- 1 | """Field mapping from SQLAlchemy type's to DRF fields.""" 2 | import datetime 3 | import decimal 4 | 5 | from sqlalchemy.dialects import postgresql 6 | from sqlalchemy.sql import sqltypes 7 | 8 | from django_sorcery.db import meta 9 | 10 | from rest_framework import fields 11 | 12 | from rest_enumfield import EnumField 13 | 14 | from .fields import CharMappingField 15 | 16 | 17 | def get_detail_view_name(model): 18 | """Given a model class, return the view name to use for URL relationships 19 | that rever to instances of the model.""" 20 | # TODO split camel case 21 | return "{}-detail".format(model.__name__.lower()) 22 | 23 | 24 | def get_url_kwargs(model): 25 | """Gets kwargs for the UriField.""" 26 | info = meta.model_info(model) 27 | lookup_field = list(info.primary_keys.keys())[0] 28 | 29 | return { 30 | "read_only": True, 31 | "view_name": get_detail_view_name(model), 32 | "lookup_field": lookup_field, 33 | "lookup_url_kwarg": "pk", 34 | } 35 | 36 | 37 | SERIALIZER_FIELD_MAPPING = { 38 | # sqlalchemy types 39 | postgresql.HSTORE: CharMappingField, 40 | sqltypes.Enum: EnumField, 41 | # python types 42 | datetime.date: fields.DateField, 43 | datetime.datetime: fields.DateTimeField, 44 | datetime.time: fields.TimeField, 45 | datetime.timedelta: fields.DurationField, 46 | decimal.Decimal: fields.DecimalField, 47 | float: fields.FloatField, 48 | int: fields.IntegerField, 49 | str: fields.CharField, 50 | } 51 | 52 | try: # pragma: no cover 53 | from sqlalchemy_utils import types 54 | 55 | SERIALIZER_FIELD_MAPPING[types.IPAddressType] = fields.IPAddressField 56 | SERIALIZER_FIELD_MAPPING[types.UUIDType] = fields.UUIDField 57 | SERIALIZER_FIELD_MAPPING[types.URLType] = fields.URLField 58 | except ImportError: # pragma: no cover 59 | pass 60 | 61 | 62 | def get_field_type(column): 63 | """Returns the field type to be used determined by the sqlalchemy column 64 | type or the column type's python type.""" 65 | if isinstance(column.type, sqltypes.Enum) and not column.type.enum_class: 66 | return fields.ChoiceField 67 | 68 | if isinstance(column.type, postgresql.ARRAY): 69 | child_field = SERIALIZER_FIELD_MAPPING.get(column.type.item_type.__class__) or SERIALIZER_FIELD_MAPPING.get( 70 | column.type.item_type.python_type 71 | ) 72 | 73 | if child_field is None: 74 | raise KeyError("Could not figure out field for ARRAY item type '{}'".format(column.type.__class__)) 75 | 76 | class ArrayField(fields.ListField): 77 | """Nested array field for PostreSQL's ARRAY type.""" 78 | 79 | def __init__(self, *args, **kwargs): 80 | kwargs["child"] = child_field() 81 | super().__init__(*args, **kwargs) 82 | 83 | return ArrayField 84 | 85 | if column.type.__class__ in SERIALIZER_FIELD_MAPPING: 86 | return SERIALIZER_FIELD_MAPPING.get(column.type.__class__) 87 | 88 | if issubclass(column.type.python_type, bool): 89 | if hasattr(fields, "NullBooleanField") and column.nullable: 90 | return fields.NullBooleanField # pragma: no cover 91 | 92 | return fields.BooleanField 93 | 94 | for typ in column.type.python_type.mro(): 95 | if typ in SERIALIZER_FIELD_MAPPING: 96 | return SERIALIZER_FIELD_MAPPING.get(typ) 97 | -------------------------------------------------------------------------------- /rest_witchcraft/filters.py: -------------------------------------------------------------------------------- 1 | """Provides generic filtering backends that can be used to filter the results 2 | returned by list views.""" 3 | 4 | from sqlalchemy import func, or_ 5 | from sqlalchemy.sql import operators 6 | 7 | from django.template import loader 8 | from django.utils.encoding import force_str 9 | from django.utils.translation import gettext_lazy 10 | 11 | from rest_framework.compat import coreapi, coreschema 12 | from rest_framework.filters import BaseFilterBackend 13 | from rest_framework.settings import api_settings 14 | 15 | 16 | class SearchFilter(BaseFilterBackend): 17 | search_param = api_settings.SEARCH_PARAM 18 | template = "rest_framework/filters/search.html" 19 | lookup_prefixes = { 20 | "": lambda c, x: operators.ilike_op(c, "%{}%".format(x)), # icontains 21 | "^": lambda c, x: c.ilike(x.replace("%", "%%") + "%"), # istartswith 22 | "=": lambda c, x: func.lower(c) == func.lower(x), # iequals 23 | "@": operators.eq, # equals 24 | } 25 | search_title = gettext_lazy("Search") 26 | search_description = gettext_lazy("A search term.") 27 | 28 | def get_schema_fields(self, view): 29 | assert coreapi is not None, "coreapi must be installed to use `get_schema_fields()`" 30 | assert coreschema is not None, "coreschema must be installed to use `get_schema_fields()`" 31 | return [ 32 | coreapi.Field( 33 | name=self.search_param, 34 | required=False, 35 | location="query", 36 | schema=coreschema.String( 37 | title=force_str(self.search_title), description=force_str(self.search_description) 38 | ), 39 | ) 40 | ] 41 | 42 | def get_schema_operation_parameters(self, view): 43 | return [ 44 | { 45 | "name": self.search_param, 46 | "required": False, 47 | "in": "query", 48 | "description": force_str(self.search_description), 49 | "schema": {"type": "string"}, 50 | } 51 | ] 52 | 53 | def get_search_fields(self, view, request): 54 | return getattr(view, "search_fields", None) 55 | 56 | def get_search_terms(self, request): 57 | params = request.query_params.get(self.search_param, "") 58 | params = params.replace("\x00", "") # strip null characters 59 | params = params.replace(",", " ") 60 | return params.split() 61 | 62 | def to_html(self, request, queryset, view): 63 | if not getattr(view, "search_fields", None): 64 | return "" 65 | 66 | term = self.get_search_terms(request) 67 | term = term[0] if term else "" 68 | context = {"param": self.search_param, "term": term} 69 | template = loader.get_template(self.template) 70 | return template.render(context) 71 | 72 | def filter_queryset(self, request, queryset, view): 73 | search_fields = self.get_search_fields(view, request) 74 | search_terms = self.get_search_terms(request) 75 | 76 | if not search_fields or not search_terms: 77 | return queryset 78 | 79 | model = view.get_model() 80 | 81 | expressions = [] 82 | for field in search_fields: 83 | for term in search_terms: 84 | expr = self.get_expression(model, field, term) 85 | if expr is not None: 86 | expressions.append(expr) 87 | 88 | return queryset.filter(or_(*expressions)) 89 | 90 | def get_expression(self, model, field, term): 91 | op = self.lookup_prefixes[""] 92 | if field[0] in self.lookup_prefixes: 93 | op = self.lookup_prefixes[field[0]] 94 | field = field[1:] 95 | 96 | return op(getattr(model, field), term) 97 | -------------------------------------------------------------------------------- /tests/test_field_mapping.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sqa 2 | from sqlalchemy.dialects import postgresql 3 | 4 | from django.test import SimpleTestCase 5 | 6 | from rest_framework import fields 7 | 8 | from rest_enumfield import EnumField 9 | 10 | from rest_witchcraft import field_mapping 11 | from rest_witchcraft.fields import CharMappingField 12 | 13 | from .models import Owner 14 | 15 | 16 | class TestModelViewName(SimpleTestCase): 17 | def test_get_detail_view_name(self): 18 | name = field_mapping.get_detail_view_name(EnumField) 19 | 20 | self.assertEqual(name, "enumfield-detail") 21 | 22 | 23 | class TestGetUrlKwargs(SimpleTestCase): 24 | def test_get_url_kwargs(self): 25 | kwargs = field_mapping.get_url_kwargs(Owner) 26 | self.assertDictEqual( 27 | {"read_only": True, "view_name": "owner-detail", "lookup_field": "id", "lookup_url_kwarg": "pk"}, kwargs 28 | ) 29 | 30 | 31 | class TestGetFieldType(SimpleTestCase): 32 | def test_get_field_type_can_map_string_column(self): 33 | field = field_mapping.get_field_type(sqa.Column(sqa.String())) 34 | 35 | self.assertTrue(issubclass(field, fields.CharField)) 36 | 37 | def test_get_field_type_can_map_int_column(self): 38 | field = field_mapping.get_field_type(sqa.Column(sqa.BigInteger())) 39 | 40 | self.assertTrue(issubclass(field, fields.IntegerField)) 41 | 42 | def test_get_field_type_can_map_float_column(self): 43 | field = field_mapping.get_field_type(sqa.Column(sqa.Float())) 44 | 45 | self.assertTrue(issubclass(field, fields.FloatField)) 46 | 47 | def test_get_field_type_can_map_decimal_column(self): 48 | column = sqa.Column(sqa.Numeric(asdecimal=True)) 49 | field = field_mapping.get_field_type(column) 50 | 51 | self.assertTrue(issubclass(field, fields.DecimalField)) 52 | 53 | def test_get_field_type_can_map_interval_column(self): 54 | column = sqa.Column(sqa.Interval()) 55 | field = field_mapping.get_field_type(column) 56 | 57 | self.assertTrue(issubclass(field, fields.DurationField)) 58 | 59 | def test_get_field_type_can_map_time_column(self): 60 | column = sqa.Column(sqa.Time()) 61 | field = field_mapping.get_field_type(column) 62 | 63 | self.assertTrue(issubclass(field, fields.TimeField)) 64 | 65 | def test_get_field_type_can_map_datetime_column(self): 66 | column = sqa.Column(sqa.DateTime()) 67 | field = field_mapping.get_field_type(column) 68 | 69 | self.assertTrue(issubclass(field, fields.DateTimeField)) 70 | 71 | def test_get_field_type_can_map_date_column(self): 72 | column = sqa.Column(sqa.Date()) 73 | field = field_mapping.get_field_type(column) 74 | 75 | self.assertTrue(issubclass(field, fields.DateField)) 76 | 77 | def test_get_field_type_can_map_bool_column(self): 78 | column = sqa.Column(sqa.Boolean(), nullable=False) 79 | field = field_mapping.get_field_type(column) 80 | 81 | self.assertTrue(issubclass(field, fields.BooleanField)) 82 | 83 | def test_get_field_type_can_map_nullable_bool_column(self): 84 | column = sqa.Column(sqa.Boolean(), nullable=True) 85 | field = field_mapping.get_field_type(column) 86 | 87 | if hasattr(fields, "NullBooleanField"): 88 | self.assertTrue(issubclass(field, fields.NullBooleanField)) 89 | else: 90 | self.assertTrue(issubclass(field, fields.BooleanField)) 91 | 92 | def test_get_field_type_can_map_pg_hstore_column(self): 93 | column = sqa.Column(postgresql.HSTORE()) 94 | field = field_mapping.get_field_type(column) 95 | 96 | self.assertTrue(issubclass(field, CharMappingField)) 97 | 98 | def test_get_field_type_can_map_pg_array_column(self): 99 | column = sqa.Column(postgresql.ARRAY(item_type=sqa.Integer())) 100 | field = field_mapping.get_field_type(column) 101 | 102 | self.assertTrue(issubclass(field, fields.ListField)) 103 | self.assertIsInstance(field().child, fields.IntegerField) 104 | 105 | def test_get_field_type_pg_array_column_raises_when_item_type_not_found(self): 106 | class DummyType: 107 | python_type = None 108 | 109 | column = sqa.Column(postgresql.ARRAY(item_type=DummyType)) 110 | 111 | with self.assertRaises(KeyError): 112 | field_mapping.get_field_type(column) 113 | -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from sqlalchemy import Column, ForeignKey, Sequence, orm, types 4 | 5 | from django.conf import settings 6 | from django.core.exceptions import ValidationError 7 | 8 | from django_sorcery.db import SQLAlchemy 9 | from django_sorcery.db.models import autocoerce 10 | 11 | 12 | session = SQLAlchemy(settings.DB_URL) 13 | 14 | Base = session.Model 15 | 16 | COLORS = ["red", "green", "blue", "silver"] 17 | 18 | 19 | class Owner(Base): 20 | __tablename__ = "owners" 21 | 22 | id = Column(types.Integer(), primary_key=True) 23 | first_name = Column(types.Unicode(length=50)) 24 | last_name = Column(types.Unicode(length=50)) 25 | 26 | 27 | class VehicleType(enum.Enum): 28 | bus = "Bus" 29 | car = "Car" 30 | 31 | 32 | class Engine: 33 | def __init__(self, cylinders, displacement, type_, fuel_type): 34 | self.cylinders = cylinders 35 | self.displacement = displacement 36 | self.type_ = type_ 37 | self.fuel_type = fuel_type 38 | 39 | def __composite_values__(self): 40 | return self.cylinders, self.displacement, self.type_, self.fuel_type 41 | 42 | def __repr__(self): 43 | return 'Engine(cylinder={},displacement={},type="{}",fuel_type="{}")'.format(*self.__composite_values__()) 44 | 45 | def __eq__(self, other): 46 | return isinstance(other, Engine) and other.__composite_values__() == self.__composite_values__() 47 | 48 | 49 | @autocoerce 50 | class Vehicle(Base): 51 | __tablename__ = "vehicles" 52 | 53 | id = Column(types.Integer(), Sequence("seq_id"), primary_key=True, doc="The primary key") 54 | name = Column(types.String(length=50), doc="The name of the vehicle") 55 | type = Column(types.Enum(VehicleType, name="vehicle_type"), nullable=False) 56 | created_at = Column(types.DateTime()) 57 | paint = Column(types.Enum(*COLORS, name="colors")) 58 | is_used = Column(types.Boolean) 59 | 60 | @property 61 | def lower_name(self): 62 | return self.name.lower() 63 | 64 | _engine_cylinders = Column("engine_cylinders", types.BigInteger()) 65 | _engine_displacement = Column("engine_displacement", types.Numeric(asdecimal=True, precision=10, scale=2)) 66 | _engine_type = Column("engine_type", types.String(length=25)) 67 | _engine_fuel_type = Column("engine_fuel_type", types.String(length=10)) 68 | engine = orm.composite(Engine, _engine_cylinders, _engine_displacement, _engine_type, _engine_fuel_type) 69 | 70 | _owner_id = Column("owner_id", types.Integer(), ForeignKey(Owner.id)) 71 | owner = orm.relationship(Owner, backref="vehicles") 72 | 73 | def clean_name(self): 74 | if self.name == "invalid": 75 | raise ValidationError("invalid vehicle name") 76 | 77 | 78 | class VehicleOther(Base): 79 | __tablename__ = "vehicle_other" 80 | 81 | id = Column(types.Integer(), primary_key=True, doc="The primary key") 82 | 83 | advertising_cost = Column(types.BigInteger()) 84 | base_invoice = Column(types.BigInteger()) 85 | base_msrp = Column(types.BigInteger()) 86 | destination_charge = Column(types.BigInteger()) 87 | gas_guzzler_tax = Column(types.BigInteger()) 88 | list_price = Column(types.BigInteger()) 89 | misc_cost = Column(types.BigInteger()) 90 | options_invoice = Column(types.BigInteger()) 91 | options_msrp = Column(types.BigInteger()) 92 | package_discount = Column(types.BigInteger()) 93 | prep_cost = Column(types.BigInteger()) 94 | total_msrp = Column(types.BigInteger()) 95 | vehicle_invoice = Column(types.BigInteger()) 96 | vehicle_msrp = Column(types.BigInteger()) 97 | 98 | _vehicle_id = Column(types.Integer(), ForeignKey(Vehicle.id)) 99 | vehicle = orm.relationship(Vehicle, backref=orm.backref("other", uselist=False), uselist=False) 100 | 101 | 102 | class Option(Base): 103 | __tablename__ = "options" 104 | id = Column(types.Integer(), primary_key=True) 105 | name = Column(types.String(length=50)) 106 | 107 | _vehicle_id = Column(types.Integer(), ForeignKey(Vehicle.id)) 108 | vehicle = orm.relationship(Vehicle, backref="options") 109 | 110 | 111 | class ModelWithJson(Base): 112 | __tablename__ = "model_with_json" 113 | 114 | id = Column(types.Integer(), Sequence("seq_id"), primary_key=True) 115 | 116 | 117 | session.create_all() 118 | 119 | # getting around sqlite not supporting json column 120 | ModelWithJson.js = Column(types.JSON()) 121 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | toxworkdir={env:TOX_WORK_DIR:.tox} 3 | skipsdist = true 4 | envlist = 5 | lint 6 | pypy36-{drf38,drf39,drf310}-{sqla12,sqla13,sqla14}-{dj11,dj20,dj21,dj22}- 7 | 8 | py36-{sqla12,sqla13,sqla14}-{ drf38}-{dj11,dj20,dj21,dj22 } 9 | py36-{sqla12,sqla13,sqla14}-{ drf39}-{dj11,dj20,dj21,dj22 } 10 | py36-{sqla12,sqla13,sqla14}-{drf310}-{dj11,dj20,dj21,dj22 } 11 | 12 | py37-{sqla12,sqla13,sqla14}-{ drf38}-{ dj20,dj21,dj22 } 13 | py37-{sqla12,sqla13,sqla14}-{ drf39}-{ dj20,dj21,dj22 } 14 | py37-{sqla12,sqla13,sqla14}-{drf310}-{ dj20,dj21,dj22,dj30 } 15 | py37-{sqla12,sqla13,sqla14}-{drf311}-{ dj20,dj21,dj22,dj30,dj31,dj32 } 16 | py37-{sqla12,sqla13,sqla14}-{drf312}-{ dj22,dj30,dj31,dj32 } 17 | py37-{sqla12,sqla13,sqla14}-{drf313}-{ dj22,dj30,dj31,dj32 } 18 | 19 | py38-{sqla12,sqla13,sqla14}-{ drf38}-{ dj20,dj21,dj22 } 20 | py38-{sqla12,sqla13,sqla14}-{ drf39}-{ dj20,dj21,dj22 } 21 | py38-{sqla12,sqla13,sqla14}-{drf310}-{ dj20,dj21,dj22,dj30 } 22 | py38-{sqla12,sqla13,sqla14}-{drf311}-{ dj20,dj21,dj22,dj30,dj31,dj32 } 23 | py38-{sqla12,sqla13,sqla14}-{drf312}-{ dj22,dj30,dj31,dj32,dj40,dj41} 24 | py38-{sqla12,sqla13,sqla14}-{drf313}-{ dj22,dj30,dj31,dj32,dj40,dj41} 25 | py38-{sqla12,sqla13,sqla14}-{drf314}-{ dj30,dj31,dj32,dj40,dj41} 26 | 27 | py39-{sqla12,sqla13,sqla14}-{ drf38}-{ dj20,dj21,dj22 } 28 | py39-{sqla12,sqla13,sqla14}-{ drf39}-{ dj20,dj21,dj22 } 29 | py39-{sqla12,sqla13,sqla14}-{drf310}-{ dj20,dj21,dj22,dj30 } 30 | py39-{sqla12,sqla13,sqla14}-{drf311}-{ dj20,dj21,dj22,dj30,dj31,dj32 } 31 | py39-{sqla12,sqla13,sqla14}-{drf312}-{ dj22,dj30,dj31,dj32,dj40,dj41} 32 | py39-{sqla12,sqla13,sqla14}-{drf313}-{ dj22,dj30,dj31,dj32,dj40,dj41} 33 | py39-{sqla12,sqla13,sqla14}-{drf314}-{ dj30,dj31,dj32,dj40,dj41} 34 | 35 | py310-{sqla12,sqla13,sqla14}-{ drf39}-{ dj21,dj22 } 36 | py310-{sqla12,sqla13,sqla14}-{drf310}-{ dj21,dj22,dj30 } 37 | py310-{sqla12,sqla13,sqla14}-{drf311}-{ dj21,dj22,dj30,dj31,dj32 } 38 | py310-{sqla12,sqla13,sqla14}-{drf312}-{ dj22,dj30,dj31,dj32,dj40,dj41} 39 | py310-{sqla12,sqla13,sqla14}-{drf313}-{ dj22,dj30,dj31,dj32,dj40,dj41} 40 | py310-{sqla12,sqla13,sqla14}-{drf314}-{ dj30,dj31,dj32,dj40,dj41} 41 | 42 | py311-{sqla12,sqla13,sqla14}-{ drf39}-{ dj22 } 43 | py311-{sqla12,sqla13,sqla14}-{drf310}-{ dj22,dj30 } 44 | py311-{sqla12,sqla13,sqla14}-{drf311}-{ dj22,dj30,dj31,dj32 } 45 | py311-{sqla12,sqla13,sqla14}-{drf312}-{ dj22,dj30,dj31,dj32,dj40,dj41} 46 | py311-{sqla12,sqla13,sqla14}-{drf313}-{ dj22,dj30,dj31,dj32,dj40,dj41} 47 | py311-{sqla12,sqla13,sqla14}-{drf314}-{ dj30,dj31,dj32,dj40,dj41} 48 | 49 | [testenv] 50 | passenv = 51 | LC_ALL 52 | LANG 53 | HOME 54 | DJANGO_SETTINGS_MODULE 55 | PATH 56 | LDFLAGS 57 | CPPFLAGS 58 | DATABASE_URL 59 | basepython = 60 | py36: python3.6 61 | py37: python3.7 62 | py38: python3.8 63 | py39: python3.9 64 | py310: python3.10 65 | py311: python3.11 66 | pypy36: pypy3.6 67 | deps = 68 | -rrequirements.txt 69 | sqla12: sqlalchemy==1.2.* 70 | sqla13: sqlalchemy==1.3.* 71 | sqla14: sqlalchemy==1.4.* 72 | dj11: django==1.11.* 73 | dj20: django==2.0.* 74 | dj21: django==2.1.* 75 | dj22: django==2.2.* 76 | dj30: django==3.0.* 77 | dj31: django==3.1.* 78 | dj32: django==3.2.* 79 | dj40: django==4.0.* 80 | dj41: django==4.1.* 81 | drf38: djangorestframework==3.8.* 82 | drf39: djangorestframework==3.9.* 83 | drf310: djangorestframework==3.10.* 84 | drf311: djangorestframework==3.11.* 85 | drf312: djangorestframework==3.12.* 86 | drf313: djangorestframework==3.13.* 87 | drf314: djangorestframework==3.14.* 88 | whitelist_externals = 89 | make 90 | commands = 91 | pip freeze 92 | make resetdb 93 | make coverage 94 | 95 | [testenv:lint] 96 | basepython = python3.11 97 | commands = 98 | make lint 99 | -------------------------------------------------------------------------------- /rest_witchcraft/mixins.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from itertools import chain 3 | 4 | import six 5 | 6 | from sqlalchemy import orm 7 | 8 | from django.db.models.constants import LOOKUP_SEP 9 | 10 | from django_sorcery.db import meta 11 | 12 | from rest_framework import mixins 13 | 14 | 15 | ToLoadField = collections.namedtuple("ToLoadField", ["field", "direction"]) 16 | 17 | 18 | class DestroyModelMixin(mixins.DestroyModelMixin): 19 | """Deletes a model instance.""" 20 | 21 | def perform_destroy(self, instance): 22 | session = self.get_session() 23 | session.delete(instance) 24 | 25 | 26 | class QuerySerializerMixin: 27 | """Adds query serializer validation logic to viewset. 28 | 29 | Query will be validated as part of query viewset initialization 30 | therefore query will be validated before any of the viewset actions 31 | are executed. 32 | 33 | In addition query serializer will be included in serializer context 34 | for standard viewset serializers. That 35 | """ 36 | 37 | query_serializer_class = None 38 | 39 | @property 40 | def query_serializer(self): 41 | return getattr(self, "_query_serializer", None) 42 | 43 | @query_serializer.setter 44 | def query_serializer(self, value): 45 | self._query_serializer = value 46 | 47 | def get_query_serializer_class(self): 48 | return ( 49 | self.query_serializer_class 50 | or getattr(self.get_serializer_class()(), "get_query_serializer_class", lambda: None)() 51 | ) 52 | 53 | def get_query_serializer_context(self): 54 | return self.get_serializer_context() 55 | 56 | def get_query_serializer(self, *args, **kwargs): 57 | serializer_class = kwargs.pop("serializer_class", None) or self.get_query_serializer_class() 58 | if serializer_class is None: 59 | return 60 | kwargs.setdefault("context", self.get_query_serializer_context()) 61 | kwargs.setdefault("data", dict(self.request.GET.lists())) 62 | self.query_serializer = serializer = serializer_class(*args, **kwargs) 63 | serializer.is_valid() 64 | return serializer 65 | 66 | def check_query(self): 67 | serializer = self.get_query_serializer() 68 | if serializer is not None: 69 | serializer.is_valid(raise_exception=True) 70 | 71 | def initial(self, request, *args, **kwargs): 72 | super().initial(request, *args, **kwargs) 73 | self.check_query() 74 | 75 | 76 | class ExpandableQuerySerializerMixin(QuerySerializerMixin): 77 | """Adds expandable query serializer validation logic to viewset as well as 78 | automatic eager load of expanded fields on the serializer. 79 | 80 | The query serializer is expected to be generated by 81 | :py:meth:`rest_witchcraft.serializers.ExpandableModelSerializer.get_query_serializer_class`. 82 | """ 83 | 84 | def get_queryset(self): 85 | queryset = super().get_queryset() 86 | 87 | serializer = self.query_serializer 88 | if serializer is None: 89 | return queryset 90 | 91 | return self.expand_queryset(queryset, chain(*serializer.validated_data.values())) 92 | 93 | def expand_queryset(self, queryset, values): 94 | to_expand = [] 95 | 96 | for value in values: 97 | to_load = [] 98 | components = value.split(LOOKUP_SEP) 99 | 100 | model = queryset._only_full_mapper_zero("get").class_ 101 | for c in components: 102 | props = meta.model_info(model).relationships 103 | try: 104 | field = getattr(model, c) 105 | prop = props[c] 106 | model = prop.relationship._dependency_processor.mapper.class_ 107 | except (KeyError, AttributeError): 108 | to_load = [] 109 | break 110 | else: 111 | to_load.append(ToLoadField(field, prop.direction)) 112 | 113 | if to_load: 114 | to_expand.append(to_load) 115 | 116 | if to_expand: 117 | queryset = queryset.options( 118 | *[ 119 | six.moves.reduce( 120 | lambda a, b: ( 121 | a.selectinload(b.field) 122 | if b.direction in {orm.interfaces.ONETOMANY, orm.interfaces.MANYTOMANY} 123 | else a.joinedload(b.field) 124 | ), 125 | expand, 126 | orm, 127 | ) 128 | for expand in to_expand 129 | ] 130 | ) 131 | 132 | return queryset 133 | 134 | def get_serializer_context(self): 135 | context = super().get_serializer_context() 136 | context["query_serializer"] = self.query_serializer 137 | return context 138 | -------------------------------------------------------------------------------- /tests/test_mixins.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import joinedload 2 | 3 | from django.http import QueryDict 4 | from django.test import SimpleTestCase 5 | 6 | from rest_framework.fields import CharField, IntegerField 7 | from rest_framework.response import Response 8 | from rest_framework.serializers import Serializer 9 | from rest_framework.test import APIRequestFactory 10 | from rest_framework.viewsets import GenericViewSet, ModelViewSet 11 | 12 | from rest_witchcraft.fields import SkippableField 13 | from rest_witchcraft.mixins import ExpandableQuerySerializerMixin 14 | from rest_witchcraft.serializers import ExpandableModelSerializer, ModelSerializer 15 | 16 | from .models import Engine, Option, Owner, Vehicle, VehicleType, session 17 | from .test_routers import UnAuthMixin 18 | 19 | 20 | class DummySerializer(ModelSerializer): 21 | class Meta: 22 | session = session 23 | model = Vehicle 24 | fields = "__all__" 25 | 26 | 27 | class VehicleOwnerStubSerializer(Serializer): 28 | id = IntegerField(source="_owner_id") 29 | 30 | 31 | class VehicleSerializer(ExpandableModelSerializer): 32 | name = CharField() 33 | 34 | class Meta: 35 | session = session 36 | model = Vehicle 37 | expandable_fields = { 38 | "owner": VehicleOwnerStubSerializer(source="*", read_only=True), 39 | "options": SkippableField(), 40 | "name": CharField(), 41 | } 42 | fields = "__all__" 43 | 44 | 45 | class DummyViewSet(ExpandableQuerySerializerMixin, GenericViewSet): 46 | serializer_class = DummySerializer 47 | 48 | def list(self, request, *args, **kwargs): 49 | return Response() 50 | 51 | 52 | class ExpandableViewSet(UnAuthMixin, ExpandableQuerySerializerMixin, ModelViewSet): 53 | serializer_class = VehicleSerializer 54 | queryset = Vehicle.objects 55 | 56 | def list(self, request, *args, **kwargs): 57 | r = super().list(request, *args, **kwargs) 58 | r.data = {"query": str(self.get_queryset()), "results": r.data} 59 | return r 60 | 61 | 62 | class TestDummyViewSet(SimpleTestCase): 63 | def test_no_queryset(self): 64 | self.assertIsNone(DummyViewSet().get_query_serializer()) 65 | 66 | 67 | class TestQuerySerializerMixin(SimpleTestCase): 68 | def setUp(self): 69 | super().setUp() 70 | self.rf = APIRequestFactory() 71 | session.add( 72 | Vehicle( 73 | name="Test vehicle", 74 | type=VehicleType.car, 75 | engine=Engine(4, 1234, None, None), 76 | owner=Owner(id=1, first_name="Test", last_name="Owner"), 77 | options=[Option(name="Navigation"), Option(name="Rocket Engine")], 78 | ) 79 | ) 80 | session.flush() 81 | self.maxDiff = None 82 | 83 | def tearDown(self): 84 | super().tearDown() 85 | session.rollback() 86 | 87 | def test_invalid_query(self): 88 | view = ExpandableViewSet.as_view(actions={"get": "list"}) 89 | 90 | self.assertEqual(view(self.rf.get("/")).status_code, 200) 91 | self.assertEqual(view(self.rf.get("/", {"expand": "owner"})).status_code, 200) 92 | self.assertEqual(view(self.rf.get("/", {"expand": "name"})).status_code, 200) 93 | self.assertEqual(view(self.rf.get("/", {"expand": "haha"})).status_code, 400) 94 | 95 | def test_no_query_serializer(self): 96 | view = ExpandableViewSet.as_view(actions={"get": "list"}, serializer_class=DummySerializer) 97 | 98 | self.assertEqual(view(self.rf.get("/")).status_code, 200) 99 | 100 | def test_eagerload_sql(self): 101 | view = ExpandableViewSet.as_view(actions={"get": "list"}) 102 | 103 | self.assertNotIn("LEFT OUTER JOIN", view(self.rf.get("/")).data["query"]) 104 | self.assertNotIn("LEFT OUTER JOIN", view(self.rf.get("/", {"expand": "name"})).data["query"]) 105 | 106 | r = view(self.rf.get("/", {"expand": "owner"})) 107 | self.assertIn("LEFT OUTER JOIN", r.data["query"]) 108 | self.assertEqual(r.data["query"].count("LEFT OUTER JOIN"), 1) 109 | self.assertNotIn("options", r.data["results"][0]) 110 | 111 | # one to many should not add any more joins since selectinload is used 112 | r = view(self.rf.get("/", QueryDict("expand=owner&expand=options"))) 113 | self.assertIn("LEFT OUTER JOIN", r.data["query"]) 114 | self.assertEqual(r.data["query"].count("LEFT OUTER JOIN"), 1) 115 | self.assertIn("options", r.data["results"][0]) 116 | 117 | def test_already_eagerload(self): 118 | view = ExpandableViewSet.as_view( 119 | actions={"get": "list"}, queryset=Vehicle.objects.options(joinedload(Vehicle.owner)) 120 | ) 121 | 122 | r = view(self.rf.get("/", {"expand": "owner"})) 123 | self.assertIn("LEFT OUTER JOIN", r.data["query"]) 124 | # even if we add more joinedloads sqlalchemy should normalize them 125 | # as that exact path is already joined in base queryset 126 | self.assertEqual(r.data["query"].count("LEFT OUTER JOIN"), 1) 127 | -------------------------------------------------------------------------------- /tests/test_filters.py: -------------------------------------------------------------------------------- 1 | import coreschema 2 | 3 | from django.test import RequestFactory 4 | 5 | from rest_framework.settings import api_settings 6 | from rest_framework.test import APISimpleTestCase 7 | 8 | from rest_witchcraft.filters import SearchFilter 9 | from rest_witchcraft.serializers import ModelSerializer 10 | from rest_witchcraft.viewsets import ModelViewSet 11 | 12 | from .models import Owner, session 13 | 14 | 15 | class OwnerSerializer(ModelSerializer): 16 | class Meta: 17 | model = Owner 18 | session = session 19 | fields = "first_name", "last_name" 20 | 21 | 22 | class OwnerViewSet(ModelViewSet): 23 | serializer_class = OwnerSerializer 24 | session = session 25 | queryset = property(lambda self: session.query(Owner)) 26 | 27 | 28 | class TestSearchFilters(APISimpleTestCase): 29 | factory = RequestFactory() 30 | 31 | def setUp(self): 32 | super().setUp() 33 | self.viewset_class = OwnerViewSet 34 | self.filter = SearchFilter() 35 | 36 | self.owner1 = Owner(id=1, first_name="Joe", last_name="Smith") 37 | self.owner2 = Owner(id=2, first_name="Jon", last_name="Snow") 38 | 39 | self.owners = [self.owner1, self.owner2] 40 | session.add_all(self.owners) 41 | session.flush() 42 | 43 | def tearDown(self): 44 | session.rollback() 45 | super().tearDown() 46 | 47 | def test_search_field(self): 48 | viewset = self.viewset_class() 49 | viewset.action_map = {"get": "list"} 50 | 51 | schema = self.filter.get_schema_fields(viewset) 52 | 53 | self.assertEqual(len(schema), 1) 54 | field = schema[0] 55 | 56 | self.assertEqual(field.name, api_settings.SEARCH_PARAM) 57 | self.assertFalse(field.required) 58 | self.assertEqual(field.location, "query") 59 | self.assertIsInstance(field.schema, coreschema.String) 60 | 61 | schema_ops = self.filter.get_schema_operation_parameters(viewset) 62 | self.assertDictEqual( 63 | {"test": schema_ops}, 64 | { 65 | "test": [ 66 | { 67 | "name": "search", 68 | "required": False, 69 | "in": "query", 70 | "description": "A search term.", 71 | "schema": {"type": "string"}, 72 | } 73 | ] 74 | }, 75 | ) 76 | 77 | def test_to_html(self): 78 | viewset = self.viewset_class() 79 | viewset.action_map = {"get": "list"} 80 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "jo"})) 81 | 82 | html = self.filter.to_html(request, viewset.get_queryset(), viewset) 83 | 84 | self.assertInHTML("

Search

", html) 85 | 86 | del viewset.__class__.search_fields 87 | html = self.filter.to_html(request, viewset.get_queryset(), viewset) 88 | self.assertEqual(html, "") 89 | 90 | def test_icontains(self): 91 | self.viewset_class.search_fields = ["first_name", "last_name"] 92 | 93 | viewset = self.viewset_class() 94 | viewset.action_map = {"get": "list"} 95 | request = viewset.initialize_request(self.factory.get("/")) 96 | 97 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 98 | 99 | self.assertEqual(set(query.all()), set(self.owners)) 100 | 101 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "jon"})) 102 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 103 | 104 | self.assertEqual(set(query.all()), {self.owner2}) 105 | 106 | def test_istartswith(self): 107 | self.viewset_class.search_fields = ["^first_name", "^last_name"] 108 | viewset = self.viewset_class() 109 | viewset.action_map = {"get": "list"} 110 | 111 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "jo"})) 112 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 113 | self.assertEqual(set(query.all()), set(self.owners)) 114 | 115 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "sno"})) 116 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 117 | self.assertEqual(set(query.all()), {self.owner2}) 118 | 119 | def test_iexact(self): 120 | self.viewset_class.search_fields = ["=first_name", "=last_name"] 121 | viewset = self.viewset_class() 122 | viewset.action_map = {"get": "list"} 123 | 124 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "sno"})) 125 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 126 | self.assertEqual(set(query.all()), set()) 127 | 128 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "snow"})) 129 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 130 | self.assertEqual(set(query.all()), {self.owner2}) 131 | 132 | def test_exact(self): 133 | self.viewset_class.search_fields = ["@first_name", "@last_name"] 134 | viewset = self.viewset_class() 135 | viewset.action_map = {"get": "list"} 136 | 137 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "snow"})) 138 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 139 | self.assertEqual(set(query.all()), set()) 140 | 141 | request = viewset.initialize_request(self.factory.get("/", {api_settings.SEARCH_PARAM: "Snow"})) 142 | query = self.filter.filter_queryset(request, viewset.get_queryset(), viewset) 143 | self.assertEqual(set(query.all()), {self.owner2}) 144 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Django REST Witchcraft documentation build configuration file, created by 4 | # sphinx-quickstart on Sat Jun 10 09:20:17 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import os 16 | import sys 17 | 18 | import sphinx_rtd_theme 19 | 20 | import django 21 | import django.test.utils 22 | from django.conf import settings 23 | 24 | 25 | # If extensions (or modules to document with autodoc) are in another directory, 26 | # add these directories to sys.path here. If the directory is relative to the 27 | # documentation root, use os.path.abspath to make it absolute, like shown here. 28 | # 29 | 30 | 31 | sys.path.insert(0, os.path.abspath("..")) 32 | sys.path.insert(0, os.path.abspath(".")) 33 | 34 | settings.configure() 35 | 36 | getattr(django, "setup", bool)() 37 | django.test.utils.setup_test_environment() 38 | 39 | here = os.path.abspath(os.path.dirname(__file__)) 40 | 41 | about = {} 42 | with open(os.path.join(here, "..", "rest_witchcraft", "__version__.py")) as f: 43 | exec(f.read(), about) # yapf: disable 44 | 45 | # -- General configuration ------------------------------------------------ 46 | 47 | # If your documentation needs a minimal Sphinx version, state it here. 48 | # 49 | # needs_sphinx = '1.0' 50 | 51 | # Add any Sphinx extension module names here, as strings. They can be 52 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 53 | # ones. 54 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode", "sphinx.ext.napoleon", "sphinx.ext.coverage"] 55 | 56 | # Add any paths that contain templates here, relative to this directory. 57 | templates_path = ["_templates"] 58 | 59 | # The suffix(es) of source filenames. 60 | # You can specify multiple suffix as a list of string: 61 | # 62 | # source_suffix = ['.rst', '.md'] 63 | source_suffix = ".rst" 64 | 65 | # The master toctree document. 66 | master_doc = "index" 67 | 68 | # General information about the project. 69 | project = "Django REST Witchcraft" 70 | copyright = "2017, Serkan Hosca" 71 | author = "Serkan Hosca" 72 | 73 | # The version info for the project you're documenting, acts as replacement for 74 | # |version| and |release|, also used in various other places throughout the 75 | # built documents. 76 | # 77 | # The short X.Y version. 78 | version = about["__version__"] 79 | # The full version, including alpha/beta/rc tags. 80 | release = about["__version__"] 81 | 82 | # The language for content autogenerated by Sphinx. Refer to documentation 83 | # for a list of supported languages. 84 | # 85 | # This is also used if you do content translation via gettext catalogs. 86 | # Usually you set "language" from the command line for these cases. 87 | language = "en" 88 | 89 | # List of patterns, relative to source directory, that match files and 90 | # directories to ignore when looking for source files. 91 | # This patterns also effect to html_static_path and html_extra_path 92 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 93 | 94 | # The name of the Pygments (syntax highlighting) style to use. 95 | pygments_style = "sphinx" 96 | 97 | # If true, `todo` and `todoList` produce output, else they produce nothing. 98 | todo_include_todos = False 99 | 100 | # -- Options for HTML output ---------------------------------------------- 101 | 102 | # The theme to use for HTML and HTML Help pages. See the documentation for 103 | # a list of builtin themes. 104 | # 105 | html_theme = "sphinx_rtd_theme" 106 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 107 | 108 | # Theme options are theme-specific and customize the look and feel of a theme 109 | # further. For a list of options available for each theme, see the 110 | # documentation. 111 | # 112 | # html_theme_options = {} 113 | 114 | # Add any paths that contain custom static files (such as style sheets) here, 115 | # relative to this directory. They are copied after the builtin static files, 116 | # so a file named "default.css" will overwrite the builtin "default.css". 117 | html_static_path = ["_static"] 118 | 119 | # -- Options for HTMLHelp output ------------------------------------------ 120 | 121 | # Output file base name for HTML help builder. 122 | htmlhelp_basename = "DjangoRESTWitchcraftdoc" 123 | 124 | # -- Options for LaTeX output --------------------------------------------- 125 | 126 | latex_elements = { 127 | # The paper size ('letterpaper' or 'a4paper'). 128 | # 129 | # 'papersize': 'letterpaper', 130 | # The font size ('10pt', '11pt' or '12pt'). 131 | # 132 | # 'pointsize': '10pt', 133 | # Additional stuff for the LaTeX preamble. 134 | # 135 | # 'preamble': '', 136 | # Latex figure (float) alignment 137 | # 138 | # 'figure_align': 'htbp', 139 | } 140 | 141 | # Grouping the document tree into LaTeX files. List of tuples 142 | # (source start file, target name, title, 143 | # author, documentclass [howto, manual, or own class]). 144 | latex_documents = [ 145 | (master_doc, "DjangoRESTWitchcraft.tex", "Django REST Witchcraft Documentation", "Serkan Hosca", "manual") 146 | ] 147 | 148 | # -- Options for manual page output --------------------------------------- 149 | 150 | # One entry per manual page. List of tuples 151 | # (source start file, name, description, authors, manual section). 152 | man_pages = [(master_doc, "djangorestwitchcraft", "Django REST Witchcraft Documentation", [author], 1)] 153 | 154 | # -- Options for Texinfo output ------------------------------------------- 155 | 156 | # Grouping the document tree into Texinfo files. List of tuples 157 | # (source start file, target name, title, author, 158 | # dir menu entry, description, category) 159 | texinfo_documents = [ 160 | ( 161 | master_doc, 162 | "DjangoRESTWitchcraft", 163 | "Django REST Witchcraft Documentation", 164 | author, 165 | "DjangoRESTWitchcraft", 166 | "One line description of project.", 167 | "Miscellaneous", 168 | ) 169 | ] 170 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Django REST Witchcraft 2 | ====================== 3 | 4 | |Build Status| |Read The Docs| |PyPI version| |Coveralls Status| |Black| 5 | 6 | **Django REST Framework integration with SQLAlchemy** 7 | 8 | django-rest-witchcraft is an extension for Django REST Framework that adds support for SQLAlchemy. It aims to provide 9 | a similar development experience to building REST api's with Django REST Framework with Django ORM, except with 10 | SQLAlchemy. 11 | 12 | Installation 13 | ============ 14 | 15 | :: 16 | 17 | pip install django-rest-witchcraft 18 | 19 | Quick Start 20 | =========== 21 | 22 | First up, lets define some simple models: 23 | 24 | .. code:: python 25 | 26 | import sqlalchemy as sa 27 | import sqlalchemy.orm # noqa 28 | from sqlalchemy.ext.declarative import declarative_base 29 | 30 | engine = sa.create_engine('sqlite:///:memory:', echo=True) 31 | session = sa.orm.scoped_session(sa.orm.sessionmaker(bind=engine)) 32 | 33 | Base = declarative_base() 34 | Base.query = session.query_property() 35 | 36 | 37 | class Group(Base): 38 | __tablename__ = 'groups' 39 | 40 | id = sa.Column(sa.Integer(), primary_key=True, autoincrement=True) 41 | name = sa.Column(sa.String()) 42 | 43 | 44 | class User(Base): 45 | __tablename__ = 'users' 46 | 47 | id = sa.Column(sa.Integer(), primary_key=True, autoincrement=True) 48 | name = sa.Column(sa.String()) 49 | fullname = sa.Column(sa.String()) 50 | password = sa.Column(sa.String()) 51 | 52 | _group_id = sa.Column('group_id', sa.Integer(), sa.ForeignKey('groups.id')) 53 | group = sa.orm.relationship(Group, backref='users') 54 | 55 | 56 | class Address(Base): 57 | __tablename__ = 'addresses' 58 | 59 | id = sa.Column(sa.Integer(), primary_key=True, autoincrement=True) 60 | email_address = sa.Column(sa.String(), nullable=False) 61 | 62 | _user_id = sa.Column(sa.Integer(), sa.ForeignKey('users.id')) 63 | user = sa.orm.relationship(User, backref='addresses') 64 | 65 | Base.metadata.create_all(engine) 66 | 67 | 68 | Nothing fancy here, we have a ``User`` class that can belongs to a ``Group`` instance and has many ``Address`` 69 | instances 70 | 71 | This serializer can handle nested create, update or partial update operations. 72 | 73 | Lets define a serializer for ``User`` with all the fields: 74 | 75 | .. code:: python 76 | 77 | class UserSerializer(serializers.ModelSerializer): 78 | 79 | class Meta: 80 | model = User 81 | session = session 82 | fields = '__all__' 83 | 84 | This will create the following serializer for us: 85 | 86 | :: 87 | 88 | >>> serializer = UserSerializer() 89 | 90 | >>> serializer 91 | UserSerializer(): 92 | id = IntegerField(allow_null=False, help_text=None, label='Id', required=True) 93 | name = CharField(allow_null=True, help_text=None, label='Name', max_length=None, required=False) 94 | fullname = CharField(allow_null=True, help_text=None, label='Fullname', max_length=None, required=False) 95 | password = CharField(allow_null=True, help_text=None, label='Password', max_length=None, required=False) 96 | group = GroupSerializer(allow_null=True, is_nested=True, required=False): 97 | id = IntegerField(allow_null=False, help_text=None, label='Id', required=False) 98 | name = CharField(allow_null=True, help_text=None, label='Name', max_length=None, required=False) 99 | addresses = AddressSerializer(allow_null=True, many=True, required=False): 100 | id = IntegerField(allow_null=False, help_text=None, label='Id', required=False) 101 | email_address = CharField(allow_null=False, help_text=None, label='Email_address', max_length=None, required=True) 102 | url = UriField(read_only=True) 103 | 104 | Lets try to create a ``User`` instance with our brand new serializer: 105 | 106 | .. code:: python 107 | 108 | serializer = UserSerializer(data={ 109 | 'name': 'shosca', 110 | 'password': 'swordfish', 111 | }) 112 | serializer.is_valid() 113 | serializer.save() 114 | 115 | user = serializer.instance 116 | 117 | This will create the following user for us: 118 | 119 | :: 120 | 121 | >>> user 122 | User(_group_id=None, id=1, name='shosca', fullname=None, password='swordfish') 123 | 124 | Lets try to update our user ``User`` instance and change its password: 125 | 126 | .. code:: python 127 | 128 | serializer = UserSerializer(user, data={ 129 | 'name': 'shosca', 130 | 'password': 'password', 131 | }) 132 | serializer.is_valid() 133 | serializer.save() 134 | 135 | user = serializer.instance 136 | 137 | Our user now looks like: 138 | 139 | :: 140 | 141 | >>> user 142 | User(_group_id=None, id=1, name='shosca', fullname=None, password='password') 143 | 144 | Lets try to update our ``User`` instance again, but this time lets change its password only: 145 | 146 | .. code:: python 147 | 148 | serializer = UserSerializer(user, data={ 149 | 'password': 'swordfish', 150 | }, partial=True) 151 | serializer.is_valid() 152 | serializer.save() 153 | 154 | user = serializer.instance 155 | 156 | This will update the following user for us: 157 | 158 | :: 159 | 160 | >>> user 161 | User(_group_id=None, id=1, name='shosca', fullname=None, password='swordfish') 162 | 163 | Our user does not belong to a ``Group``, lets fix that: 164 | 165 | .. code:: python 166 | 167 | group = Group(name='Admin') 168 | session.add(group) 169 | session.flush() 170 | 171 | serializer = UserSerializer(user, data={ 172 | 'group': {'id': group.id} 173 | }) 174 | serializer.is_valid() 175 | serializer.save() 176 | 177 | user = serializer.instance 178 | 179 | Now, our user looks like: 180 | 181 | :: 182 | 183 | >>> user 184 | User(_group_id=1, id=1, name='shosca', fullname=None, password='swordfish') 185 | 186 | >>> user.group 187 | Group(id=1, name='Admin') 188 | 189 | We can also change the name of our user's group through the user using nested updates: 190 | 191 | .. code:: python 192 | 193 | class UserSerializer(serializers.ModelSerializer): 194 | 195 | class Meta: 196 | model = User 197 | session = session 198 | fields = '__all__' 199 | extra_kwargs = { 200 | 'group': {'allow_nested_updates': True} 201 | } 202 | 203 | serializer = UserSerializer(user, data={ 204 | 'group': {'name': 'Super User'} 205 | }, partial=True) 206 | serializer.is_valid() 207 | 208 | user = serializer.save() 209 | 210 | Now, our user looks like: 211 | 212 | :: 213 | 214 | >>> user 215 | User(_group_id=1, id=1, name='shosca', fullname=None, password='swordfish') 216 | 217 | >>> user.group 218 | Group(id=1, name='Super User') 219 | 220 | We can use this serializer in a viewset like: 221 | 222 | .. code:: python 223 | 224 | from rest_witchcraft import viewsets 225 | 226 | class UserViewSet(viewsets.ModelViewSet): 227 | queryset = User.query 228 | serializer_class = UserSerializer 229 | 230 | And we can register this viewset in our ``urls.py`` like: 231 | 232 | .. code:: python 233 | 234 | from rest_witchcraft import routers 235 | 236 | router = routers.DefaultRouter() 237 | router.register(r'users', UserViewSet) 238 | 239 | urlpatterns = [ 240 | ... 241 | url(r'^', include(router.urls)), 242 | ... 243 | ] 244 | 245 | 246 | .. |Build Status| image:: https://github.com/shosca/django-rest-witchcraft/workflows/Build/badge.svg?branch=master 247 | :target: https://github.com/shosca/django-rest-witchcraft/actions?query=workflow%3ABuild+branch%3Amaster 248 | .. |Read The Docs| image:: https://readthedocs.org/projects/django-rest-witchcraft/badge/?version=latest 249 | :target: http://django-rest-witchcraft.readthedocs.io/en/latest/?badge=latest 250 | .. |PyPI version| image:: https://badge.fury.io/py/django-rest-witchcraft.svg 251 | :target: https://badge.fury.io/py/django-rest-witchcraft 252 | .. |Coveralls Status| image:: https://coveralls.io/repos/github/shosca/django-rest-witchcraft/badge.svg?branch=master 253 | :target: https://coveralls.io/github/shosca/django-rest-witchcraft?branch=master 254 | .. |Black| image:: https://img.shields.io/badge/code%20style-black-000000.svg 255 | :target: https://github.com/ambv/black 256 | -------------------------------------------------------------------------------- /tests/test_routers.py: -------------------------------------------------------------------------------- 1 | import simplejson as json 2 | 3 | from django.test import SimpleTestCase, override_settings 4 | 5 | from rest_witchcraft import routers, serializers, viewsets 6 | 7 | from .models_composite import RouterTestCompositeKeyModel, RouterTestModel, session 8 | 9 | 10 | try: 11 | from django.conf.urls import url as re_path 12 | from django.conf.urls import include 13 | except ImportError: # pragma: no cover 14 | from django.urls import re_path, include 15 | 16 | 17 | class RouterTestModelSerializer(serializers.ModelSerializer): 18 | class Meta: 19 | model = RouterTestModel 20 | session = session 21 | fields = "__all__" 22 | 23 | 24 | class RouterTestCompositeKeyModelSerializer(serializers.ModelSerializer): 25 | class Meta: 26 | model = RouterTestCompositeKeyModel 27 | session = session 28 | fields = "__all__" 29 | 30 | 31 | class UnAuthMixin: 32 | def perform_authentication(self, request): 33 | return None 34 | 35 | 36 | class RouterTestViewSet(UnAuthMixin, viewsets.ModelViewSet): 37 | queryset = RouterTestModel.query 38 | serializer_class = RouterTestModelSerializer 39 | lookup_field = "id" 40 | lookup_url_kwarg = "pk" 41 | 42 | 43 | class RouterTestCompositeViewSet(UnAuthMixin, viewsets.ModelViewSet): 44 | queryset = RouterTestCompositeKeyModel.query 45 | serializer_class = RouterTestCompositeKeyModelSerializer 46 | 47 | 48 | class RouterTestCompositeCustomRegexViewSet(UnAuthMixin, viewsets.ModelViewSet): 49 | queryset = RouterTestCompositeKeyModel.query 50 | serializer_class = RouterTestCompositeKeyModelSerializer 51 | lookup_url_regex = "(?P[0-9]+)/other/(?P[0-9]+)" 52 | 53 | 54 | router = routers.DefaultRouter() 55 | router.register(r"test", RouterTestViewSet) 56 | router.register(r"testcomposite", RouterTestCompositeViewSet) 57 | router.register(r"testcompositeregex", RouterTestCompositeCustomRegexViewSet) 58 | 59 | urlpatterns = [re_path(r"^", include(router.urls))] 60 | 61 | 62 | @override_settings(ROOT_URLCONF="tests.test_routers") 63 | class TestDummyDummy(SimpleTestCase): 64 | def test_assert_when_no_model_found(self): 65 | class DummyViewSet(UnAuthMixin, viewsets.ModelViewSet): 66 | pass 67 | 68 | dummy_router = routers.DefaultRouter() 69 | 70 | with self.assertRaises(AssertionError): 71 | dummy_router.register(r"dummy", DummyViewSet) 72 | 73 | def test_get_lookup_regex_without_model(self): 74 | class DummyViewSet(UnAuthMixin, viewsets.ModelViewSet): 75 | @classmethod 76 | def get_model(cls): 77 | return None 78 | 79 | dummy_router = routers.DefaultRouter() 80 | 81 | lookup_regex = dummy_router.get_lookup_regex(DummyViewSet) 82 | self.assertEqual(lookup_regex, "(?P[^/.]+)") 83 | 84 | 85 | @override_settings(ROOT_URLCONF="tests.test_routers") 86 | class TestModelRoutes(SimpleTestCase): 87 | maxDiff = None 88 | 89 | def setUp(self): 90 | session.add_all( 91 | [RouterTestModel(id=1, text="router test model 1"), RouterTestModel(id=2, text="router test model 2")] 92 | ) 93 | 94 | def tearDown(self): 95 | session.rollback() 96 | 97 | def test_list(self): 98 | resp = self.client.get("/test/") 99 | 100 | self.assertEqual( 101 | resp.data, [{"id": 1, "text": "router test model 1"}, {"id": 2, "text": "router test model 2"}] 102 | ) 103 | 104 | def test_retrieve(self): 105 | resp = self.client.get("/test/2/") 106 | 107 | self.assertEqual(resp.data, {"id": 2, "text": "router test model 2"}) 108 | 109 | def test_create(self): 110 | data = json.dumps({"text": "router test model 3"}) 111 | resp = self.client.post("/test/", data=data, content_type="application/json") 112 | 113 | self.assertEqual(resp.data, {"id": 3, "text": "router test model 3"}) 114 | 115 | def test_update(self): 116 | data = {"text": "router test update 2"} 117 | resp = self.client.put("/test/2/", data=json.dumps(data), content_type="application/json") 118 | self.assertEqual(resp.status_code, 200) 119 | self.assertEqual(resp.data, {"id": 2, "text": "router test update 2"}) 120 | 121 | def test_patch_update(self): 122 | data = {"text": "router test update 2"} 123 | resp = self.client.patch("/test/2/", data=json.dumps(data), content_type="application/json") 124 | self.assertEqual(resp.status_code, 200) 125 | self.assertEqual(resp.data, {"id": 2, "text": "router test update 2"}) 126 | 127 | def test_delete(self): 128 | resp = self.client.delete("/test/2/", content_type="application/json") 129 | self.assertEqual(resp.status_code, 204) 130 | 131 | 132 | @override_settings(ROOT_URLCONF="tests.test_routers") 133 | class TestCompositeRoutes(SimpleTestCase): 134 | def setUp(self): 135 | session.add_all( 136 | [ 137 | RouterTestCompositeKeyModel(id=1, other_id=1, text="router composite model 1"), 138 | RouterTestCompositeKeyModel(id=1, other_id=2, text="router composite model 2"), 139 | ] 140 | ) 141 | 142 | def tearDown(self): 143 | session.rollback() 144 | 145 | def test_list(self): 146 | resp = self.client.get("/testcomposite/") 147 | 148 | self.assertEqual( 149 | resp.data, 150 | [ 151 | {"id": 1, "other_id": 1, "text": "router composite model 1"}, 152 | {"id": 1, "other_id": 2, "text": "router composite model 2"}, 153 | ], 154 | ) 155 | 156 | def test_retrieve(self): 157 | resp = self.client.get("/testcomposite/1/2/") 158 | 159 | self.assertEqual(resp.data, {"id": 1, "other_id": 2, "text": "router composite model 2"}) 160 | 161 | def test_create(self): 162 | data = json.dumps({"text": "composite test model 3"}) 163 | resp = self.client.post("/testcomposite/", data=data, content_type="application/json") 164 | 165 | self.assertEqual(resp.data, {"id": 1, "other_id": 3, "text": "composite test model 3"}) 166 | 167 | def test_update(self): 168 | data = json.dumps({"text": "router test update 2"}) 169 | resp = self.client.put("/testcomposite/1/2/", data=data, content_type="application/json") 170 | self.assertEqual(resp.status_code, 200) 171 | self.assertEqual(resp.data, {"id": 1, "other_id": 2, "text": "router test update 2"}) 172 | 173 | def test_patch_update(self): 174 | data = json.dumps({"text": "router test update 2"}) 175 | resp = self.client.patch("/testcomposite/1/2/", data=data, content_type="application/json") 176 | self.assertEqual(resp.status_code, 200) 177 | self.assertEqual(resp.data, {"id": 1, "other_id": 2, "text": "router test update 2"}) 178 | 179 | def test_delete(self): 180 | resp = self.client.delete("/testcomposite/1/2/", content_type="application/json") 181 | self.assertEqual(resp.status_code, 204) 182 | 183 | 184 | @override_settings(ROOT_URLCONF="tests.test_routers") 185 | class TestCompositeRoutesWithCustomRegex(SimpleTestCase): 186 | def setUp(self): 187 | session.add_all( 188 | [ 189 | RouterTestCompositeKeyModel(id=1, other_id=1, text="router composite model 1"), 190 | RouterTestCompositeKeyModel(id=1, other_id=2, text="router composite model 2"), 191 | ] 192 | ) 193 | 194 | def tearDown(self): 195 | session.rollback() 196 | 197 | def test_list(self): 198 | resp = self.client.get("/testcompositeregex/") 199 | 200 | self.assertEqual( 201 | resp.data, 202 | [ 203 | {"id": 1, "other_id": 1, "text": "router composite model 1"}, 204 | {"id": 1, "other_id": 2, "text": "router composite model 2"}, 205 | ], 206 | ) 207 | 208 | def test_retrieve(self): 209 | resp = self.client.get("/testcompositeregex/1/other/2/") 210 | 211 | self.assertEqual(resp.data, {"id": 1, "other_id": 2, "text": "router composite model 2"}) 212 | 213 | def test_create(self): 214 | data = json.dumps({"text": "composite test model 3"}) 215 | resp = self.client.post("/testcompositeregex/", data=data, content_type="application/json") 216 | 217 | self.assertEqual(resp.data, {"id": 1, "other_id": 3, "text": "composite test model 3"}) 218 | 219 | def test_update(self): 220 | data = json.dumps({"text": "router test update 2"}) 221 | resp = self.client.put("/testcompositeregex/1/other/2/", data=data, content_type="application/json") 222 | self.assertEqual(resp.status_code, 200) 223 | self.assertEqual(resp.data, {"id": 1, "other_id": 2, "text": "router test update 2"}) 224 | 225 | def test_patch_update(self): 226 | data = json.dumps({"text": "router test update 2"}) 227 | resp = self.client.patch("/testcompositeregex/1/other/2/", data=data, content_type="application/json") 228 | self.assertEqual(resp.status_code, 200) 229 | self.assertEqual(resp.data, {"id": 1, "other_id": 2, "text": "router test update 2"}) 230 | 231 | def test_delete(self): 232 | resp = self.client.delete("/testcompositeregex/1/other/2/", content_type="application/json") 233 | self.assertEqual(resp.status_code, 204) 234 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | 5 | 0.12.1 (2022-12-21) 6 | ----------------------------- 7 | - Add support for DRF 3.14 and python 3.11 (#89) [Serkan Hosca] 8 | - [pre-commit.ci] pre-commit autoupdate (#87) [pre-commit-ci[bot], pre- 9 | commit-ci[bot]] 10 | - [pre-commit.ci] pre-commit autoupdate (#86) [pre-commit-ci[bot], pre- 11 | commit-ci[bot]] 12 | - Add dj41 to build matrix (#85) [Serkan Hosca] 13 | - [pre-commit.ci] pre-commit autoupdate (#84) [pre-commit-ci[bot], pre- 14 | commit-ci[bot]] 15 | - [pre-commit.ci] pre-commit autoupdate (#83) [pre-commit-ci[bot], pre- 16 | commit-ci[bot]] 17 | - [pre-commit.ci] pre-commit autoupdate (#82) [pre-commit-ci[bot], pre- 18 | commit-ci[bot]] 19 | 20 | 21 | 22 | - [pre-commit.ci] pre-commit autoupdate (#81) [pre-commit-ci[bot], pre- 23 | commit-ci[bot]] 24 | - [pre-commit.ci] pre-commit autoupdate (#80) [Serkan Hosca, pre-commit- 25 | ci[bot], pre-commit-ci[bot]] 26 | - [pre-commit.ci] pre-commit autoupdate (#79) [pre-commit-ci[bot], pre- 27 | commit-ci[bot]] 28 | 29 | 30 | 31 | - [pre-commit.ci] pre-commit autoupdate (#76) [Serkan Hosca, pre-commit- 32 | ci[bot], pre-commit-ci[bot]] 33 | - Multi python dockerfile for local dev (#78) [Serkan Hosca] 34 | 35 | 36 | 0.12.0 (2022-03-26) 37 | ------------------- 38 | - Add django 4 to tox matrix (#77) [Serkan Hosca] 39 | - [pre-commit.ci] pre-commit autoupdate (#75) [pre-commit-ci[bot], pre- 40 | commit-ci[bot]] 41 | - [pre-commit.ci] pre-commit autoupdate (#74) [pre-commit-ci[bot], pre- 42 | commit-ci[bot]] 43 | - [pre-commit.ci] pre-commit autoupdate (#73) [pre-commit-ci[bot], pre- 44 | commit-ci[bot]] 45 | - [pre-commit.ci] pre-commit autoupdate (#72) [pre-commit-ci[bot], pre- 46 | commit-ci[bot]] 47 | - [pre-commit.ci] pre-commit autoupdate (#71) [pre-commit-ci[bot], pre- 48 | commit-ci[bot]] 49 | - Sourcery fixes (#70) [Serkan Hosca] 50 | 51 | 52 | 0.11.1 (2021-05-06) 53 | ------------------- 54 | - Use suppress from contextlib (#69) [Serkan Hosca] 55 | 56 | 57 | 0.11.0 (2021-03-20) 58 | ------------------- 59 | - Add sqlalchemy 1.4 support (#65) [Serkan Hosca] 60 | - Github actions build badge (#64) [Serkan Hosca] 61 | - Adding github actions (#57) [Serkan Hosca] 62 | - Fix build link. [Serkan Hosca] 63 | - Add dj3.1 and drf3.11 to matrix (#62) [Serkan Hosca] 64 | - Pre-commit imporanize pyupgrade and docformat (#61) [Serkan Hosca] 65 | - Add django3 to build matrix (#58) [Serkan Hosca] 66 | - Add drf 3.10 on build matrix (#56) [Serkan Hosca] 67 | 68 | 69 | 0.10.3 (2019-11-07) 70 | ------------------- 71 | - Checking manifest with pre-commit (#55) [Miroslav Shubernetskiy] 72 | 73 | 74 | 0.10.2 (2019-10-31) 75 | ------------------- 76 | - Accounting for all expandable fields (#54) [Miroslav Shubernetskiy] 77 | 78 | 79 | 0.10.1 (2019-10-30) 80 | ------------------- 81 | - Expandable serializer uses selectinload for *tomany (#53) [Miroslav 82 | Shubernetskiy] 83 | 84 | 85 | 0.10.0 (2019-08-31) 86 | ------------------- 87 | - Drop py2 support (#51) [Serkan Hosca] 88 | - Pytest and black configs (#49) [Serkan Hosca] 89 | - Add SearchFilter (#47) [Serkan Hosca] 90 | - Use python/black (#46) [Serkan Hosca] 91 | 92 | 93 | 0.9.0 (2019-06-28) 94 | ------------------ 95 | - Drop enumfield and update importanize config (#45) [Serkan Hosca] 96 | 97 | 98 | 0.8.3 (2019-06-27) 99 | ------------------ 100 | - Fix module test runner target (#44) [Serkan Hosca] 101 | - Switching to tox-travis and tox matrix (#43) [Miroslav Shubernetskiy] 102 | - Run tests with pg (#42) [Serkan Hosca] 103 | - Update pre-commit (#41) [Serkan Hosca] 104 | 105 | 106 | 0.8.2 (2019-02-11) 107 | ------------------ 108 | - Fix Unicode type column mapping (#40) [Serkan Hosca] 109 | 110 | 111 | 0.8.1 (2019-01-08) 112 | ------------------ 113 | - Allowing to overwrite fields and exclude on serializer init (#38) 114 | [Miroslav Shubernetskiy] 115 | 116 | 117 | 0.8.0 (2019-01-05) 118 | ------------------ 119 | - Grab composite meta info from parent model (#37) [Serkan Hosca] 120 | - Coersion fixes from django-sorcery (#36) [Serkan Hosca] 121 | 122 | 123 | 0.7.20 (2018-12-13) 124 | ------------------- 125 | - Fix enum field custom kwargs (#35) [Serkan Hosca] 126 | 127 | 128 | 0.7.19 (2018-11-28) 129 | ------------------- 130 | - Pop widget from args (#34) [Serkan Hosca] 131 | 132 | 133 | 0.7.18 (2018-11-26) 134 | ------------------- 135 | - Stop using deprecated functions (#33) [Serkan Hosca] 136 | 137 | 138 | 0.7.17 (2018-11-24) 139 | ------------------- 140 | - Fix enum field and make it more generic (#32) [Serkan Hosca] 141 | 142 | 143 | 0.7.16 (2018-11-19) 144 | ------------------- 145 | - Fix composite source (#31) [Serkan Hosca] 146 | - Remove pipenv (#30) [Serkan Hosca] 147 | 148 | 149 | 0.7.15 (2018-11-14) 150 | ------------------- 151 | - Handling ValidationError in update on set attribute (#28) [Miroslav 152 | Shubernetskiy] 153 | 154 | 155 | 156 | 157 | 158 | - Bump pre-commit check versions (#27) [Serkan Hosca] 159 | 160 | 161 | 0.7.14 (2018-11-07) 162 | ------------------- 163 | - Fixing typo referencing session which does not exist (#26) [Miroslav 164 | Shubernetskiy] 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 0.7.13 (2018-11-06) 173 | ------------------- 174 | - Adding query_model hook (#24) [Miroslav Shubernetskiy] 175 | 176 | 177 | 0.7.12 (2018-11-05) 178 | ------------------- 179 | - Remove url default field from ModelSerializer (#25) [Serkan Hosca] 180 | - Update lock. [Serkan Hosca] 181 | 182 | 183 | 0.7.11 (2018-11-01) 184 | ------------------- 185 | - Hook for how model is created (#22) [Miroslav Shubernetskiy] 186 | - Fix serializer tests (#23) [Serkan Hosca] 187 | - Relock (#20) [Serkan Hosca] 188 | - Drop py3.5 build. [Serkan Hosca] 189 | 190 | 191 | 0.7.10 (2018-08-13) 192 | ------------------- 193 | - Partial by pk (#19) [Miroslav Shubernetskiy] 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | - Allowing to overwrite extra_kwargs in Serializer.__init__ (#18) 202 | [Miroslav Shubernetskiy] 203 | 204 | 205 | 0.7.9 (2018-08-08) 206 | ------------------ 207 | - ExpandableModelSerializer (#17) [Miroslav Shubernetskiy] 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | - Fixing saving serializer with source=* (#16) [Miroslav Shubernetskiy] 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 0.7.5 (2018-07-24) 240 | ------------------ 241 | - Correctly removing composite when validated data is None (#15) 242 | [Miroslav Shubernetskiy] 243 | 244 | 245 | 0.7.4 (2018-07-20) 246 | ------------------ 247 | - Fixing enum field choices (#14) [Miroslav Shubernetskiy] 248 | 249 | 250 | 0.7.3 (2018-07-16) 251 | ------------------ 252 | - Fixing updating model when field.field_name != field.source (#13) 253 | [Miroslav Shubernetskiy] 254 | 255 | 256 | 257 | 258 | 259 | - Add nested update test (#12) [Serkan Hosca] 260 | 261 | 262 | 0.7.2 (2018-06-28) 263 | ------------------ 264 | - Merge pull request #10 from shosca/composite-labels. [Serkan Hosca] 265 | - Fixing uri field for multiple pk models. fixed tests. [Miroslav 266 | Shubernetskiy] 267 | - Honoring lookup_field iin querying model in generics.py when single 268 | pk. [Miroslav Shubernetskiy] 269 | - Normalizing django validation errors in apis. [Miroslav Shubernetskiy] 270 | - Fixing composite serializer field labels to use compose fields vs 271 | column names. [Miroslav Shubernetskiy] 272 | 273 | 274 | 0.7.1 (2018-06-26) 275 | ------------------ 276 | - Merge pull request #11 from shosca/relation-null-set. [Serkan Hosca] 277 | - Fix many-to-one or one-to-one relation null set. [Serkan Hosca] 278 | 279 | 280 | 0.7.0 (2018-06-10) 281 | ------------------ 282 | - Merge pull request #9 from shosca/use-sorcery. [Serkan Hosca] 283 | - Add sorcery as dependency. [Serkan Hosca] 284 | 285 | 286 | 0.6.2 (2018-02-23) 287 | ------------------ 288 | - Merge pull request #8 from shosca/packaging. [Serkan Hosca] 289 | - Fix packaging. [Serkan Hosca] 290 | 291 | 292 | 0.6.1 (2018-01-08) 293 | ------------------ 294 | 295 | Fix 296 | ~~~ 297 | - Adjust build_nested_field signature. [Serkan Hosca] 298 | 299 | Other 300 | ~~~~~ 301 | - Version 0.6.1. [Serkan Hosca] 302 | - Merge pull request #7 from shosca/relation-info. [Serkan Hosca] 303 | 304 | 305 | 0.6.0 (2018-01-05) 306 | ------------------ 307 | - Version 0.6.0. [Serkan Hosca] 308 | - Merge pull request #5 from shosca/build-field-signature. [Serkan 309 | Hosca] 310 | - Add model_class to build_field. [Serkan Hosca] 311 | 312 | 313 | 0.5.6 (2017-12-21) 314 | ------------------ 315 | - Merge pull request #3 from nickswiss/enum-mapping. [Serkan Hosca] 316 | - Adding enums to field mapping dict. [Nick Arnold] 317 | 318 | 319 | 0.5.5 (2017-11-02) 320 | ------------------ 321 | 322 | Fix 323 | ~~~ 324 | - Declared fields. [Serkan Hosca] 325 | 326 | Other 327 | ~~~~~ 328 | - 0.5.5. [Serkan Hosca] 329 | - Merge pull request #2 from shosca/fix-declared-fields. [Serkan Hosca] 330 | 331 | 332 | 0.5.4 (2017-10-23) 333 | ------------------ 334 | 335 | Fix 336 | ~~~ 337 | - Super for py2. [Serkan Hosca] 338 | 339 | Refactor 340 | ~~~~~~~~ 341 | - Separate out session flush. [Serkan Hosca] 342 | 343 | 344 | 0.5.2 (2017-10-21) 345 | ------------------ 346 | 347 | Fix 348 | ~~~ 349 | - Deepcopy composite and model serializers. [Serkan Hosca] 350 | 351 | 352 | 0.5.1 (2017-10-04) 353 | ------------------ 354 | 355 | Refactor 356 | ~~~~~~~~ 357 | - Handle session passing around. [Serkan Hosca] 358 | 359 | Other 360 | ~~~~~ 361 | - Merge pull request #1 from shosca/session-distribution. [Serkan Hosca] 362 | 363 | 364 | 0.5.0 (2017-10-03) 365 | ------------------ 366 | 367 | Refactor 368 | ~~~~~~~~ 369 | - Make enums use values instead of names. [Serkan Hosca] 370 | - Use relationship mapper to get target model class. [Serkan Hosca] 371 | 372 | Other 373 | ~~~~~ 374 | - Add LICENSE. [Serkan Hosca] 375 | - Pipfile lock. [Serkan Hosca] 376 | 377 | 378 | 0.4.3 (2017-07-06) 379 | ------------------ 380 | 381 | Fix 382 | ~~~ 383 | - Allow_null is not allowed in boolean fields. [Serkan Hosca] 384 | 385 | 386 | 0.4.2 (2017-07-02) 387 | ------------------ 388 | 389 | Fix 390 | ~~~ 391 | - Handle composite pks when one pk is None. [Serkan Hosca] 392 | 393 | 394 | 0.4.1 (2017-07-01) 395 | ------------------ 396 | 397 | Fix 398 | ~~~ 399 | - Nested model primary key field generation. [Serkan Hosca] 400 | 401 | Other 402 | ~~~~~ 403 | - Fix readme. [Serkan Hosca] 404 | 405 | 406 | 0.4.0 (2017-06-28) 407 | ------------------ 408 | 409 | Fix 410 | ~~~ 411 | - Field label generation. [Serkan Hosca] 412 | 413 | Refactor 414 | ~~~~~~~~ 415 | - Lots of minor pylint and pycharm linter fixes. [Serkan Hosca] 416 | 417 | Other 418 | ~~~~~ 419 | - Update gitchangelog.rc. [Serkan Hosca] 420 | 421 | 422 | 0.3.5 (2017-06-18) 423 | ------------------ 424 | 425 | Fix 426 | ~~~ 427 | - Increase coverage. [Serkan Hosca] 428 | 429 | Refactor 430 | ~~~~~~~~ 431 | - Dedup update attribute logic. [Serkan Hosca] 432 | - Run pre-commit as part of build. [Serkan Hosca] 433 | 434 | 435 | 0.3.4 (2017-06-14) 436 | ------------------ 437 | 438 | Refactor 439 | ~~~~~~~~ 440 | - Better route name handling and nullable boolean field tests. [Serkan 441 | Hosca] 442 | 443 | Documentation 444 | ~~~~~~~~~~~~~ 445 | - Update gitchangelog config. [Serkan Hosca] 446 | 447 | 448 | 0.3.3 (2017-06-13) 449 | ------------------ 450 | 451 | Fix 452 | ~~~ 453 | - Add pipenv for setup. [Serkan Hosca] 454 | 455 | Documentation 456 | ~~~~~~~~~~~~~ 457 | - Fix versioning. [Serkan Hosca] 458 | 459 | 460 | 0.3.2 (2017-06-13) 461 | ------------------ 462 | 463 | Fix 464 | ~~~ 465 | - Stop passing around is_nested and fix autoincrement value check. 466 | [Serkan Hosca] 467 | 468 | 469 | 0.3.1 (2017-06-11) 470 | ------------------ 471 | - Delete tests and coverall config. [Serkan Hosca] 472 | 473 | 474 | 0.3.0 (2017-06-11) 475 | ------------------ 476 | 477 | Fix 478 | ~~~ 479 | - Nested list serializer flags. [Serkan Hosca] 480 | - Generic destroy with sqlalchemy. [Serkan Hosca] 481 | - Handle autoincrement and nested update with existing instance. [Serkan 482 | Hosca] 483 | 484 | Refactor 485 | ~~~~~~~~ 486 | - Model_info changes and added docstrings. [Serkan Hosca] 487 | 488 | Other 489 | ~~~~~ 490 | - Initial doc setup. [Serkan Hosca] 491 | 492 | 493 | 0.2.1 (2017-06-10) 494 | ------------------ 495 | - Initial doc setup. [Serkan Hosca] 496 | 497 | 498 | 0.2.0 (2017-06-10) 499 | ------------------ 500 | - Refactor field mapping and object fetching and more tests. [Serkan 501 | Hosca] 502 | 503 | 504 | 0.1.4 (2017-06-09) 505 | ------------------ 506 | - Respect allow_null. [Serkan Hosca] 507 | 508 | 509 | 0.1.2 (2017-06-08) 510 | ------------------ 511 | - Mark all columns read only when allow_nested_updates is false. [Serkan 512 | Hosca] 513 | 514 | 515 | 0.1.1 (2017-06-07) 516 | ------------------ 517 | - Fix composite serializer. [Serkan Hosca] 518 | 519 | 520 | 0.1.0 (2017-06-06) 521 | ------------------ 522 | - Add more tests and generic api fixes. [Serkan Hosca] 523 | 524 | 525 | 0.0.6 (2017-06-05) 526 | ------------------ 527 | - Add missing dep and add pypi badge. [Serkan Hosca] 528 | - Add more tests for composite routes. [Serkan Hosca] 529 | 530 | 531 | 0.0.5 (2017-06-05) 532 | ------------------ 533 | - Add route tests. [Serkan Hosca] 534 | 535 | 536 | 0.0.4 (2017-06-05) 537 | ------------------ 538 | - Add pre-commit. [Serkan Hosca] 539 | - Move GenericAPIView. [Serkan Hosca] 540 | - Fix Readme. [Serkan Hosca] 541 | 542 | 543 | 0.0.2 (2017-06-02) 544 | ------------------ 545 | - Fix setup publish and make clean. [Serkan Hosca] 546 | - Added viewsets and version bump. [Serkan Hosca] 547 | - Update readme. [Serkan Hosca] 548 | 549 | 550 | 0.0.1 (2017-06-02) 551 | ------------------ 552 | - Fix readme. [Serkan Hosca] 553 | - Added initial readme. [Serkan Hosca] 554 | - Add travis. [Serkan Hosca] 555 | - Initial commit. [Serkan Hosca] 556 | 557 | 558 | -------------------------------------------------------------------------------- /rest_witchcraft/serializers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | from collections import OrderedDict, namedtuple 4 | from itertools import groupby 5 | 6 | from sqlalchemy.orm.interfaces import ONETOMANY 7 | 8 | from django.core.exceptions import ImproperlyConfigured, ValidationError as DjangoValidationError 9 | from django.db.models.constants import LOOKUP_SEP 10 | from django.http import QueryDict 11 | from django.utils.text import capfirst 12 | 13 | from django_sorcery.db import meta 14 | 15 | from rest_framework import fields, serializers 16 | from rest_framework.exceptions import ValidationError 17 | from rest_framework.settings import api_settings 18 | 19 | from .field_mapping import get_field_type, get_url_kwargs 20 | from .fields import ImplicitExpandableListField, UriField 21 | from .utils import django_to_drf_validation_error 22 | 23 | 24 | ALL_FIELDS = "__all__" 25 | REGEX_TYPE = type(re.compile("")) 26 | 27 | 28 | class BaseSerializer(serializers.Serializer): 29 | serializer_choice_field = fields.ChoiceField 30 | 31 | @property 32 | def is_nested(self): 33 | if self.parent: 34 | if not hasattr(self.parent, "many"): 35 | return True 36 | 37 | if self.parent.many is True and self.parent.parent: 38 | return True 39 | 40 | return False 41 | 42 | def build_standard_field_kwargs(self, field_name, field_class, column_info): 43 | """Analyze model column to generate field kwargs.""" 44 | field_kwargs = column_info.field_kwargs.copy() 45 | field_kwargs["label"] = capfirst(" ".join(field_name.split("_")).strip()) 46 | field_kwargs["allow_null"] = not field_kwargs.get("required", True) 47 | 48 | if "choices" in field_kwargs: 49 | # Fields with choices get coerced into `ChoiceField` 50 | # instead of using their regular typed field. 51 | field_class = self.serializer_choice_field 52 | # Some model fields may introduce kwargs that would not be valid 53 | # for the choice field. We need to strip these out. 54 | # Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES) 55 | valid_kwargs = { 56 | "read_only", 57 | "write_only", 58 | "required", 59 | "default", 60 | "initial", 61 | "source", 62 | "label", 63 | "help_text", 64 | "style", 65 | "error_messages", 66 | "validators", 67 | "allow_null", 68 | "allow_blank", 69 | "choices", 70 | } 71 | for key in list(field_kwargs): 72 | if key not in valid_kwargs: 73 | del field_kwargs[key] # pragma: nocover 74 | 75 | # Include any kwargs defined in `Meta.extra_kwargs` 76 | field_kwargs = self.include_extra_kwargs(field_kwargs, self._extra_kwargs.get(field_name)) 77 | 78 | if not issubclass(field_class, fields.CharField) and not issubclass(field_class, fields.ChoiceField): 79 | # `allow_blank` is only valid for textual fields. 80 | field_kwargs.pop("allow_blank", None) 81 | 82 | if hasattr(fields, "NullBooleanField") and issubclass(field_class, fields.NullBooleanField): 83 | # 'allow_null' is not valid kwarg for NullBooleanField 84 | field_kwargs.pop("allow_null", None) # pragma: nocover 85 | 86 | field_kwargs.pop("widget", None) 87 | return field_kwargs 88 | 89 | def build_standard_field(self, field_name, column_info): 90 | """Create regular model fields.""" 91 | field_class = self.get_field_type(column_info) 92 | 93 | field_kwargs = self.build_standard_field_kwargs(field_name, field_class, column_info) 94 | 95 | return field_class(**field_kwargs) 96 | 97 | def get_field_type(self, column_info): 98 | """Returns the field type to be used determined by the sqlalchemy 99 | column type or the column type's python type.""" 100 | field_class = get_field_type(column_info.column) 101 | 102 | if not field_class: 103 | raise KeyError( 104 | "Could not figure out type for attribute '{}.{}'".format(self.model.__name__, column_info.property.key) 105 | ) 106 | 107 | return field_class 108 | 109 | def include_extra_kwargs(self, kwargs, extra_kwargs=None): 110 | """Include any 'extra_kwargs' that have been included for this field, 111 | possibly removing any incompatible existing keyword arguments.""" 112 | extra_kwargs = extra_kwargs or {} 113 | if extra_kwargs.get("read_only", False): 114 | for attr in [ 115 | "required", 116 | "default", 117 | "allow_blank", 118 | "allow_null", 119 | "min_length", 120 | "max_length", 121 | "min_value", 122 | "max_value", 123 | "validators", 124 | "queryset", 125 | ]: 126 | kwargs.pop(attr, None) 127 | 128 | if extra_kwargs.get("default") and kwargs.get("required") is False: 129 | kwargs.pop("required") 130 | 131 | if extra_kwargs.get("read_only", kwargs.get("read_only", False)): 132 | # Read only fields should always omit the 'required' argument. 133 | extra_kwargs.pop("required", None) 134 | 135 | kwargs.update(extra_kwargs) 136 | 137 | return kwargs 138 | 139 | def create(self, validated_data): 140 | raise NotImplementedError() 141 | 142 | def update(self, instance, validated_data): 143 | raise NotImplementedError() 144 | 145 | def update_attribute(self, instance, field, value): 146 | """Performs update on the instance for the given field with value.""" 147 | field_setter = getattr(self, "set_" + field.field_name, None) 148 | if field_setter: 149 | field_setter(instance, field.source, value) 150 | else: 151 | setattr(instance, field.source, value) 152 | 153 | 154 | class CompositeSerializer(BaseSerializer): 155 | """This class is useful for generating a serializer for sqlalchemy's 156 | `composite` model attributes.""" 157 | 158 | def __init__(self, *args, **kwargs): 159 | composite_attr = kwargs.pop("composite", None) or getattr(getattr(self, "Meta", None), "composite", None) 160 | self._info = meta.model_info(composite_attr.prop.parent).composites[composite_attr.prop.key] 161 | 162 | super().__init__(*args, **kwargs) 163 | self.composite_class = self._info.prop.composite_class 164 | self.read_only = False 165 | self.required = False 166 | self.default = None 167 | self.allow_nested_updates = True 168 | self._extra_kwargs = {} 169 | 170 | def get_fields(self): 171 | """Return the dict of field names -> field instances that should be 172 | used for `self.fields` when instantiating the serializer.""" 173 | _fields = OrderedDict() 174 | 175 | for field_name, column_info in self._info.properties.items(): 176 | source = self._extra_kwargs.get(field_name, {}).get("source") or field_name 177 | 178 | _fields[field_name] = self.build_standard_field(source, column_info) 179 | 180 | return _fields 181 | 182 | def get_object(self, validated_data, instance=None): 183 | if validated_data is None: 184 | return 185 | 186 | if instance: 187 | return instance 188 | 189 | validated_data = validated_data or {} 190 | 191 | composite_args = [validated_data.get(i) for i in self._info.properties] 192 | return self.composite_class(*composite_args) 193 | 194 | def create(self, validated_data): 195 | instance = self.get_object(validated_data) 196 | return self.update(instance, validated_data) 197 | 198 | def update(self, instance, validated_data): 199 | errors = {} 200 | instance = self.perform_update(instance, validated_data, errors) 201 | 202 | if errors: 203 | raise ValidationError(errors) 204 | 205 | return instance 206 | 207 | def perform_update(self, instance, validated_data, errors): 208 | validated_data = validated_data or {} 209 | 210 | for field in self._writable_fields: 211 | 212 | if field.field_name not in validated_data: 213 | continue 214 | 215 | try: 216 | value = validated_data.get(field.field_name) 217 | 218 | self.update_attribute(instance, field, value) 219 | 220 | except DjangoValidationError as e: 221 | errors.setdefault(self.field_name, {}).update(django_to_drf_validation_error(e).detail) 222 | 223 | except Exception as e: 224 | errors.setdefault(field.field_name, []).append(" ".join(e.args)) 225 | 226 | return instance 227 | 228 | def __deepcopy__(self, memo=None): 229 | """When cloning fields we instantiate using the arguments it was 230 | originally created with, rather than copying the complete state.""" 231 | # Treat regexes, validators and session as immutable. 232 | args = [copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item for item in self._args] 233 | kwargs = { 234 | key: (copy.deepcopy(value) if (key not in ("validators", "regex", "composite")) else value) 235 | for key, value in self._kwargs.items() 236 | } 237 | return self.__class__(*args, **kwargs) 238 | 239 | 240 | class ModelSerializer(BaseSerializer): 241 | """ModelSerializer is basically like a drf model serializer except that it 242 | works with sqlalchemy models: 243 | 244 | * A set of default fields are automatically populated by introspecting a sqlalchemy model 245 | * Default `.create()` and `.update()` implementations provided by mostly reducing the problem 246 | to update. 247 | 248 | The process of automatically determining a set of serializer fields is based on the model's fields, components 249 | and relationships. 250 | 251 | If the `ModelSerializer` does not generate the set of fields that you need, you can explicitly declare them. 252 | """ 253 | 254 | url_field_name = None 255 | serializer_url_field = UriField 256 | 257 | default_error_messages = {"not_found": "No instance found with primary keys"} 258 | 259 | def __init__(self, *args, **kwargs): 260 | """ModelSerializer initializer The main things that we're interested in 261 | is the sqlalchemy session, you can provide it thru `Meta.session`, 262 | `session` kwarg or thru `context` 263 | 264 | `allow_nested_updates` is for controlling nested related model 265 | updates. 266 | """ 267 | self._session = kwargs.pop("session", None) or getattr(getattr(self, "Meta", None), "session", None) 268 | self.allow_nested_updates = kwargs.pop("allow_nested_updates", False) 269 | self.allow_create = kwargs.pop("allow_create", False) 270 | self.partial_by_pk = kwargs.pop("partial_by_pk", False) 271 | overwrite_fields = kwargs.pop("fields", fields.empty) 272 | overwrite_exclude = kwargs.pop("exclude", fields.empty) 273 | extra_kwargs = kwargs.pop("extra_kwargs", {}) 274 | 275 | super().__init__(*args, **kwargs) 276 | 277 | self._extra_kwargs = self.get_extra_kwargs(**extra_kwargs) 278 | self._overwrite_fields = overwrite_fields 279 | self._overwrite_exclude = overwrite_exclude 280 | 281 | def __deepcopy__(self, memo=None): 282 | """When cloning fields we instantiate using the arguments it was 283 | originally created with, rather than copying the complete state.""" 284 | # Treat regexes, validators and session as immutable. 285 | args = [copy.deepcopy(item) if not isinstance(item, REGEX_TYPE) else item for item in self._args] 286 | kwargs = { 287 | key: (copy.deepcopy(value) if (key not in ("validators", "regex", "session")) else value) 288 | for key, value in self._kwargs.items() 289 | } 290 | return self.__class__(*args, **kwargs) 291 | 292 | @property 293 | def session(self): 294 | if not self._session: 295 | 296 | self._session = self.context.get("session") 297 | 298 | assert self._session is not None, ( 299 | "Creating a {}(ModelSerializer) without the session attribute in Meta, " 300 | "as a keyword argument or without a session in the serializer context" 301 | "".format(self.__class__.__name__) 302 | ) 303 | 304 | return self._session 305 | 306 | @property 307 | def model(self): 308 | assert hasattr(self.Meta, "model"), 'Class {serializer_class} missing "Meta.model" attribute'.format( 309 | serializer_class=self.__class__.__name__ 310 | ) 311 | return self.Meta.model 312 | 313 | @property 314 | def queryset(self): 315 | return getattr(self.Meta, "queryset", None) or self.session.query(self.model) 316 | 317 | def get_fields(self): 318 | """Return the dict of field names -> field instances that should be 319 | used for `self.fields` when instantiating the serializer.""" 320 | if self.url_field_name is None: 321 | self.url_field_name = api_settings.URL_FIELD_NAME 322 | 323 | assert hasattr(self, "Meta"), 'Class {serializer_class} missing "Meta" attribute'.format( 324 | serializer_class=self.__class__.__name__ 325 | ) 326 | 327 | declared_fields = copy.deepcopy(self._declared_fields) 328 | info = meta.model_info(self.model) 329 | depth = getattr(self.Meta, "depth", 0) 330 | 331 | if depth is not None: 332 | assert depth >= 0, "'depth' may not be negative." 333 | assert depth <= 5, "'depth' may not be greater than 5." 334 | 335 | field_names = self.get_field_names(declared_fields, info) 336 | 337 | # Determine the fields that should be included on the serializer. 338 | _fields = OrderedDict() 339 | 340 | for field_name in field_names: 341 | # If the field is explicitly declared on the class then use that. 342 | if field_name in declared_fields: 343 | _fields[field_name] = declared_fields[field_name] 344 | continue 345 | 346 | source = self._extra_kwargs.get(field_name, {}).get("source") or field_name 347 | 348 | _fields[field_name] = self.build_field(source, info, self.model, depth) 349 | 350 | return _fields 351 | 352 | def get_field_names(self, declared_fields, info): 353 | """Returns the list of all field names that should be created when 354 | instantiating this serializer class. 355 | 356 | This is based on the default set of fields, but also takes into 357 | account the `Meta.fields` or `Meta.exclude` options if they have 358 | been specified. 359 | """ 360 | _fields = ( 361 | self._overwrite_fields if self._overwrite_fields is not fields.empty else getattr(self.Meta, "fields", None) 362 | ) 363 | exclude = ( 364 | self._overwrite_exclude 365 | if self._overwrite_exclude is not fields.empty 366 | else getattr(self.Meta, "exclude", None) 367 | ) 368 | 369 | if _fields and _fields != ALL_FIELDS and not isinstance(_fields, (list, tuple)): 370 | raise TypeError( 371 | 'The `fields` option must be a list or tuple or "__all__". ' "Got %s." % type(_fields).__name__ 372 | ) 373 | 374 | if exclude and not isinstance(exclude, (list, tuple)): 375 | raise TypeError("The `exclude` option must be a list or tuple. Got %s." % type(exclude).__name__) 376 | 377 | assert not ( 378 | _fields and exclude 379 | ), "Cannot set both 'fields' and 'exclude' options on " "serializer {serializer_class}.".format( 380 | serializer_class=self.__class__.__name__ 381 | ) 382 | 383 | assert _fields is not None or exclude is not None, ( 384 | "Creating a ModelSerializer without either the 'fields' attribute " 385 | "or the 'exclude' attribute has been deprecated since 3.3.0, " 386 | "and is now disallowed. Add an explicit fields = '__all__' to the " 387 | "{serializer_class} serializer.".format(serializer_class=self.__class__.__name__), 388 | ) 389 | 390 | if _fields == ALL_FIELDS: 391 | _fields = None 392 | 393 | if _fields is not None: 394 | # Ensure that all declared fields have also been included in the 395 | # `Meta.fields` option. 396 | 397 | # Do not require any fields that are declared a parent class, 398 | # in order to allow serializer subclasses to only include 399 | # a subset of fields. 400 | required_field_names = set(declared_fields) 401 | for cls in self.__class__.__bases__: 402 | required_field_names -= set(getattr(cls, "_declared_fields", [])) 403 | 404 | for field_name in required_field_names: 405 | assert field_name in _fields, ( 406 | "The field '{field_name}' was declared on serializer " 407 | "{serializer_class}, but has not been included in the " 408 | "'fields' option.".format(field_name=field_name, serializer_class=self.__class__.__name__) 409 | ) 410 | return _fields 411 | 412 | # Use the default set of field names if `Meta.fields` is not specified. 413 | _fields = self.get_default_field_names(declared_fields, info) 414 | 415 | if exclude is not None: 416 | # If `Meta.exclude` is included, then remove those fields. 417 | for field_name in exclude: 418 | assert field_name in _fields, ( 419 | "The field '{field_name}' was included on serializer " 420 | "{serializer_class} in the 'exclude' option, but does " 421 | "not match any model field.".format(field_name=field_name, serializer_class=self.__class__.__name__) 422 | ) 423 | _fields.remove(field_name) 424 | 425 | return _fields 426 | 427 | def get_default_field_names(self, declared_fields, info): 428 | """Return the default list of field names that will be used if the 429 | `Meta.fields` option is not specified.""" 430 | return info.field_names + list(declared_fields.keys()) 431 | 432 | def get_extra_kwargs(self, **additional_kwargs): 433 | """Return a dictionary mapping field names to a dictionary of 434 | additional keyword arguments.""" 435 | extra_kwargs = copy.deepcopy(getattr(self.Meta, "extra_kwargs", {})) or {} 436 | 437 | read_only_fields = getattr(self.Meta, "read_only_fields", None) 438 | if read_only_fields is not None: 439 | if not isinstance(read_only_fields, (list, tuple)): 440 | raise TypeError( 441 | "The `read_only_fields` option must be a list or tuple. " 442 | "Got %s." % type(read_only_fields).__name__ 443 | ) 444 | 445 | for field_name in read_only_fields: 446 | kwargs = extra_kwargs.get(field_name, {}) 447 | kwargs["read_only"] = True 448 | extra_kwargs[field_name] = kwargs 449 | 450 | extra_kwargs.update(additional_kwargs) 451 | 452 | return extra_kwargs 453 | 454 | def build_field(self, field_name, info, model_class, nested_depth): 455 | """Return a field or a nested serializer for the field name.""" 456 | if field_name in info.primary_keys: 457 | pk = info.primary_keys[field_name] 458 | return self.build_primary_key_field(field_name, pk) 459 | 460 | elif field_name in info.properties: 461 | prop = info.properties[field_name] 462 | return self.build_standard_field(field_name, prop) 463 | 464 | elif field_name in info.relationships: 465 | relation_info = info.relationships[field_name] 466 | return self.build_nested_field(field_name, relation_info, nested_depth) 467 | 468 | elif field_name in info.composites: 469 | composite = info.composites[field_name] 470 | return self.build_composite_field(field_name, getattr(info.model_class, composite.prop.key)) 471 | 472 | elif hasattr(info.model_class, field_name): 473 | return self.build_property_field(field_name, info) 474 | 475 | elif field_name == self.url_field_name: 476 | return self.build_url_field(field_name, info) 477 | 478 | return self.build_unknown_field(field_name, info) 479 | 480 | def build_primary_key_field(self, field_name, column_info): 481 | """Builds a field for the primary key of the model.""" 482 | field_class = self.get_field_type(column_info) 483 | 484 | field_kwargs = self.build_standard_field_kwargs(field_name, field_class, column_info) 485 | 486 | if self.is_nested: 487 | if self.allow_create or self.allow_null: 488 | # since we're allowed to create new instances, pk is not required 489 | field_kwargs["required"] = False 490 | 491 | elif column_info.column.default is not None or column_info.column.autoincrement is True: 492 | # pk has a default value or its an autoincremented column so the field should be read only 493 | field_kwargs.pop("required", None) 494 | field_kwargs["read_only"] = True 495 | 496 | return field_class(**field_kwargs) 497 | 498 | def build_composite_field(self, field_name, composite): 499 | """Builds a `CompositeSerializer` to handle composite attribute in 500 | model.""" 501 | field_kwargs = {"composite": composite} 502 | field_kwargs = self.include_extra_kwargs(field_kwargs, self._extra_kwargs.get(field_name)) 503 | return CompositeSerializer(**field_kwargs) 504 | 505 | def build_nested_field(self, field_name, relation_info, nested_depth): 506 | """Builds nested serializer to handle relationshipped model.""" 507 | target_model = relation_info.related_model 508 | nested_fields = self.get_nested_relationship_fields(relation_info, nested_depth) 509 | 510 | field_kwargs = self.get_relationship_kwargs(relation_info, nested_depth) 511 | field_kwargs = self.include_extra_kwargs(field_kwargs, self._extra_kwargs.get(field_name)) 512 | nested_extra_kwargs = {} 513 | 514 | nested_info = meta.model_info(target_model) 515 | if not field_kwargs.get("required", True): 516 | for nested_field in nested_info.primary_keys: 517 | nested_extra_kwargs.setdefault(nested_field, {}).setdefault("required", False) 518 | 519 | if not field_kwargs.get("allow_nested_updates", True): 520 | nested_depth = 0 521 | for nested_field in nested_info.properties: 522 | nested_extra_kwargs.setdefault(nested_field, {}).setdefault("read_only", True) 523 | nested_extra_kwargs.setdefault(nested_field, {}).pop("required", None) 524 | 525 | class NestedSerializer(getattr(self.Meta, "nested_serializer_class", ModelSerializer)): 526 | class Meta: 527 | model = target_model 528 | session = self.session 529 | depth = max(0, nested_depth - 1) 530 | fields = nested_fields 531 | extra_kwargs = nested_extra_kwargs 532 | 533 | return type(str(target_model.__name__ + "Serializer"), (NestedSerializer,), {})(**field_kwargs) 534 | 535 | def build_property_field(self, field_name, info): 536 | return fields.ReadOnlyField() 537 | 538 | def build_url_field(self, field_name, info): 539 | """Create a field representing the object's own URL.""" 540 | field_class = self.serializer_url_field 541 | field_kwargs = get_url_kwargs(info.model_class) 542 | field_kwargs.update(self._extra_kwargs.get(self.url_field_name, {})) 543 | 544 | return field_class(**field_kwargs) 545 | 546 | def build_unknown_field(self, field_name, info): 547 | """Raise an error on any unknown fields.""" 548 | raise ImproperlyConfigured( 549 | "Field name `{}` is not valid for model `{}`.".format(field_name, info.model_class.__name__) 550 | ) 551 | 552 | def get_relationship_kwargs(self, relation_info, depth): 553 | """Figure out the arguments to be used in the `NestedSerializer` for 554 | the relationship.""" 555 | kwargs = {} 556 | if relation_info.direction == ONETOMANY: 557 | kwargs["required"] = False 558 | kwargs["allow_null"] = True 559 | elif all(col.nullable for col in relation_info.foreign_keys): 560 | kwargs["required"] = False 561 | kwargs["allow_null"] = True 562 | 563 | if relation_info.uselist: 564 | kwargs["many"] = True 565 | kwargs["required"] = False 566 | 567 | return kwargs 568 | 569 | def get_nested_relationship_fields(self, relation_info, depth): 570 | """Get the field names for the nested serializer.""" 571 | target_model_info = meta.model_info(relation_info.related_model) 572 | 573 | # figure out backrefs 574 | backrefs = {key for key, rel in target_model_info.relationships.items() if rel.related_model == self.model} 575 | 576 | _fields = set(target_model_info.primary_keys.keys()) 577 | _fields.update(target_model_info.properties.keys()) 578 | if depth > 0: 579 | _fields.update(target_model_info.composites.keys()) 580 | _fields.update(target_model_info.relationships.keys()) 581 | 582 | _fields = _fields - backrefs 583 | 584 | return tuple(field for field in _fields if not field.startswith("_")) 585 | 586 | def to_internal_value(self, data): 587 | """Same as in DRF but also handle ``partial_by_pk`` by making all non- 588 | pk fields optional. 589 | 590 | Even though flag name implies it will make serializer partial, 591 | that is currently not possible in DRF as partial flag is checked 592 | on root serializer within serializer validation loops. As such, 593 | individual serializers cannot be marked partial. Therefore when 594 | flag is provided and primary key is provided in validated data, 595 | we physically mark all other fields as not required to 596 | effectively make them partial without using ``partial`` flag 597 | itself. To make serializer behave more or less like real partial 598 | serializer, only passed keys in input data are preserved in 599 | validated data. If they are not stripped, it is possible to 600 | remove some existing data. 601 | """ 602 | if not self.partial_by_pk or not self.get_primary_keys(data): 603 | return super().to_internal_value(data) 604 | 605 | info = meta.model_info(self.model) 606 | 607 | for _, field in self.fields.items(): 608 | if field.source not in info.primary_keys: 609 | field.required = False 610 | 611 | passed_keys = set(data) 612 | data = super().to_internal_value(data) 613 | 614 | for k in set(data) - passed_keys: 615 | if k in self.fields and self.fields[k].get_default() == data[k]: 616 | data.pop(k) 617 | 618 | return data 619 | 620 | def get_primary_keys(self, validated_data): 621 | """Returns the primary key values from validated_data.""" 622 | if not validated_data: 623 | return 624 | 625 | info = meta.model_info(self.queryset._only_full_mapper_zero("get").class_) 626 | return info.primary_keys_from_dict( 627 | {getattr(self.fields.get(k), "source", None) or k: v for k, v in validated_data.items()} 628 | ) 629 | 630 | def get_object(self, validated_data, instance=None): 631 | """Returns model object instance using the primary key values in the 632 | `validated_data`. 633 | 634 | If the instance is not found, depending on serializer's 635 | `allow_create` value, it will create a new model instance or 636 | raise an error. 637 | """ 638 | pks = self.get_primary_keys(validated_data) 639 | if validated_data and pks: 640 | return self.query_model(pks) or self.fail("not_found") 641 | 642 | # if validated data is None, it means it was explicitly set as None 643 | # in self.initial_data hence we normalize to None 644 | # regardless if parent already had this relation set 645 | if validated_data is None: 646 | instance = None 647 | 648 | if instance is not None: 649 | return instance 650 | 651 | elif validated_data is not None and self.allow_create: 652 | return self.model() 653 | 654 | elif self.allow_null: 655 | return 656 | 657 | else: 658 | raise self.fail("required") 659 | 660 | def save(self, **kwargs): 661 | """Save and return a list of object instances.""" 662 | with self.session.no_autoflush: 663 | self.instance = super().save(**kwargs) 664 | 665 | self.perform_flush() 666 | return self.instance 667 | 668 | def perform_flush(self): 669 | """Perform session flush changes.""" 670 | try: 671 | self.session.flush() 672 | except DjangoValidationError as e: 673 | e = django_to_drf_validation_error(e) 674 | self._errors = e.detail 675 | raise e 676 | 677 | def query_model(self, pks): 678 | """Hook to allow to customize how model is queried when serializer is 679 | nested and needs to query the model by its primary keys.""" 680 | return self.queryset.get(pks) 681 | 682 | def create_model(self, validated_data): 683 | """Hook to allow to customize how model is created in create flow.""" 684 | return self.model() 685 | 686 | def create(self, validated_data): 687 | """Creates a model instance using validated_data.""" 688 | instance = self.update(self.create_model(validated_data), validated_data) 689 | self.session.add(instance) 690 | return instance 691 | 692 | def update(self, instance, validated_data): 693 | """Updates an existing model instance using validated_data with 694 | suspended autoflush.""" 695 | errors = {} 696 | instance = self.perform_update(instance, validated_data, errors) 697 | 698 | if errors: 699 | raise ValidationError(errors) 700 | 701 | return instance 702 | 703 | def perform_update(self, instance, validated_data, errors): 704 | """The main nested update logic implementation using nested fields and 705 | serializer.""" 706 | for field in self._writable_fields: 707 | try: 708 | if isinstance(field, BaseSerializer): 709 | if field.source == "*": 710 | value = validated_data 711 | child_instance = instance 712 | else: 713 | if field.source not in validated_data: 714 | continue 715 | value = validated_data.get(field.source) 716 | child_instance = getattr(instance, field.source, None) 717 | child_instance = field.get_object(value, child_instance) 718 | 719 | if child_instance and field.allow_nested_updates: 720 | value = field.perform_update(child_instance, value, errors) 721 | else: 722 | value = child_instance 723 | 724 | elif isinstance(field, serializers.ListSerializer) and isinstance(field.child, BaseSerializer): 725 | if field.source not in validated_data: 726 | continue 727 | 728 | value = [] 729 | 730 | for item in validated_data.get(field.source): 731 | child_instance = field.child.get_object(item) 732 | if child_instance and (field.child.allow_create or field.child.allow_nested_updates): 733 | v = field.child.perform_update(child_instance, item, errors) 734 | else: 735 | v = child_instance 736 | 737 | if v: 738 | value.append(v) 739 | 740 | else: 741 | if field.source not in validated_data: 742 | continue 743 | 744 | value = validated_data.get(field.source) 745 | 746 | self.update_attribute(instance, field, value) 747 | 748 | except DjangoValidationError as e: 749 | errors.update(django_to_drf_validation_error(e).detail) 750 | 751 | except Exception as e: 752 | errors.setdefault(field.field_name, []).append(" ".join(map(str, e.args))) 753 | 754 | return instance 755 | 756 | 757 | class ExpandableModelSerializer(ModelSerializer): 758 | """Same as ``ModelSerializer`` but allows to conditionally recursively 759 | expand specific fields. 760 | 761 | Serializer by default renders with all fields collapsed 762 | however validates data with expanded fields. 763 | 764 | To expand fields, either: 765 | 766 | * ``request.GET`` should request to expand field by ``?expand=``. 767 | Field names can be recursive ``?expand=__``. 768 | * One of expandable fields was updated which will cause 769 | ``to_representation()`` to render expanded field. 770 | 771 | By default serializer should define "expanded" fields. 772 | ``ModelSerializer`` already does it by default for all relations. 773 | This allows introspection of not rendered serializer to pick up all fields. 774 | This is especially useful when generating schema for the serializer 775 | such as for coreapi docs. 776 | Collapsed fields are specified in ``Meta.expandable_fields`` where keys 777 | are field names and values are replacement field instances. 778 | 779 | In addition expandable query key can be specified via ``Meta.expandable_query_key``. 780 | 781 | For example: 782 | 783 | .. code:: 784 | 785 | class BarJustIDSerializer(Serializer): 786 | id = serializers.IntegerField(source="bar_id") 787 | 788 | class Meta: 789 | model = Bar 790 | session = session 791 | fields = ["id"] 792 | 793 | class FooSerializer(ExpandableModelSerializer): 794 | class Meta: 795 | model = Foo 796 | session = session 797 | exclude = ["bar_id"] 798 | expandable_fields = { 799 | "bar": BarJustIDSerializer(source="*", read_only=True) 800 | } 801 | expandable_query_key = "include" 802 | 803 | Additionally, query serializer can be autogenerated to be used to either validate request query 804 | or generate documentation: 805 | 806 | .. code:: 807 | 808 | FooSerializer().get_query_serializer_class() 809 | FooSerializer().get_query_serializer_class(exclude=["bar"]) 810 | FooSerializer().get_query_serializer_class(disallow=["bar"]) 811 | 812 | :``exclude``: excludes given expand paths. Useful for generating documentation. 813 | :``disallow``: leaves the expand field in serializer however removes given paths from valid choices. 814 | Useful for validating user input within viewset. 815 | """ 816 | 817 | def update_attribute(self, instance, field, value): 818 | """Mark which attributes are updated so that during representation of 819 | the resource, we can expand those fields even if not explicitly asked 820 | for. 821 | 822 | Fields are marked on root serializer since child serializers 823 | should not contain any state. 824 | """ 825 | try: 826 | self.root._updated_fields.setdefault(id(self), []).append(field.field_name) 827 | except AttributeError: 828 | self.root._updated_fields = {id(self): [field.field_name]} 829 | 830 | return super().update_attribute(instance, field, value) 831 | 832 | def to_representation(self, instance): 833 | """Switch expandable fields to collapsed fields if not explicitly asked 834 | to be expanded or field was updated.""" 835 | expandable_query_key = getattr(self.Meta, "expandable_query_key", "expand") 836 | 837 | for i in self._expandable_fields: 838 | if any( 839 | [ 840 | # if no context provided usually is used by schema generation 841 | self.context is None, 842 | # path explicitly provided in request.GET to be included 843 | i.path in getattr(self.context.get("request"), "GET", QueryDict()).getlist(expandable_query_key), 844 | # path was implicitly added by query serializer 845 | i.path 846 | in (getattr(self.context.get("query_serializer"), "validated_data", None) or {}).get( 847 | expandable_query_key, () 848 | ), 849 | # field was explicitly updated so we leave it in representation 850 | i.name in getattr(self.root, "_updated_fields", {}).get(id(self), []), 851 | ] 852 | ): 853 | continue 854 | 855 | # no reason to leave full field in representation 856 | self.fields[i.name] = i.replacement 857 | 858 | return super().to_representation(instance) 859 | 860 | @property 861 | def _expandable_fields(self): 862 | """Get all defined expandable fields with their path within 863 | serializers.""" 864 | components = [] 865 | root = self.root 866 | f = self 867 | while f is not root: 868 | if f.field_name: 869 | components.insert(0, f.field_name) 870 | f = f.parent 871 | 872 | nt = namedtuple("ExpandableField", ["name", "parts", "path", "replacement"]) 873 | 874 | for name, replacement in getattr(self.Meta, "expandable_fields", {}).items(): 875 | parts = components + [name] 876 | path = LOOKUP_SEP.join(parts) 877 | yield nt(name, parts, path, copy.deepcopy(replacement)) 878 | 879 | def _get_all_expandable_fields(self, parents, this, exclude): 880 | """Recursively search for all expandable fields on class.""" 881 | nt = namedtuple("ExpandableField", ["query_key", "parts", "path"]) 882 | 883 | query_key = getattr(getattr(this, "Meta", None), "expandable_query_key", "expand") 884 | for name in getattr(getattr(this, "Meta", None), "expandable_fields", {}): 885 | parts = parents + [name] 886 | path = LOOKUP_SEP.join(parts) 887 | if path in exclude: 888 | continue 889 | yield nt(query_key, parts, path) 890 | 891 | for field_name, field in this.fields.items(): 892 | if not isinstance(field, serializers.BaseSerializer): 893 | continue 894 | if isinstance(field, serializers.ListSerializer): 895 | field = field.child 896 | 897 | yield from self._get_all_expandable_fields(parents=parents + [field_name], this=field, exclude=exclude) 898 | 899 | def get_query_serializer_class(self, exclude=(), disallow=(), implicit_expand=True): 900 | """Generate serializer to either validate request querystring or 901 | generate documentation.""" 902 | attrs = { 903 | k: (ImplicitExpandableListField if implicit_expand else fields.ListField)( 904 | required=False, 905 | help_text=( 906 | "Query parameter to expand nested fields. " 907 | "Can be provided multiple times to expand multiple fields. " 908 | "Field is automatically expanded whenever it is updated." 909 | ), 910 | child=fields.ChoiceField(required=False, choices=[i.path for i in v if i.path not in disallow]), 911 | ) 912 | for k, v in groupby( 913 | self._get_all_expandable_fields(parents=[], this=self, exclude=exclude), key=lambda i: i.query_key 914 | ) 915 | } 916 | attrs["implicit_expand"] = implicit_expand 917 | return type("ExpandableQuerySerializer", (serializers.Serializer,), attrs) 918 | -------------------------------------------------------------------------------- /tests/test_serializers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict 3 | from decimal import Decimal 4 | 5 | from django.core.exceptions import ImproperlyConfigured, ValidationError as DjangoValidationError 6 | from django.test import SimpleTestCase 7 | 8 | from django_sorcery.db.meta import model_info 9 | 10 | from rest_framework import fields 11 | from rest_framework.exceptions import ErrorDetail, ValidationError 12 | from rest_framework.serializers import ListSerializer, Serializer 13 | from rest_framework.settings import api_settings 14 | from rest_framework.test import APIRequestFactory 15 | 16 | from rest_witchcraft.fields import HyperlinkedIdentityField 17 | from rest_witchcraft.serializers import BaseSerializer, CompositeSerializer, ExpandableModelSerializer, ModelSerializer 18 | 19 | from .models import COLORS, Engine, ModelWithJson, Option, Owner, Vehicle, VehicleOther, VehicleType, session 20 | 21 | 22 | class VehicleOwnerStubSerializer(Serializer): 23 | id = fields.IntegerField(source="_owner_id") 24 | 25 | 26 | class VehicleSerializer(ExpandableModelSerializer): 27 | class Meta: 28 | model = Vehicle 29 | session = session 30 | expandable_fields = {"owner": VehicleOwnerStubSerializer(source="*", read_only=True)} 31 | exclude = ["other"] 32 | extra_kwargs = {"owner": {"allow_nested_updates": True}} 33 | nested_serializer_class = ExpandableModelSerializer 34 | 35 | 36 | class TestModelSerializer(SimpleTestCase): 37 | def setUp(self): 38 | super().setUp() 39 | session.add(Owner(id=1, first_name="Test", last_name="Owner")) 40 | session.add_all( 41 | [ 42 | Option(id=1, name="Option 1"), 43 | Option(id=2, name="Option 2"), 44 | Option(id=3, name="Option 3"), 45 | Option(id=4, name="Option 4"), 46 | ] 47 | ) 48 | session.flush() 49 | self.maxDiff = None 50 | 51 | def tearDown(self): 52 | super().tearDown() 53 | session.rollback() 54 | 55 | def test_cannot_initialize_without_a_meta(self): 56 | class VehicleSerializer(ModelSerializer): 57 | pass 58 | 59 | with self.assertRaises(AttributeError): 60 | VehicleSerializer() 61 | 62 | def test_cannot_initialize_without_a_session(self): 63 | class VehicleSerializer(ModelSerializer): 64 | class Meta: 65 | pass 66 | 67 | with self.assertRaises(AssertionError): 68 | serializer = VehicleSerializer() 69 | serializer.session 70 | 71 | def test_cannot_initialize_without_a_model_with_session_meta(self): 72 | class VehicleSerializer(ModelSerializer): 73 | class Meta: 74 | session = session 75 | 76 | with self.assertRaises(AssertionError): 77 | serializer = VehicleSerializer() 78 | serializer.model 79 | 80 | def test_cannot_initialize_without_a_model_with_session_kwarg(self): 81 | class VehicleSerializer(ModelSerializer): 82 | class Meta: 83 | pass 84 | 85 | with self.assertRaises(AssertionError): 86 | serializer = VehicleSerializer(session=session) 87 | serializer.model 88 | 89 | def test_get_fields_sets_url_field_name_when_missing(self): 90 | class VehicleSerializer(ModelSerializer): 91 | class Meta: 92 | model = Vehicle 93 | session = session 94 | exclude = ("name",) 95 | 96 | serializer = VehicleSerializer() 97 | serializer.get_fields() 98 | 99 | self.assertEqual(serializer.url_field_name, api_settings.URL_FIELD_NAME) 100 | 101 | def test_raises_type_error_if_fields_is_not_a_list_or_tuple(self): 102 | class VehicleSerializer(ModelSerializer): 103 | class Meta: 104 | model = Vehicle 105 | session = session 106 | fields = "name" 107 | 108 | serializer = VehicleSerializer() 109 | 110 | with self.assertRaises(TypeError): 111 | serializer.get_fields() 112 | 113 | def test_raises_type_error_if_exclude_is_not_a_list_or_tuple(self): 114 | class VehicleSerializer(ModelSerializer): 115 | class Meta: 116 | model = Vehicle 117 | session = session 118 | exclude = "name" 119 | 120 | serializer = VehicleSerializer() 121 | 122 | with self.assertRaises(TypeError): 123 | serializer.get_fields() 124 | 125 | def test_get_default_field_names_should_get_all_field_names(self): 126 | class VehicleSerializer(ModelSerializer): 127 | class Meta: 128 | model = Vehicle 129 | session = session 130 | fields = ("id", "name") 131 | 132 | serializer = VehicleSerializer() 133 | info = model_info(Vehicle) 134 | field_names = serializer.get_default_field_names({}, info) 135 | self.assertEqual( 136 | set(field_names), 137 | { 138 | Vehicle.created_at.key, 139 | Vehicle.engine.key, 140 | Vehicle.id.key, 141 | Vehicle.name.key, 142 | Vehicle.options.key, 143 | Vehicle.other.key, 144 | Vehicle.owner.key, 145 | Vehicle.paint.key, 146 | Vehicle.type.key, 147 | Vehicle.is_used.key, 148 | }, 149 | ) 150 | 151 | def test_get_field_names_with_include(self): 152 | class VehicleSerializer(ModelSerializer): 153 | class Meta: 154 | model = Vehicle 155 | session = session 156 | fields = ("id", "name") 157 | 158 | serializer = VehicleSerializer() 159 | info = model_info(Vehicle) 160 | field_names = serializer.get_field_names([], info) 161 | self.assertEqual(set(field_names), {Vehicle.id.key, Vehicle.name.key}) 162 | 163 | def test_get_field_names_with_exclude(self): 164 | class VehicleSerializer(ModelSerializer): 165 | class Meta: 166 | model = Vehicle 167 | session = session 168 | exclude = ("type", "options") 169 | 170 | serializer = VehicleSerializer() 171 | info = model_info(Vehicle) 172 | field_names = serializer.get_field_names({}, info) 173 | self.assertEqual( 174 | set(field_names), 175 | { 176 | Vehicle.created_at.key, 177 | Vehicle.engine.key, 178 | Vehicle.id.key, 179 | Vehicle.name.key, 180 | Vehicle.other.key, 181 | Vehicle.owner.key, 182 | Vehicle.paint.key, 183 | Vehicle.is_used.key, 184 | }, 185 | ) 186 | 187 | def test_generate_all_fields(self): 188 | class VehicleSerializer(ModelSerializer): 189 | class Meta: 190 | model = Vehicle 191 | session = session 192 | fields = "__all__" 193 | 194 | serializer = VehicleSerializer() 195 | generated_fields = serializer.get_fields() 196 | 197 | self.assertIn(Vehicle.id.key, generated_fields) 198 | self.assertIn(Vehicle.type.key, generated_fields) 199 | self.assertIn(Vehicle.name.key, generated_fields) 200 | self.assertIn(Vehicle.engine.key, generated_fields) 201 | self.assertIn(Vehicle.owner.key, generated_fields) 202 | self.assertIn(Vehicle.options.key, generated_fields) 203 | 204 | self.assertFalse(generated_fields[Vehicle.options.key].read_only) 205 | 206 | def test_overwrite_extra_kwargs(self): 207 | class VehicleSerializer(ModelSerializer): 208 | class Meta: 209 | model = Vehicle 210 | session = session 211 | fields = "__all__" 212 | 213 | serializer = VehicleSerializer(extra_kwargs={Vehicle.options.key: {"read_only": True}}) 214 | generated_fields = serializer.get_fields() 215 | 216 | self.assertIn(Vehicle.id.key, generated_fields) 217 | self.assertIn(Vehicle.type.key, generated_fields) 218 | self.assertIn(Vehicle.name.key, generated_fields) 219 | self.assertIn(Vehicle.engine.key, generated_fields) 220 | self.assertIn(Vehicle.owner.key, generated_fields) 221 | self.assertIn(Vehicle.options.key, generated_fields) 222 | 223 | self.assertTrue(generated_fields[Vehicle.options.key].read_only) 224 | 225 | def test_overwrite_fields_exlude(self): 226 | class VehicleSerializer(ModelSerializer): 227 | class Meta: 228 | model = Vehicle 229 | session = session 230 | fields = "__all__" 231 | 232 | serializer = VehicleSerializer(fields=None, exclude=["options"]) 233 | generated_fields = serializer.get_fields() 234 | 235 | self.assertIn(Vehicle.id.key, generated_fields) 236 | self.assertIn(Vehicle.type.key, generated_fields) 237 | self.assertIn(Vehicle.name.key, generated_fields) 238 | self.assertIn(Vehicle.engine.key, generated_fields) 239 | self.assertIn(Vehicle.owner.key, generated_fields) 240 | self.assertNotIn(Vehicle.options.key, generated_fields) 241 | 242 | def test_declared_field(self): 243 | class VehicleSerializer(ModelSerializer): 244 | name = fields.ChoiceField(choices=["a", "b"]) 245 | 246 | class Meta: 247 | model = Vehicle 248 | session = session 249 | fields = "__all__" 250 | 251 | serializer = VehicleSerializer() 252 | generated_fields = serializer.get_fields() 253 | 254 | self.assertIsInstance(generated_fields["name"], fields.ChoiceField) 255 | 256 | def test_get_field_names_includes_all_required_fields(self): 257 | class VehicleSerializer(ModelSerializer): 258 | class Meta: 259 | model = Vehicle 260 | session = session 261 | fields = ("id", "name") 262 | 263 | serializer = VehicleSerializer() 264 | info = model_info(Vehicle) 265 | 266 | with self.assertRaises(AssertionError): 267 | serializer.get_field_names(["type"], info) 268 | 269 | def test_include_extra_kwargs(self): 270 | serializer = BaseSerializer() 271 | 272 | kwargs = {} 273 | extra_kwargs = {} 274 | 275 | kwargs = serializer.include_extra_kwargs(kwargs, extra_kwargs) 276 | 277 | self.assertEqual(kwargs, {}) 278 | 279 | def test_include_extra_kwargs_filter_when_read_only(self): 280 | serializer = BaseSerializer() 281 | 282 | kwargs = { 283 | "allow_blank": True, 284 | "allow_null": True, 285 | "default": True, 286 | "max_length": 255, 287 | "max_value": 255, 288 | "min_length": 0, 289 | "min_value": 0, 290 | "queryset": None, 291 | "required": True, 292 | "validators": None, 293 | } 294 | extra_kwargs = {"read_only": True} 295 | 296 | kwargs = serializer.include_extra_kwargs(kwargs, extra_kwargs) 297 | 298 | self.assertEqual(kwargs, {"read_only": True}) 299 | 300 | def test_include_extra_kwargs_filter_required_when_default_provided(self): 301 | serializer = BaseSerializer() 302 | 303 | kwargs = {"required": False} 304 | extra_kwargs = {"default": True} 305 | 306 | kwargs = serializer.include_extra_kwargs(kwargs, extra_kwargs) 307 | 308 | self.assertEqual(kwargs, {"default": True}) 309 | 310 | def test_base_serializer_raises_on_create(self): 311 | serializer = BaseSerializer() 312 | 313 | with self.assertRaises(NotImplementedError): 314 | serializer.create({}) 315 | 316 | def test_base_serializer_raises_on_update(self): 317 | serializer = BaseSerializer() 318 | 319 | with self.assertRaises(NotImplementedError): 320 | serializer.update(None, {}) 321 | 322 | def test_get_extra_kwargs_with_no_extra_kwargs(self): 323 | class VehicleSerializer(ModelSerializer): 324 | class Meta: 325 | model = Vehicle 326 | session = session 327 | fields = ("id", "name") 328 | 329 | serializer = VehicleSerializer() 330 | extra_kwargs = serializer.get_extra_kwargs() 331 | self.assertEqual(extra_kwargs, {}) 332 | 333 | def test_get_extra_kwargs_with_extra_kwargs(self): 334 | class VehicleSerializer(ModelSerializer): 335 | class Meta: 336 | model = Vehicle 337 | session = session 338 | fields = ("id", "name") 339 | extra_kwargs = {"name": {"read_only": True}} 340 | 341 | serializer = VehicleSerializer() 342 | extra_kwargs = serializer.get_extra_kwargs() 343 | self.assertEqual(extra_kwargs, {"name": {"read_only": True}}) 344 | 345 | def test_get_extra_kwargs_with_read_only_fields(self): 346 | class VehicleSerializer(ModelSerializer): 347 | class Meta: 348 | model = Vehicle 349 | session = session 350 | fields = ("id", "name") 351 | read_only_fields = ("id", "name") 352 | 353 | serializer = VehicleSerializer() 354 | extra_kwargs = serializer.get_extra_kwargs() 355 | self.assertEqual(extra_kwargs, {"id": {"read_only": True}, "name": {"read_only": True}}) 356 | 357 | def test_get_extra_kwargs_with_read_only_fields_as_string(self): 358 | class VehicleSerializer(ModelSerializer): 359 | class Meta: 360 | model = Vehicle 361 | session = session 362 | fields = ("id", "name") 363 | read_only_fields = "id" 364 | 365 | with self.assertRaises(TypeError): 366 | VehicleSerializer() 367 | 368 | def test_build_standard_integer_field(self): 369 | class VehicleSerializer(ModelSerializer): 370 | class Meta: 371 | model = Vehicle 372 | session = session 373 | fields = "__all__" 374 | 375 | serializer = VehicleSerializer() 376 | info = model_info(Vehicle) 377 | field = serializer.build_field(Vehicle.id.key, info, Vehicle, 0) 378 | 379 | self.assertEqual(field.help_text, Vehicle.id.doc) 380 | self.assertEqual(field.label, "Id") 381 | self.assertFalse(field.allow_null) 382 | self.assertIsInstance(field, fields.IntegerField) 383 | self.assertFalse(field.required) 384 | 385 | def test_build_standard_char_field(self): 386 | class VehicleSerializer(ModelSerializer): 387 | class Meta: 388 | model = Vehicle 389 | session = session 390 | fields = "__all__" 391 | 392 | serializer = VehicleSerializer() 393 | info = model_info(Vehicle) 394 | field = serializer.build_field(Vehicle.name.key, info, Vehicle, 0) 395 | 396 | self.assertEqual(field.help_text, Vehicle.name.doc) 397 | self.assertEqual(field.label, "Name") 398 | self.assertFalse(field.required) 399 | self.assertIsInstance(field, fields.CharField) 400 | self.assertTrue(field.allow_null) 401 | 402 | def test_build_enum_field(self): 403 | class VehicleSerializer(ModelSerializer): 404 | class Meta: 405 | model = Vehicle 406 | session = session 407 | fields = ("type",) 408 | 409 | serializer = VehicleSerializer() 410 | info = model_info(Vehicle) 411 | field = serializer.build_field(Vehicle.type.key, info, Vehicle, 0) 412 | 413 | self.assertEqual(field.help_text, Vehicle.type.doc) 414 | self.assertEqual(field.label, "Type") 415 | self.assertTrue(field.required) 416 | self.assertIsInstance(field, fields.ChoiceField) 417 | self.assertFalse(field.allow_null) 418 | 419 | def test_build_choice_field(self): 420 | class VehicleSerializer(ModelSerializer): 421 | class Meta: 422 | model = Vehicle 423 | session = session 424 | fields = ("paint",) 425 | 426 | serializer = VehicleSerializer() 427 | info = model_info(Vehicle) 428 | field = serializer.build_field(Vehicle.paint.key, info, Vehicle, 0) 429 | 430 | self.assertEqual(field.help_text, Vehicle.paint.doc) 431 | self.assertEqual(field.label, "Paint") 432 | self.assertFalse(field.required) 433 | self.assertIsInstance(field, fields.ChoiceField) 434 | self.assertEqual(field.choices, OrderedDict([(color, color) for color in COLORS])) 435 | self.assertTrue(field.allow_null) 436 | 437 | def test_fail_when_a_field_type_not_found(self): 438 | class JSSerializer(ModelSerializer): 439 | class Meta: 440 | model = ModelWithJson 441 | session = session 442 | fields = ("js",) 443 | 444 | serializer = JSSerializer() 445 | with self.assertRaises(KeyError) as e: 446 | serializer.fields 447 | 448 | self.assertEqual(e.exception.args, ("Could not figure out type for attribute 'ModelWithJson.js'",)) 449 | 450 | def test_build_url_field(self): 451 | class VehicleSerializer(ModelSerializer): 452 | class Meta: 453 | model = Vehicle 454 | session = session 455 | exclude = ("name",) 456 | 457 | def get_default_field_names(self, declared_fields, info): 458 | return super().get_default_field_names(declared_fields, info) + ["url"] 459 | 460 | serializer = VehicleSerializer() 461 | fields = serializer.get_fields() 462 | self.assertIn("url", fields) 463 | 464 | url_field = fields.get("url") 465 | 466 | self.assertIsInstance(url_field, HyperlinkedIdentityField) 467 | self.assertEqual(url_field.view_name, "vehicle-detail") 468 | 469 | def test_build_composite_field(self): 470 | class VehicleSerializer(ModelSerializer): 471 | class Meta: 472 | model = Vehicle 473 | session = session 474 | fields = "__all__" 475 | 476 | serializer = VehicleSerializer() 477 | info = model_info(Vehicle) 478 | field = serializer.build_field(Vehicle.engine.key, info, Vehicle, 0) 479 | 480 | self.assertIsInstance(field, CompositeSerializer) 481 | self.assertEqual(len(field.fields), 4) 482 | 483 | def test_deepcopy_composite_field(self): 484 | class EngineSerializer(CompositeSerializer): 485 | pass 486 | 487 | serializer = EngineSerializer(composite=Vehicle.engine) 488 | 489 | clone = copy.deepcopy(serializer) 490 | 491 | self.assertNotEqual(id(serializer), id(clone)) 492 | self.assertEqual(serializer._args, clone._args) 493 | self.assertDictEqual(serializer._kwargs, clone._kwargs) 494 | 495 | def test_build_property_field(self): 496 | class VehicleSerializer(ModelSerializer): 497 | class Meta: 498 | model = Vehicle 499 | session = session 500 | fields = "__all__" 501 | 502 | serializer = VehicleSerializer() 503 | info = model_info(Vehicle) 504 | field = serializer.build_field("lower_name", info, Vehicle, 0) 505 | 506 | self.assertIsInstance(field, fields.ReadOnlyField) 507 | 508 | def test_build_unknows_field(self): 509 | class VehicleSerializer(ModelSerializer): 510 | class Meta: 511 | model = Vehicle 512 | session = session 513 | fields = "__all__" 514 | 515 | serializer = VehicleSerializer() 516 | info = model_info(Vehicle) 517 | 518 | with self.assertRaises(ImproperlyConfigured): 519 | serializer.build_field("abcde", info, Vehicle, 0) 520 | 521 | def test_build_one_to_many_relationship_field(self): 522 | class VehicleSerializer(ModelSerializer): 523 | class Meta: 524 | model = Vehicle 525 | session = session 526 | fields = "__all__" 527 | 528 | serializer = VehicleSerializer() 529 | info = model_info(Vehicle) 530 | nested_serializer = serializer.build_field(Vehicle.owner.key, info, Vehicle, 0) 531 | 532 | self.assertIsNotNone(nested_serializer) 533 | self.assertIsInstance(nested_serializer, ModelSerializer) 534 | self.assertEqual(len(nested_serializer.fields), 3) 535 | 536 | def test_build_one_to_many_relationship_field_with_nested_updates_disabled(self): 537 | class VehicleSerializer(ModelSerializer): 538 | class Meta: 539 | model = Vehicle 540 | session = session 541 | fields = "__all__" 542 | extra_kwargs = {"owner": {"allow_nested_updates": False}} 543 | 544 | serializer = VehicleSerializer() 545 | info = model_info(Vehicle) 546 | nested_serializer = serializer.build_field(Vehicle.owner.key, info, Vehicle, 0) 547 | 548 | self.assertIsNotNone(nested_serializer) 549 | self.assertIsInstance(nested_serializer, ModelSerializer) 550 | self.assertEqual(len(nested_serializer.fields), 3) 551 | self.assertTrue(nested_serializer.fields["first_name"].read_only) 552 | self.assertTrue(nested_serializer.fields["last_name"].read_only) 553 | 554 | def test_generated_nested_serializer_get_session_from_parent(self): 555 | class VehicleSerializer(ModelSerializer): 556 | class Meta: 557 | model = Vehicle 558 | fields = "__all__" 559 | depth = 3 560 | 561 | serializer = VehicleSerializer(context={"session": session}) 562 | 563 | owner_serializer = serializer.fields["owner"] 564 | 565 | self.assertEqual(owner_serializer.session, session) 566 | 567 | def test_declared_nested_serializer_get_session_from_context(self): 568 | class OwnerSerializer(ModelSerializer): 569 | class Meta: 570 | model = Owner 571 | fields = "__all__" 572 | 573 | class VehicleSerializer(ModelSerializer): 574 | Owner = OwnerSerializer() 575 | 576 | class Meta: 577 | model = Vehicle 578 | fields = "__all__" 579 | 580 | serializer = VehicleSerializer(context={"session": session}) 581 | 582 | owner_serializer = serializer.fields["owner"] 583 | 584 | self.assertEqual(owner_serializer.session, session) 585 | 586 | def test_declared_nested_serializer_get_session_from_root_meta(self): 587 | class OwnerSerializer(ModelSerializer): 588 | class Meta: 589 | model = Owner 590 | fields = "__all__" 591 | 592 | class VehicleSerializer(ModelSerializer): 593 | Owner = OwnerSerializer() 594 | 595 | class Meta: 596 | model = Vehicle 597 | session = session 598 | fields = "__all__" 599 | 600 | serializer = VehicleSerializer() 601 | 602 | owner_serializer = serializer.fields["owner"] 603 | 604 | self.assertEqual(owner_serializer.session, session) 605 | 606 | def test_build_serializer_with_depth(self): 607 | class VehicleSerializer(ModelSerializer): 608 | class Meta: 609 | model = Vehicle 610 | session = session 611 | fields = "__all__" 612 | depth = 3 613 | 614 | serializer = VehicleSerializer() 615 | 616 | self.assertEqual(len(serializer.fields), 10) 617 | self.assertEqual( 618 | set(serializer.fields.keys()), 619 | { 620 | Vehicle.created_at.key, 621 | Vehicle.engine.key, 622 | Vehicle.id.key, 623 | Vehicle.name.key, 624 | Vehicle.options.key, 625 | Vehicle.other.key, 626 | Vehicle.owner.key, 627 | Vehicle.paint.key, 628 | Vehicle.type.key, 629 | Vehicle.is_used.key, 630 | }, 631 | ) 632 | 633 | engine_serializer = serializer.fields["engine"] 634 | self.assertEqual(len(engine_serializer.fields), 4) 635 | self.assertEqual(set(engine_serializer.fields.keys()), {"type_", "displacement", "fuel_type", "cylinders"}) 636 | 637 | owner_serializer = serializer.fields["owner"] 638 | self.assertEqual(len(owner_serializer.fields), 3) 639 | self.assertEqual(set(owner_serializer.fields.keys()), {"id", "first_name", "last_name"}) 640 | self.assertEqual({f.label for f in owner_serializer.fields.values()}, {"Id", "First name", "Last name"}) 641 | 642 | options_serializer = serializer.fields["options"] 643 | self.assertTrue(options_serializer.many) 644 | self.assertIsInstance(options_serializer, ListSerializer) 645 | 646 | option_serializer = options_serializer.child 647 | self.assertEqual(len(option_serializer.fields), 2) 648 | self.assertEqual(set(option_serializer.fields.keys()), {"id", "name"}) 649 | 650 | def test_serializer_zero_depth_invalid_error_message(self): 651 | class VehicleSerializer(ModelSerializer): 652 | class Meta: 653 | model = Vehicle 654 | session = session 655 | fields = "__all__" 656 | 657 | serializer = VehicleSerializer(data={}) 658 | 659 | self.assertFalse(serializer.is_valid()) 660 | 661 | self.assertDictEqual(dict(serializer.errors), {"type": ["This field is required."]}) 662 | 663 | def test_serializer_zero_depth_post_basic_validation(self): 664 | class VehicleSerializer(ModelSerializer): 665 | class Meta: 666 | model = Vehicle 667 | session = session 668 | fields = "__all__" 669 | extra_kwargs = {"other": {"required": False}} 670 | 671 | data = { 672 | "name": "Test vehicle", 673 | "one": "Two", 674 | "type": "bus", 675 | "engine": {"displacement": 1234, "cylinders": 4}, 676 | "owner": {"id": 1}, 677 | "options": [], 678 | } 679 | serializer = VehicleSerializer(data=data) 680 | 681 | self.assertTrue(serializer.is_valid(), serializer.errors) 682 | 683 | self.assertDictEqual( 684 | dict(serializer.validated_data), 685 | { 686 | "name": "Test vehicle", 687 | "type": VehicleType.bus, 688 | "engine": {"displacement": Decimal("1234.00"), "cylinders": 4}, 689 | "owner": {"id": 1}, 690 | "options": [], 691 | }, 692 | ) 693 | 694 | def test_serializer_create(self): 695 | class VehicleSerializer(ModelSerializer): 696 | class Meta: 697 | model = Vehicle 698 | session = session 699 | fields = "__all__" 700 | extra_kwargs = {"other": {"required": False, "allow_create": False}} 701 | 702 | data = { 703 | "name": "Test vehicle", 704 | "one": "Two", 705 | "type": "bus", 706 | "engine": {"displacement": 1234, "cylinders": 4}, 707 | "owner": {"id": 1}, 708 | "options": [], 709 | } 710 | 711 | serializer = VehicleSerializer(data=data) 712 | 713 | self.assertTrue(serializer.is_valid(), serializer.errors) 714 | 715 | vehicle = serializer.save() 716 | 717 | self.assertEqual(vehicle.name, data["name"]) 718 | self.assertEqual(vehicle.type, VehicleType.bus) 719 | self.assertEqual(vehicle.engine.cylinders, data["engine"]["cylinders"]) 720 | self.assertEqual(vehicle.engine.displacement, data["engine"]["displacement"]) 721 | self.assertIsNone(vehicle.engine.fuel_type) 722 | self.assertIsNone(vehicle.engine.type_) 723 | self.assertEqual(vehicle.owner.id, data["owner"]["id"]) 724 | self.assertEqual(vehicle.owner.first_name, "Test") 725 | self.assertEqual(vehicle.owner.last_name, "Owner") 726 | self.assertEqual(vehicle.options, data["options"]) 727 | 728 | def test_serializer_create_diff_field_source(self): 729 | class VehicleSerializer(ModelSerializer): 730 | class Meta: 731 | model = Vehicle 732 | session = session 733 | fields = "__all__" 734 | extra_kwargs = {"other": {"required": False, "allow_create": False}} 735 | 736 | def get_fields(self): 737 | fields = super().get_fields() 738 | fields["vehicle_type"] = fields.pop("type") 739 | fields["vehicle_type"].source = "type" 740 | return fields 741 | 742 | data = { 743 | "name": "Test vehicle", 744 | "one": "Two", 745 | "vehicle_type": "Bus", 746 | "engine": {"displacement": 1234, "cylinders": 4}, 747 | "owner": {"id": 1}, 748 | "options": [], 749 | } 750 | 751 | serializer = VehicleSerializer(data=data) 752 | 753 | self.assertTrue(serializer.is_valid(), serializer.errors) 754 | 755 | vehicle = serializer.save() 756 | 757 | self.assertEqual(vehicle.name, data["name"]) 758 | self.assertEqual(vehicle.type, VehicleType(data["vehicle_type"])) 759 | self.assertEqual(vehicle.engine.cylinders, data["engine"]["cylinders"]) 760 | self.assertEqual(vehicle.engine.displacement, data["engine"]["displacement"]) 761 | self.assertIsNone(vehicle.engine.fuel_type) 762 | self.assertIsNone(vehicle.engine.type_) 763 | self.assertEqual(vehicle.owner.id, data["owner"]["id"]) 764 | self.assertEqual(vehicle.owner.first_name, "Test") 765 | self.assertEqual(vehicle.owner.last_name, "Owner") 766 | self.assertEqual(vehicle.options, data["options"]) 767 | 768 | def test_serializer_create_model_validations(self): 769 | class VehicleSerializer(ModelSerializer): 770 | class Meta: 771 | model = Vehicle 772 | session = session 773 | fields = "__all__" 774 | extra_kwargs = {"other": {"required": False, "allow_create": False}} 775 | 776 | data = { 777 | "name": "invalid", 778 | "one": "Two", 779 | "type": "Bus", 780 | "engine": {"displacement": 1234, "cylinders": 4}, 781 | "owner": {"id": 1}, 782 | "options": [], 783 | } 784 | 785 | serializer = VehicleSerializer(data=data) 786 | 787 | self.assertTrue(serializer.is_valid(), serializer.errors) 788 | 789 | with self.assertRaises(ValidationError) as e: 790 | serializer.save() 791 | 792 | self.assertDictEqual(e.exception.detail, {"name": ["invalid vehicle name"]}) 793 | 794 | def test_serializer_create_star_source(self): 795 | class BasicSerializer(ModelSerializer): 796 | class Meta: 797 | model = Vehicle 798 | session = session 799 | fields = ["name", "type"] 800 | 801 | class VehicleSerializer(ModelSerializer): 802 | basic = BasicSerializer(source="*", allow_nested_updates=True) 803 | 804 | class Meta: 805 | model = Vehicle 806 | session = session 807 | exclude = ["name", "type"] 808 | extra_kwargs = {"other": {"required": False, "allow_create": False}} 809 | 810 | data = { 811 | "basic": {"name": "Test vehicle", "type": "Bus"}, 812 | "one": "Two", 813 | "engine": {"displacement": 1234, "cylinders": 4}, 814 | "owner": {"id": 1}, 815 | "options": [], 816 | } 817 | 818 | serializer = VehicleSerializer(data=data) 819 | 820 | self.assertTrue(serializer.is_valid(), serializer.errors) 821 | 822 | vehicle = serializer.save() 823 | 824 | self.assertEqual(vehicle.name, data["basic"]["name"]) 825 | self.assertEqual(vehicle.type, VehicleType(data["basic"]["type"])) 826 | self.assertEqual(vehicle.engine.cylinders, data["engine"]["cylinders"]) 827 | self.assertEqual(vehicle.engine.displacement, data["engine"]["displacement"]) 828 | self.assertIsNone(vehicle.engine.fuel_type) 829 | self.assertIsNone(vehicle.engine.type_) 830 | self.assertEqual(vehicle.owner.id, data["owner"]["id"]) 831 | self.assertEqual(vehicle.owner.first_name, "Test") 832 | self.assertEqual(vehicle.owner.last_name, "Owner") 833 | self.assertEqual(vehicle.options, data["options"]) 834 | 835 | def test_post_update(self): 836 | vehicle = Vehicle( 837 | name="Test vehicle", 838 | type=VehicleType.bus, 839 | engine=Engine(4, 1234, None, None), 840 | owner=session.query(Owner).get(1), 841 | ) 842 | 843 | class VehicleSerializer(ModelSerializer): 844 | class Meta: 845 | model = Vehicle 846 | session = session 847 | fields = "__all__" 848 | extra_kwargs = {"other": {"required": False, "allow_create": True, "allow_nested_updates": True}} 849 | 850 | data = { 851 | "name": "Another test vechicle", 852 | "one": "Two", 853 | "type": "Car", 854 | "engine": {"displacement": 4321, "cylinders": 2, "type_": "banana", "fuel_type": "petrol"}, 855 | "owner": {"id": 1}, 856 | "options": [], 857 | "other": {"advertising_cost": 4321}, 858 | } 859 | 860 | serializer = VehicleSerializer(instance=vehicle, data=data) 861 | 862 | self.assertTrue(serializer.is_valid(), serializer.errors) 863 | 864 | vehicle = serializer.save() 865 | 866 | self.assertEqual(vehicle.name, data["name"]) 867 | self.assertEqual(vehicle.type, VehicleType(data["type"])) 868 | self.assertEqual(vehicle.engine.cylinders, data["engine"]["cylinders"]) 869 | self.assertEqual(vehicle.engine.displacement, data["engine"]["displacement"]) 870 | self.assertEqual(vehicle.engine.fuel_type, data["engine"]["fuel_type"]) 871 | self.assertEqual(vehicle.engine.type_, data["engine"]["type_"]) 872 | self.assertEqual(vehicle.owner.id, data["owner"]["id"]) 873 | self.assertEqual(vehicle.owner.first_name, "Test") 874 | self.assertEqual(vehicle.owner.last_name, "Owner") 875 | self.assertEqual(vehicle.options, data["options"]) 876 | self.assertEqual(vehicle.other.advertising_cost, 4321) 877 | 878 | def test_post_update_remove_composite(self): 879 | vehicle = Vehicle( 880 | name="Test vehicle", 881 | type=VehicleType.bus, 882 | engine=Engine(4, 1234, None, None), 883 | owner=session.query(Owner).get(1), 884 | ) 885 | 886 | class VehicleSerializer(ModelSerializer): 887 | class Meta: 888 | model = Vehicle 889 | session = session 890 | fields = "__all__" 891 | extra_kwargs = { 892 | "other": {"required": False, "allow_create": True, "allow_nested_updates": True}, 893 | "engine": {"required": False, "allow_null": True}, 894 | } 895 | 896 | data = { 897 | "name": "Another test vechicle", 898 | "one": "Two", 899 | "type": "Car", 900 | "engine": None, 901 | "owner": {"id": 1}, 902 | "options": [], 903 | "other": {"advertising_cost": 4321}, 904 | } 905 | 906 | serializer = VehicleSerializer(instance=vehicle, data=data) 907 | 908 | self.assertTrue(serializer.is_valid(), serializer.errors) 909 | 910 | vehicle = serializer.save() 911 | 912 | self.assertEqual(vehicle.name, data["name"]) 913 | self.assertEqual(vehicle.type, VehicleType(data["type"])) 914 | self.assertEqual(vehicle.owner.id, data["owner"]["id"]) 915 | self.assertEqual(vehicle.owner.first_name, "Test") 916 | self.assertEqual(vehicle.owner.last_name, "Owner") 917 | self.assertEqual(vehicle.options, data["options"]) 918 | self.assertEqual(vehicle.other.advertising_cost, 4321) 919 | self.assertIsNone(vehicle.engine.cylinders) 920 | self.assertIsNone(vehicle.engine.displacement) 921 | self.assertIsNone(vehicle.engine.fuel_type) 922 | self.assertIsNone(vehicle.engine.type_) 923 | 924 | def test_patch_update(self): 925 | vehicle = Vehicle( 926 | name="Test vehicle", 927 | type=VehicleType.bus, 928 | engine=Engine(4, 1234, None, None), 929 | owner=session.query(Owner).get(1), 930 | other=VehicleOther(advertising_cost=4321), 931 | ) 932 | 933 | class VehicleSerializer(ModelSerializer): 934 | class Meta: 935 | model = Vehicle 936 | session = session 937 | fields = "__all__" 938 | extra_kwargs = {"other": {"required": False, "allow_create": True, "allow_nested_updates": True}} 939 | 940 | data = {"other": {"advertising_cost": 1234}} 941 | 942 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 943 | 944 | self.assertTrue(serializer.is_valid(), serializer.errors) 945 | 946 | vehicle = serializer.save() 947 | 948 | self.assertEqual(vehicle.other.advertising_cost, data["other"]["advertising_cost"]) 949 | 950 | def test_patch_update_with_nested_id(self): 951 | vehicle = Vehicle( 952 | name="Test vehicle", 953 | type=VehicleType.bus, 954 | engine=Engine(4, 1234, None, None), 955 | owner=session.query(Owner).get(1), 956 | other=VehicleOther(advertising_cost=4321), 957 | ) 958 | session.add(vehicle) 959 | session.flush() 960 | 961 | other = vehicle.other 962 | 963 | class VehicleSerializer(ModelSerializer): 964 | class Meta: 965 | model = Vehicle 966 | session = session 967 | fields = ("other",) 968 | extra_kwargs = {"other": {"allow_nested_updates": True}} 969 | 970 | data = {"other": {"id": vehicle.other.id, "advertising_cost": 1234}} 971 | 972 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 973 | 974 | self.assertTrue(serializer.is_valid(), serializer.errors) 975 | 976 | vehicle = serializer.save() 977 | 978 | self.assertEqual(vehicle.other.advertising_cost, data["other"]["advertising_cost"]) 979 | self.assertEqual(vehicle.other, other) 980 | 981 | def test_patch_update_nested_set_null(self): 982 | vehicle = Vehicle( 983 | name="Test vehicle", 984 | type=VehicleType.bus, 985 | engine=Engine(4, 1234, None, None), 986 | owner=session.query(Owner).get(1), 987 | other=VehicleOther(advertising_cost=4321), 988 | ) 989 | 990 | class VehicleSerializer(ModelSerializer): 991 | class Meta: 992 | model = Vehicle 993 | session = session 994 | fields = ("other",) 995 | extra_kwargs = {"other": {"allow_create": True, "allow_null": True}} 996 | 997 | data = {"other": None} 998 | 999 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1000 | 1001 | self.assertTrue(serializer.is_valid(), serializer.errors) 1002 | 1003 | vehicle = serializer.save() 1004 | 1005 | self.assertIsNone(vehicle.other) 1006 | 1007 | def test_patch_update_nested_set_null_allow_null_false(self): 1008 | vehicle = Vehicle( 1009 | name="Test vehicle", 1010 | type=VehicleType.bus, 1011 | engine=Engine(4, 1234, None, None), 1012 | owner=session.query(Owner).get(1), 1013 | other=VehicleOther(advertising_cost=4321), 1014 | ) 1015 | 1016 | class VehicleSerializer(ModelSerializer): 1017 | class Meta: 1018 | model = Vehicle 1019 | session = session 1020 | fields = ("other",) 1021 | extra_kwargs = {"other": {"allow_create": True, "allow_null": False}} 1022 | 1023 | data = {"other": None} 1024 | 1025 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1026 | 1027 | self.assertFalse(serializer.is_valid(), serializer.errors) 1028 | 1029 | def test_patch_update_nested_set_null_allow_create_false(self): 1030 | vehicle = Vehicle( 1031 | name="Test vehicle", 1032 | type=VehicleType.bus, 1033 | engine=Engine(4, 1234, None, None), 1034 | owner=session.query(Owner).get(1), 1035 | other=VehicleOther(advertising_cost=4321), 1036 | ) 1037 | 1038 | class VehicleSerializer(ModelSerializer): 1039 | class Meta: 1040 | model = Vehicle 1041 | session = session 1042 | fields = ("other",) 1043 | extra_kwargs = {"other": {"allow_create": False, "allow_null": True}} 1044 | 1045 | data = {"other": None} 1046 | 1047 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1048 | 1049 | self.assertTrue(serializer.is_valid(), serializer.errors) 1050 | 1051 | vehicle = serializer.save() 1052 | 1053 | self.assertIsNone(vehicle.other) 1054 | 1055 | def test_composite_serializer_can_create(self): 1056 | class EngineSerializer(CompositeSerializer): 1057 | class Meta: 1058 | composite = Vehicle.engine 1059 | 1060 | data = {"cylinders": 2, "displacement": 1234, "fuel_type": "petrol", "type_": "banana"} 1061 | 1062 | serializer = EngineSerializer(data=data) 1063 | self.assertTrue(serializer.is_valid(), serializer.errors) 1064 | 1065 | engine = serializer.save() 1066 | 1067 | self.assertIsInstance(engine, Engine) 1068 | self.assertEqual(engine.cylinders, 2) 1069 | self.assertEqual(engine.displacement, 1234) 1070 | self.assertEqual(engine.fuel_type, "petrol") 1071 | self.assertEqual(engine.type_, "banana") 1072 | 1073 | def test_composite_serializer_can_update(self): 1074 | class EngineSerializer(CompositeSerializer): 1075 | class Meta: 1076 | composite = Vehicle.engine 1077 | 1078 | data = {"cylinders": 2, "displacement": 1234, "fuel_type": "diesel", "type_": "banana"} 1079 | engine = Engine(4, 2345, "apple", "petrol") 1080 | 1081 | serializer = EngineSerializer(engine, data=data) 1082 | self.assertTrue(serializer.is_valid(), serializer.errors) 1083 | 1084 | engine = serializer.save() 1085 | 1086 | self.assertIsInstance(engine, Engine) 1087 | self.assertEqual(engine.cylinders, 2) 1088 | self.assertEqual(engine.displacement, 1234) 1089 | self.assertEqual(engine.fuel_type, "diesel") 1090 | self.assertEqual(engine.type_, "banana") 1091 | 1092 | def test_composite_serializer_can_update_patch(self): 1093 | class EngineSerializer(CompositeSerializer): 1094 | class Meta: 1095 | composite = Vehicle.engine 1096 | 1097 | data = {"cylinders": 2} 1098 | engine = Engine(4, 2345, "apple", "petrol") 1099 | 1100 | serializer = EngineSerializer(engine, data=data, partial=True) 1101 | self.assertTrue(serializer.is_valid(), serializer.errors) 1102 | 1103 | engine = serializer.save() 1104 | 1105 | self.assertIsInstance(engine, Engine) 1106 | self.assertEqual(engine.cylinders, 2) 1107 | self.assertEqual(engine.displacement, 2345) 1108 | self.assertEqual(engine.fuel_type, "petrol") 1109 | self.assertEqual(engine.type_, "apple") 1110 | 1111 | def test_composite_serializer_can_use_custom_setter(self): 1112 | class EngineSerializer(CompositeSerializer): 1113 | class Meta: 1114 | composite = Vehicle.engine 1115 | 1116 | def set_cylinders(self, instance, field, value): 1117 | self.called = True 1118 | instance.cylinders = value 1119 | 1120 | data = {"cylinders": 2} 1121 | engine = Engine(4, 2345, "apple", "petrol") 1122 | 1123 | serializer = EngineSerializer(engine, data=data, partial=True) 1124 | self.assertTrue(serializer.is_valid(), serializer.errors) 1125 | 1126 | serializer.save() 1127 | 1128 | self.assertTrue(serializer.called) 1129 | 1130 | def test_composite_serializer_can_handle_errors_during_update(self): 1131 | class EngineSerializer(CompositeSerializer): 1132 | class Meta: 1133 | composite = Vehicle.engine 1134 | 1135 | def set_cylinders(self, instance, field, value): 1136 | raise AssertionError("Some error") 1137 | 1138 | data = {"cylinders": 2} 1139 | engine = Engine(4, 2345, "apple", "petrol") 1140 | 1141 | serializer = EngineSerializer(engine, data=data, partial=True) 1142 | self.assertTrue(serializer.is_valid(), serializer.errors) 1143 | 1144 | with self.assertRaises(ValidationError): 1145 | serializer.save() 1146 | 1147 | def test_patch_update_to_list_with_empty_list_clears_it(self): 1148 | vehicle = Vehicle( 1149 | name="Test vehicle", 1150 | type=VehicleType.bus, 1151 | engine=Engine(4, 1234, None, None), 1152 | owner=session.query(Owner).get(1), 1153 | other=VehicleOther(advertising_cost=4321), 1154 | options=session.query(Option).all(), 1155 | ) 1156 | 1157 | class VehicleSerializer(ModelSerializer): 1158 | class Meta: 1159 | model = Vehicle 1160 | session = session 1161 | fields = ("options",) 1162 | 1163 | data = {"options": []} 1164 | 1165 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1166 | 1167 | self.assertTrue(serializer.is_valid(), serializer.errors) 1168 | 1169 | vehicle = serializer.save() 1170 | 1171 | self.assertEqual(len(vehicle.options), 0) 1172 | 1173 | def test_patch_update_to_list_with_new_list(self): 1174 | vehicle = Vehicle( 1175 | name="Test vehicle", 1176 | type=VehicleType.bus, 1177 | engine=Engine(4, 1234, None, None), 1178 | owner=session.query(Owner).get(1), 1179 | other=VehicleOther(advertising_cost=4321), 1180 | options=session.query(Option).filter(Option.id.in_([1, 2])).all(), 1181 | ) 1182 | 1183 | class VehicleSerializer(ModelSerializer): 1184 | class Meta: 1185 | model = Vehicle 1186 | session = session 1187 | fields = ("options",) 1188 | 1189 | data = {"options": [{"id": 3}, {"id": 4}]} 1190 | 1191 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1192 | 1193 | self.assertTrue(serializer.is_valid(), serializer.errors) 1194 | 1195 | vehicle = serializer.update(vehicle, serializer.validated_data) 1196 | 1197 | self.assertEqual(len(vehicle.options), 2) 1198 | self.assertEqual({v.id for v in vehicle.options}, {3, 4}) 1199 | 1200 | def test_patch_update_to_list_with_new_list_with_allow_create(self): 1201 | vehicle = Vehicle( 1202 | name="Test vehicle", 1203 | type=VehicleType.bus, 1204 | engine=Engine(4, 1234, None, None), 1205 | owner=session.query(Owner).get(1), 1206 | other=VehicleOther(advertising_cost=4321), 1207 | options=session.query(Option).filter(Option.id.in_([1, 2])).all(), 1208 | ) 1209 | 1210 | class VehicleSerializer(ModelSerializer): 1211 | class Meta: 1212 | model = Vehicle 1213 | session = session 1214 | fields = ("options",) 1215 | extra_kwargs = {"options": {"allow_create": True}} 1216 | 1217 | data = {"options": [{"name": "Test"}, {"name": "Other Test"}]} 1218 | 1219 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1220 | 1221 | self.assertTrue(serializer.is_valid(), serializer.errors) 1222 | 1223 | vehicle = serializer.update(vehicle, serializer.validated_data) 1224 | 1225 | self.assertEqual(len(vehicle.options), 2) 1226 | self.assertEqual({v.name for v in vehicle.options}, {"Test", "Other Test"}) 1227 | 1228 | def test_patch_update_to_list_with_new_list_with_nested(self): 1229 | vehicle = Vehicle( 1230 | name="Test vehicle", 1231 | type=VehicleType.bus, 1232 | engine=Engine(4, 1234, None, None), 1233 | owner=session.query(Owner).get(1), 1234 | other=VehicleOther(advertising_cost=4321), 1235 | options=session.query(Option).filter(Option.id.in_([1, 2])).all(), 1236 | ) 1237 | 1238 | class VehicleSerializer(ModelSerializer): 1239 | class Meta: 1240 | model = Vehicle 1241 | session = session 1242 | fields = ("options",) 1243 | extra_kwargs = {"options": {"allow_nested_updates": True}} 1244 | 1245 | data = {"options": [{"id": 1, "name": "Test 1"}, {"id": 2, "name": "Test 2"}]} 1246 | 1247 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1248 | 1249 | self.assertTrue(serializer.is_valid(), serializer.errors) 1250 | 1251 | vehicle = serializer.update(vehicle, serializer.validated_data) 1252 | 1253 | self.assertEqual([option.id for option in vehicle.options], [1, 2]) 1254 | self.assertEqual([option.name for option in vehicle.options], ["Test 1", "Test 2"]) 1255 | 1256 | def test_patch_update_to_list_with_new_list_with_nested_raises_for_a_bad_pk(self): 1257 | vehicle = Vehicle( 1258 | name="Test vehicle", 1259 | type=VehicleType.bus, 1260 | engine=Engine(4, 1234, None, None), 1261 | owner=session.query(Owner).get(1), 1262 | other=VehicleOther(advertising_cost=4321), 1263 | options=session.query(Option).filter(Option.id.in_([1, 2])).all(), 1264 | ) 1265 | 1266 | class VehicleSerializer(ModelSerializer): 1267 | class Meta: 1268 | model = Vehicle 1269 | session = session 1270 | fields = ("options",) 1271 | extra_kwargs = {"options": {"allow_null": False}} 1272 | 1273 | data = {"options": [{"id": 1, "name": "Test 1"}, {"id": 5, "name": "Test 5"}]} 1274 | 1275 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1276 | 1277 | self.assertTrue(serializer.is_valid(), serializer.errors) 1278 | 1279 | with self.assertRaises(ValidationError): 1280 | serializer.update(vehicle, serializer.validated_data) 1281 | 1282 | def test_update_generates_validation_error_when_required_many_to_one_instance_not_found(self): 1283 | vehicle = Vehicle( 1284 | name="Test vehicle", 1285 | type=VehicleType.bus, 1286 | engine=Engine(4, 1234, None, None), 1287 | owner=session.query(Owner).get(1), 1288 | other=VehicleOther(advertising_cost=4321), 1289 | options=session.query(Option).filter(Option.id.in_([1, 2])).all(), 1290 | ) 1291 | 1292 | class VehicleSerializer(ModelSerializer): 1293 | class Meta: 1294 | model = Vehicle 1295 | session = session 1296 | fields = ("owner",) 1297 | extra_kwargs = {"owner": {"allow_null": False}} 1298 | 1299 | data = {"owner": {"id": 1234}} 1300 | 1301 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1302 | self.assertTrue(serializer.is_valid(), serializer.errors) 1303 | 1304 | with self.assertRaises(ValidationError): 1305 | serializer.update(vehicle, serializer.validated_data) 1306 | 1307 | def test_update_calls_custom_setter(self): 1308 | class VehicleSerializer(ModelSerializer): 1309 | class Meta: 1310 | model = Vehicle 1311 | session = session 1312 | fields = ("name",) 1313 | 1314 | def set_name(self, instance, field, value): 1315 | self._set_name_called = True 1316 | 1317 | vehicle = Vehicle( 1318 | name="Test vehicle", 1319 | type=VehicleType.bus, 1320 | engine=Engine(4, 1234, None, None), 1321 | owner=session.query(Owner).get(1), 1322 | other=VehicleOther(advertising_cost=4321), 1323 | options=session.query(Option).filter(Option.id.in_([1, 2])).all(), 1324 | ) 1325 | 1326 | data = {"name": "Bob Loblaw"} 1327 | 1328 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1329 | self.assertTrue(serializer.is_valid(), serializer.errors) 1330 | 1331 | serializer.update(vehicle, serializer.validated_data) 1332 | self.assertTrue(serializer._set_name_called) 1333 | 1334 | def test_update_calls_custom_setter_django_validation_error(self): 1335 | class VehicleSerializer(ModelSerializer): 1336 | class Meta: 1337 | model = Vehicle 1338 | session = session 1339 | fields = ("name",) 1340 | 1341 | def set_name(self, instance, field, value): 1342 | raise DjangoValidationError({"name": [DjangoValidationError("error here")]}) 1343 | 1344 | vehicle = Vehicle(name="Test vehicle") 1345 | 1346 | data = {"name": "Bob Loblaw"} 1347 | 1348 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1349 | self.assertTrue(serializer.is_valid(), serializer.errors) 1350 | 1351 | with self.assertRaises(ValidationError) as e: 1352 | serializer.update(vehicle, serializer.validated_data) 1353 | 1354 | self.assertEqual(e.exception.detail, {"name": ["error here"]}) 1355 | 1356 | def test_update_composite_calls_custom_setter_django_validation_error(self): 1357 | class EngineSerializer(CompositeSerializer): 1358 | class Meta: 1359 | composite = Vehicle.engine 1360 | 1361 | def set_cylinders(self, instance, field, value): 1362 | raise DjangoValidationError({"cylinders": [DjangoValidationError("error here")]}) 1363 | 1364 | class VehicleSerializer(ModelSerializer): 1365 | engine = EngineSerializer() 1366 | 1367 | class Meta: 1368 | model = Vehicle 1369 | session = session 1370 | fields = ("engine",) 1371 | 1372 | def set_name(self, instance, field, value): 1373 | raise DjangoValidationError({"name": [DjangoValidationError("error here")]}) 1374 | 1375 | vehicle = Vehicle() 1376 | 1377 | data = {"engine": {"cylinders": 10}} 1378 | 1379 | serializer = VehicleSerializer(instance=vehicle, data=data, partial=True) 1380 | self.assertTrue(serializer.is_valid(), serializer.errors) 1381 | 1382 | with self.assertRaises(ValidationError) as e: 1383 | serializer.update(vehicle, serializer.validated_data) 1384 | 1385 | self.assertEqual(e.exception.detail, {"engine": {"cylinders": [ErrorDetail("error here", code="invalid")]}}) 1386 | 1387 | def test_get_object_can_get_object(self): 1388 | class OwnerSerializer(ModelSerializer): 1389 | class Meta: 1390 | model = Owner 1391 | session = session 1392 | fields = "__all__" 1393 | 1394 | serializer = OwnerSerializer() 1395 | instance = serializer.get_object({"id": 1}) 1396 | 1397 | self.assertIsNotNone(instance) 1398 | self.assertIsInstance(instance, Owner) 1399 | 1400 | def test_get_object_raise_when_not_found(self): 1401 | class OwnerSerializer(ModelSerializer): 1402 | class Meta: 1403 | model = Owner 1404 | session = session 1405 | fields = "__all__" 1406 | 1407 | serializer = OwnerSerializer() 1408 | 1409 | with self.assertRaises(ValidationError): 1410 | serializer.get_object({"id": 999}) 1411 | 1412 | def test_get_object_existing_instance(self): 1413 | class OwnerSerializer(ModelSerializer): 1414 | class Meta: 1415 | model = Owner 1416 | session = session 1417 | fields = "__all__" 1418 | 1419 | existing = Owner() 1420 | serializer = OwnerSerializer(allow_null=True) 1421 | instance = serializer.get_object({}, existing) 1422 | 1423 | self.assertIsNotNone(instance) 1424 | self.assertIs(instance, existing) 1425 | 1426 | def test_get_object_allow_null(self): 1427 | class OwnerSerializer(ModelSerializer): 1428 | class Meta: 1429 | model = Owner 1430 | session = session 1431 | fields = "__all__" 1432 | 1433 | serializer = OwnerSerializer(allow_null=True) 1434 | 1435 | self.assertIsNone(serializer.get_object(None, Owner())) 1436 | 1437 | def test_get_object_allows_create(self): 1438 | class OwnerSerializer(ModelSerializer): 1439 | class Meta: 1440 | model = Owner 1441 | session = session 1442 | fields = "__all__" 1443 | 1444 | serializer = OwnerSerializer(allow_create=True) 1445 | instance = serializer.get_object({}) 1446 | 1447 | self.assertIsNotNone(instance) 1448 | self.assertIsInstance(instance, Owner) 1449 | 1450 | def test_get_object_no_object(self): 1451 | class OwnerSerializer(ModelSerializer): 1452 | class Meta: 1453 | model = Owner 1454 | session = session 1455 | fields = "__all__" 1456 | 1457 | serializer = OwnerSerializer() 1458 | 1459 | with self.assertRaises(ValidationError): 1460 | serializer.get_object({}) 1461 | 1462 | def test_to_internal_value_partial_by_pk(self): 1463 | class OwnerSerializer(ModelSerializer): 1464 | class Meta: 1465 | model = Owner 1466 | session = session 1467 | fields = "__all__" 1468 | 1469 | serializer = OwnerSerializer(data={"id": 1}, partial_by_pk=True) 1470 | 1471 | self.assertTrue(serializer.fields["id"].required) 1472 | self.assertTrue(serializer.is_valid()) 1473 | self.assertTrue(serializer.fields["id"].required) 1474 | self.assertFalse(serializer.fields["first_name"].required) 1475 | self.assertFalse(serializer.fields["last_name"].required) 1476 | 1477 | def test_to_internal_value_partial_by_pk_remove_extra_fields(self): 1478 | class VehicleSerializer(ModelSerializer): 1479 | class Meta: 1480 | model = Vehicle 1481 | session = session 1482 | fields = "__all__" 1483 | 1484 | class OwnerSerializer(ModelSerializer): 1485 | vehicle = VehicleSerializer(partial_by_pk=True) 1486 | 1487 | class Meta: 1488 | model = Option 1489 | session = session 1490 | fields = "__all__" 1491 | 1492 | serializer = OwnerSerializer(data={"id": 111, "name": "foo", "vehicle": {"id": 1}}) 1493 | 1494 | self.assertTrue(serializer.is_valid(), serializer.errors) 1495 | self.assertEqual(serializer.validated_data["vehicle"], {"id": 1}) 1496 | 1497 | serializer = OwnerSerializer(data={"id": 111, "name": "foo", "vehicle": {}}) 1498 | self.assertFalse(serializer.is_valid(), serializer.errors) 1499 | 1500 | serializer = OwnerSerializer(data={"id": 111, "name": "foo"}) 1501 | self.assertFalse(serializer.is_valid(), serializer.errors) 1502 | 1503 | 1504 | class TestExpandableModelSerializer(SimpleTestCase): 1505 | def setUp(self): 1506 | super().setUp() 1507 | self.vehicle = Vehicle( 1508 | name="Test vehicle", 1509 | type=VehicleType.bus, 1510 | engine=Engine(4, 1234, None, None), 1511 | owner=Owner(first_name="Jon", last_name="Snow"), 1512 | other=VehicleOther(advertising_cost=4321), 1513 | options=[Option(name="GPS")], 1514 | ) 1515 | self.rf = APIRequestFactory() 1516 | self.maxDiff = None 1517 | 1518 | def test_to_representation_collapsed(self): 1519 | s = VehicleSerializer(instance=self.vehicle) 1520 | 1521 | self.assertEqual( 1522 | s.data, 1523 | { 1524 | "id": None, 1525 | "name": "Test vehicle", 1526 | "type": "bus", 1527 | "created_at": None, 1528 | "paint": None, 1529 | "is_used": None, 1530 | "engine": {"cylinders": 4, "displacement": "1234.00", "type_": None, "fuel_type": None}, 1531 | "owner": {"id": None}, 1532 | "options": [{"name": "GPS", "id": None}], 1533 | }, 1534 | ) 1535 | 1536 | def test_to_representation_list_collapsed(self): 1537 | s = VehicleSerializer(instance=[self.vehicle], many=True) 1538 | 1539 | self.assertEqual( 1540 | s.data, 1541 | [ 1542 | { 1543 | "id": None, 1544 | "name": "Test vehicle", 1545 | "type": "bus", 1546 | "created_at": None, 1547 | "paint": None, 1548 | "is_used": None, 1549 | "engine": {"cylinders": 4, "displacement": "1234.00", "type_": None, "fuel_type": None}, 1550 | "owner": {"id": None}, 1551 | "options": [{"name": "GPS", "id": None}], 1552 | } 1553 | ], 1554 | ) 1555 | 1556 | def test_to_representation_request_expanded(self): 1557 | s = VehicleSerializer(instance=self.vehicle, context={"request": self.rf.get("/", {"expand": "owner"})}) 1558 | 1559 | self.assertEqual( 1560 | s.data, 1561 | { 1562 | "id": None, 1563 | "name": "Test vehicle", 1564 | "type": "bus", 1565 | "created_at": None, 1566 | "paint": None, 1567 | "is_used": None, 1568 | "engine": {"cylinders": 4, "displacement": "1234.00", "type_": None, "fuel_type": None}, 1569 | "owner": {"id": None, "last_name": "Snow", "first_name": "Jon"}, 1570 | "options": [{"name": "GPS", "id": None}], 1571 | }, 1572 | ) 1573 | 1574 | def test_to_representation_update_expanded(self): 1575 | s = VehicleSerializer( 1576 | instance=self.vehicle, 1577 | partial=True, 1578 | data={"owner": {"first_name": "John", "last_name": "Doe"}}, 1579 | allow_nested_updates=True, 1580 | ) 1581 | 1582 | self.assertTrue(s.is_valid()) 1583 | s.save() 1584 | 1585 | self.assertEqual( 1586 | s.data, 1587 | { 1588 | "id": None, 1589 | "name": "Test vehicle", 1590 | "type": "bus", 1591 | "created_at": None, 1592 | "paint": None, 1593 | "is_used": None, 1594 | "engine": {"cylinders": 4, "displacement": "1234.00", "type_": None, "fuel_type": None}, 1595 | "owner": {"id": None, "last_name": "Doe", "first_name": "John"}, 1596 | "options": [{"name": "GPS", "id": None}], 1597 | }, 1598 | ) 1599 | 1600 | def test_query_serializer(self): 1601 | s = VehicleSerializer().get_query_serializer_class()() 1602 | 1603 | self.assertEqual(list(s.fields), ["expand"]) 1604 | self.assertIsInstance(s.fields["expand"], fields.ListField) 1605 | self.assertIsInstance(s.fields["expand"].child, fields.ChoiceField) 1606 | self.assertEqual(list(s.fields["expand"].child.choices), ["owner"]) 1607 | 1608 | def test_query_serializer_exclude(self): 1609 | s = VehicleSerializer().get_query_serializer_class(exclude=["owner"])() 1610 | 1611 | self.assertEqual(list(s.fields), []) 1612 | 1613 | def test_query_serializer_disallow(self): 1614 | s = VehicleSerializer().get_query_serializer_class(disallow=["owner"])() 1615 | 1616 | self.assertEqual(list(s.fields), ["expand"]) 1617 | self.assertIsInstance(s.fields["expand"], fields.ListField) 1618 | self.assertIsInstance(s.fields["expand"].child, fields.ChoiceField) 1619 | self.assertEqual(list(s.fields["expand"].child.choices), []) 1620 | 1621 | def test_query_serializer_nested(self): 1622 | class Serializer(ExpandableModelSerializer): 1623 | vehicles = VehicleSerializer(many=True) 1624 | 1625 | class Meta: 1626 | model = Owner 1627 | session = session 1628 | fields = "__all__" 1629 | 1630 | s = Serializer().get_query_serializer_class()() 1631 | 1632 | self.assertEqual(list(s.fields), ["expand"]) 1633 | self.assertIsInstance(s.fields["expand"], fields.ListField) 1634 | self.assertIsInstance(s.fields["expand"].child, fields.ChoiceField) 1635 | self.assertEqual(set(s.fields["expand"].child.choices), {"vehicles__owner"}) 1636 | --------------------------------------------------------------------------------