├── .github └── workflows │ ├── ci.yml │ └── codeql-analysis.yml ├── .gitignore ├── .travis.yml ├── AUTHORS.md ├── GNUmakefile ├── LICENSE ├── MANIFEST.in ├── README.md ├── SECURITY.md ├── docs ├── backdoc.py ├── index.html └── index.md ├── rest_framework_extensions ├── __init__.py ├── bulk_operations │ ├── __init__.py │ └── mixins.py ├── cache │ ├── __init__.py │ ├── decorators.py │ └── mixins.py ├── compat.py ├── decorators.py ├── etag │ ├── __init__.py │ ├── decorators.py │ └── mixins.py ├── exceptions.py ├── fields.py ├── key_constructor │ ├── __init__.py │ ├── bits.py │ └── constructors.py ├── mixins.py ├── permissions.py ├── routers.py ├── serializers.py ├── settings.py ├── test.py └── utils.py ├── setup.cfg ├── setup.py ├── tests_app ├── __init__.py ├── plugins.py ├── requirements.txt ├── settings.py ├── tests │ ├── __init__.py │ ├── functional │ │ ├── __init__.py │ │ ├── _concurrency │ │ │ ├── __init__.py │ │ │ └── conditional_request │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── serializers.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ ├── _examples │ │ │ ├── __init__.py │ │ │ └── etags │ │ │ │ ├── __init__.py │ │ │ │ └── remove_etag_gzip_postfix │ │ │ │ ├── __init__.py │ │ │ │ ├── middleware.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ ├── cache │ │ │ ├── __init__.py │ │ │ └── decorators │ │ │ │ ├── __init__.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ ├── key_constructor │ │ │ ├── __init__.py │ │ │ └── bits │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── serializers.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ ├── migrations │ │ │ ├── 0001_initial.py │ │ │ ├── 0002_nestedroutermixinusermodel_code.py │ │ │ └── __init__.py │ │ ├── mixins │ │ │ ├── __init__.py │ │ │ ├── detail_serializer_mixin │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── serializers.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ │ ├── list_destroy_model_mixin │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ │ ├── list_update_model_mixin │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── serializers.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ │ └── paginate_by_max_mixin │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── pagination.py │ │ │ │ ├── serializers.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ ├── models.py │ │ ├── permissions │ │ │ ├── __init__.py │ │ │ └── extended_django_object_permissions │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── tests.py │ │ │ │ ├── urls.py │ │ │ │ └── views.py │ │ └── routers │ │ │ ├── __init__.py │ │ │ ├── extended_default_router │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── tests.py │ │ │ ├── urls.py │ │ │ └── views.py │ │ │ ├── models.py │ │ │ ├── nested_router_mixin │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── serializers.py │ │ │ ├── tests.py │ │ │ ├── urls.py │ │ │ ├── urls_generic_relations.py │ │ │ ├── urls_parent_viewset_lookup.py │ │ │ └── views.py │ │ │ ├── tests.py │ │ │ └── views.py │ └── unit │ │ ├── __init__.py │ │ ├── _etag │ │ ├── __init__.py │ │ └── decorators │ │ │ ├── __init__.py │ │ │ └── tests.py │ │ ├── cache │ │ ├── __init__.py │ │ └── decorators │ │ │ ├── __init__.py │ │ │ └── tests.py │ │ ├── decorators │ │ ├── __init__.py │ │ └── tests.py │ │ ├── key_constructor │ │ ├── __init__.py │ │ ├── bits │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ └── tests.py │ │ └── constructor │ │ ├── migrations │ │ ├── 0001_initial.py │ │ └── __init__.py │ │ ├── models.py │ │ ├── routers │ │ ├── __init__.py │ │ ├── nested_router_mixin │ │ │ ├── __init__.py │ │ │ ├── models.py │ │ │ ├── tests.py │ │ │ └── views.py │ │ └── tests.py │ │ ├── serializers │ │ ├── __init__.py │ │ ├── models.py │ │ ├── serializers.py │ │ └── tests.py │ │ └── utils │ │ ├── __init__.py │ │ └── tests.py ├── testutils.py └── urls.py └── tox.ini /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Django CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | strategy: 15 | max-parallel: 4 16 | matrix: 17 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 18 | django-version: ["2.2","3.2", "4.2", "5.2"] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install Dependencies 27 | run: | 28 | python -m pip install --upgrade pip tox 29 | pip install -r tests_app/requirements.txt 30 | - name: Run Tests 31 | run: | 32 | tox -- tests_app 33 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | 21 | jobs: 22 | analyze: 23 | name: Analyze 24 | runs-on: ubuntu-latest 25 | permissions: 26 | actions: read 27 | contents: read 28 | security-events: write 29 | 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | language: [ 'python' ] 34 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 35 | # Learn more about CodeQL language support at https://git.io/codeql-language-support 36 | 37 | steps: 38 | - name: Checkout repository 39 | uses: actions/checkout@v2 40 | 41 | # Initializes the CodeQL tools for scanning. 42 | - name: Initialize CodeQL 43 | uses: github/codeql-action/init@v1 44 | with: 45 | languages: ${{ matrix.language }} 46 | # If you wish to specify custom queries, you can do so here or in a config file. 47 | # By default, queries listed here will override any specified in a config file. 48 | # Prefix the list here with "+" to use these queries and those in the config file. 49 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 50 | 51 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 52 | # If this step fails, then you should remove it and run the build manually (see below) 53 | - name: Autobuild 54 | uses: github/codeql-action/autobuild@v1 55 | 56 | # ℹ️ Command-line programs to run using the OS shell. 57 | # 📚 https://git.io/JvXDl 58 | 59 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 60 | # and modify them (or add more) to build your code if your project 61 | # uses a compiled language 62 | 63 | #- run: | 64 | # make bootstrap 65 | # make release 66 | 67 | - name: Perform CodeQL Analysis 68 | uses: github/codeql-action/analyze@v1 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.egg-info 4 | .tox 5 | *.egg 6 | .idea 7 | env 8 | build 9 | dist 10 | .DS_Store 11 | venv 12 | tests_app/tests/files 13 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | cache: pip 3 | dist: bionic 4 | sudo: false 5 | arch: 6 | - amd64 7 | - ppc64le 8 | python: 9 | - 3.6 10 | - 3.7 11 | - 3.8 12 | 13 | 14 | install: 15 | - pip install tox tox-travis 16 | 17 | script: 18 | - tox -r 19 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | ## Original Author 2 | --------------- 3 | Gennady Chibisov https://github.com/chibisov 4 | 5 | ## Core maintainer 6 | Asif Saif Uddin https://github.com/auvipy 7 | 8 | 9 | ## Contributors 10 | ------------ 11 | Luke Murphy https://github.com/lwm 12 | -------------------------------------------------------------------------------- /GNUmakefile: -------------------------------------------------------------------------------- 1 | build_docs: 2 | PYTHONIOENCODING=utf-8 python docs/backdoc.py --title "Django Rest Framework extensions documentation" < docs/index.md > docs/index.html 3 | 4 | watch_docs: 5 | make build_docs 6 | watchmedo shell-command -p "*.md" -R -c "make build_docs" docs/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013 Gennady Chibisov. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include tox.ini 4 | recursive-include docs *.md *.html *.txt *.py 5 | recursive-include tests_app requirements.txt *.py 6 | recursive-exclude * __pycache__ 7 | global-exclude *pyc 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Django REST Framework extensions 2 | 3 | DRF-extensions is a collection of custom extensions for [Django REST Framework](https://github.com/tomchristie/django-rest-framework) 4 | 5 | Full documentation for project is available at [http://chibisov.github.io/drf-extensions/docs](http://chibisov.github.io/drf-extensions/docs) 6 | 7 | [![Backers on Open Collective](https://opencollective.com/drf-extensions/backers/badge.svg)](#backers) [![Sponsors on Open Collective](https://opencollective.com/drf-extensions/sponsors/badge.svg)](#sponsors) [![PyPI](https://img.shields.io/pypi/v/drf-extensions.svg)](https://pypi.python.org/pypi/drf-extensions) 8 | 9 | ### Sponsor 10 | 11 | [Tidelift gives software development teams a single source for purchasing and maintaining their software, with professional grade assurances from the experts who know it best, while seamlessly integrating with existing tools.](https://tidelift.com/subscription/pkg/pypi-drf-extensions?utm_source=pypi-drf-extensions&utm_medium=referral&utm_campaign=readme) 12 | 13 | 14 | ## Requirements 15 | 16 | * Tested for Python 3.8, 3.9, 3.10, 3.11 and 3.12 17 | * Tested for Django Rest Framework 3.12, 3.13, 3.14 and 3.15 18 | * Tested for Django 2.2 to 5.2 19 | * Tested for django-filter 2.1.0+ 20 | 21 | ## Installation: 22 | 23 | pip3 install drf-extensions 24 | 25 | or from github 26 | 27 | pip3 install https://github.com/chibisov/drf-extensions/archive/master.zip 28 | 29 | ## Some features 30 | 31 | * DetailSerializerMixin 32 | * Caching 33 | * Conditional requests 34 | * Customizable key construction for caching and conditional requests 35 | * Nested routes 36 | * Bulk operations 37 | 38 | Read more in [documentation](http://chibisov.github.io/drf-extensions/docs) 39 | 40 | ## Development 41 | 42 | Running the tests: 43 | 44 | $ pip3 install tox 45 | $ tox -- tests_app 46 | 47 | Running test for exact environment: 48 | 49 | $ tox -e py38 -- tests_app 50 | 51 | Recreate envs before running tests: 52 | 53 | $ tox --recreate -- tests_app 54 | 55 | Pass custom arguments: 56 | 57 | $ tox -- tests_app --verbosity=3 58 | 59 | Run with pdb support: 60 | 61 | $ tox -- tests_app --processes=0 --nocapture 62 | 63 | Run exact TestCase: 64 | 65 | $ tox -- tests_app.tests.unit.mixins.tests:DetailSerializerMixinTest_serializer_detail_class 66 | 67 | Run tests from exact module: 68 | 69 | $ tox -- tests_app.tests.unit.mixins.tests 70 | 71 | Build docs: 72 | 73 | $ make build_docs 74 | 75 | Automatically build docs by watching changes: 76 | 77 | $ pip install watchdog 78 | $ make watch_docs 79 | 80 | ## Developing new features 81 | 82 | Every new feature should be: 83 | 84 | * Documented 85 | * Tested 86 | * Implemented 87 | * Pushed to main repository 88 | 89 | ### How to write documentation 90 | 91 | When new feature implementation starts you should place it into `development version` pull. Add `Development version` 92 | section to `Release notes` and describe every new feature in it. Use `#anchors` to facilitate navigation. 93 | 94 | Every feature should have title and information that it was implemented in current development version. 95 | 96 | For example if we've just implemented `Usage of the specific cache`: 97 | 98 | ... 99 | 100 | #### Usage of the specific cache 101 | 102 | *New in DRF-extensions development version* 103 | 104 | `@cache_response` can also take... 105 | 106 | ... 107 | 108 | ### Release notes 109 | 110 | ... 111 | 112 | #### Development version 113 | 114 | * Added ability to [use a specific cache](#usage-of-the-specific-cache) for `@cache_response` decorator 115 | 116 | ## Publishing new releases 117 | 118 | Increment version in `rest_framework_extensions/__init__.py`. For example: 119 | 120 | __version__ = '0.2.2' # from 0.2.1 121 | 122 | Move to new version section all release notes in documentation. 123 | 124 | Add date for release note section. 125 | 126 | Replace in documentation all `New in DRF-extensions development version` notes to `New in DRF-extensions 0.2.2`. 127 | 128 | Rebuild documentation. 129 | 130 | Run tests. 131 | 132 | Commit changes with message "Version 0.2.2" 133 | 134 | Add new tag version for commit: 135 | 136 | $ git tag 0.2.2 137 | 138 | Push to master with tags: 139 | 140 | $ git push origin master --tags 141 | 142 | Don't forget to merge `master` to `gh-pages` branch and push to origin: 143 | 144 | $ git co gh-pages 145 | $ git merge --no-ff master 146 | $ git push origin gh-pages 147 | 148 | Publish to pypi: 149 | 150 | $ python setup.py publish 151 | 152 | ## Contributors 153 | 154 | This project exists thanks to all the people who contribute. 155 | 156 | 157 | ## Backers 158 | 159 | Thank you to all our backers! 🙏 [[Become a backer](https://opencollective.com/drf-extensions#backer)] 160 | 161 | 162 | 163 | 164 | ## Sponsors 165 | 166 | Support this project by becoming a sponsor. Your logo will show up here with a link to your website. [[Become a sponsor](https://opencollective.com/drf-extensions#sponsor)] 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | 6 | | Version | Supported | 7 | | ------- | ------------------ | 8 | | 0.7.x | :white_check_mark: | 9 | | 0.6.x | :x: | 10 | | < 0.7 | :x: | 11 | 12 | ## Reporting a Vulnerability 13 | 14 | Please report Vulnerability to auvipy@gmail.com via email. 15 | -------------------------------------------------------------------------------- /rest_framework_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.8.0' # from 0.7.1 2 | 3 | VERSION = __version__ 4 | -------------------------------------------------------------------------------- /rest_framework_extensions/bulk_operations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/rest_framework_extensions/bulk_operations/__init__.py -------------------------------------------------------------------------------- /rest_framework_extensions/bulk_operations/mixins.py: -------------------------------------------------------------------------------- 1 | from django.utils.encoding import force_str 2 | 3 | from rest_framework import status 4 | from rest_framework.response import Response 5 | from rest_framework_extensions.settings import extensions_api_settings 6 | from rest_framework_extensions import utils 7 | 8 | 9 | class BulkOperationBaseMixin: 10 | def is_object_operation(self): 11 | return bool(self.get_object_lookup_value()) 12 | 13 | def get_object_lookup_value(self): 14 | return self.kwargs.get(getattr(self, 'lookup_url_kwarg', None) or self.lookup_field, None) 15 | 16 | def is_valid_bulk_operation(self): 17 | if extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME: 18 | header_name = utils.prepare_header_name( 19 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME) 20 | return bool(self.request.META.get(header_name, None)), { 21 | 'detail': 'Header \'{0}\' should be provided for bulk operation.'.format( 22 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME 23 | ) 24 | } 25 | else: 26 | return True, {} 27 | 28 | 29 | class ListDestroyModelMixin(BulkOperationBaseMixin): 30 | def delete(self, request, *args, **kwargs): 31 | if self.is_object_operation(): 32 | return super().destroy(request, *args, **kwargs) 33 | else: 34 | return self.destroy_bulk(request, *args, **kwargs) 35 | 36 | def destroy_bulk(self, request, *args, **kwargs): 37 | is_valid, errors = self.is_valid_bulk_operation() 38 | if is_valid: 39 | queryset = self.filter_queryset(self.get_queryset()) 40 | self.pre_delete_bulk(queryset) # todo: test and document me 41 | queryset.delete() 42 | self.post_delete_bulk(queryset) # todo: test and document me 43 | return Response(status=status.HTTP_204_NO_CONTENT) 44 | else: 45 | return Response(errors, status=status.HTTP_400_BAD_REQUEST) 46 | 47 | def pre_delete_bulk(self, queryset): 48 | """ 49 | Placeholder method for calling before deleting an queryset. 50 | """ 51 | pass 52 | 53 | def post_delete_bulk(self, queryset): 54 | """ 55 | Placeholder method for calling after deleting an queryset. 56 | """ 57 | pass 58 | 59 | 60 | class ListUpdateModelMixin(BulkOperationBaseMixin): 61 | def patch(self, request, *args, **kwargs): 62 | if self.is_object_operation(): 63 | return super().partial_update(request, *args, **kwargs) 64 | else: 65 | return self.partial_update_bulk(request, *args, **kwargs) 66 | 67 | def partial_update_bulk(self, request, *args, **kwargs): 68 | is_valid, errors = self.is_valid_bulk_operation() 69 | if is_valid: 70 | queryset = self.filter_queryset(self.get_queryset()) 71 | update_bulk_dict = self.get_update_bulk_dict( 72 | serializer=self.get_serializer_class()(), data=request.data) 73 | # todo: test and document me 74 | self.pre_save_bulk(queryset, update_bulk_dict) 75 | try: 76 | queryset.update(**update_bulk_dict) 77 | except ValueError as e: 78 | errors = { 79 | 'detail': force_str(e) 80 | } 81 | return Response(errors, status=status.HTTP_400_BAD_REQUEST) 82 | # todo: test and document me 83 | self.post_save_bulk(queryset, update_bulk_dict) 84 | return Response(status=status.HTTP_204_NO_CONTENT) 85 | else: 86 | return Response(errors, status=status.HTTP_400_BAD_REQUEST) 87 | 88 | def get_update_bulk_dict(self, serializer, data): 89 | update_bulk_dict = {} 90 | for field_name, field in serializer.fields.items(): 91 | if field_name in data and not field.read_only: 92 | update_bulk_dict[field.source or field_name] = data[field_name] 93 | return update_bulk_dict 94 | 95 | def pre_save_bulk(self, queryset, update_bulk_dict): 96 | """ 97 | Placeholder method for calling before deleting an queryset. 98 | """ 99 | pass 100 | 101 | def post_save_bulk(self, queryset, update_bulk_dict): 102 | """ 103 | Placeholder method for calling after deleting an queryset. 104 | """ 105 | pass 106 | -------------------------------------------------------------------------------- /rest_framework_extensions/cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/rest_framework_extensions/cache/__init__.py -------------------------------------------------------------------------------- /rest_framework_extensions/cache/decorators.py: -------------------------------------------------------------------------------- 1 | from functools import wraps, WRAPPER_ASSIGNMENTS 2 | 3 | from django.http.response import HttpResponse 4 | 5 | 6 | from rest_framework_extensions.settings import extensions_api_settings 7 | 8 | 9 | def get_cache(alias): 10 | from django.core.cache import caches 11 | return caches[alias] 12 | 13 | 14 | class CacheResponse: 15 | """ 16 | Store/Receive and return cached `HttpResponse` based on DRF response. 17 | 18 | 19 | .. note:: 20 | This decorator will render and discard the original DRF response in 21 | favor of Django's `HttpResponse`. The allows the cache to retain a 22 | smaller memory footprint and eliminates the need to re-render 23 | responses on each request. Furthermore it eliminates the risk for users 24 | to unknowingly cache whole Serializers and QuerySets. 25 | 26 | """ 27 | def __init__(self, 28 | timeout=None, 29 | key_func=None, 30 | cache=None, 31 | cache_errors=None): 32 | if timeout is None: 33 | self.timeout = extensions_api_settings.DEFAULT_CACHE_RESPONSE_TIMEOUT 34 | else: 35 | self.timeout = timeout 36 | 37 | if key_func is None: 38 | self.key_func = extensions_api_settings.DEFAULT_CACHE_KEY_FUNC 39 | else: 40 | self.key_func = key_func 41 | 42 | if cache_errors is None: 43 | self.cache_errors = extensions_api_settings.DEFAULT_CACHE_ERRORS 44 | else: 45 | self.cache_errors = cache_errors 46 | 47 | self.cache = get_cache(cache or extensions_api_settings.DEFAULT_USE_CACHE) 48 | 49 | def __call__(self, func): 50 | this = self 51 | 52 | @wraps(func, assigned=WRAPPER_ASSIGNMENTS) 53 | def inner(self, request, *args, **kwargs): 54 | return this.process_cache_response( 55 | view_instance=self, 56 | view_method=func, 57 | request=request, 58 | args=args, 59 | kwargs=kwargs, 60 | ) 61 | return inner 62 | 63 | def process_cache_response(self, 64 | view_instance, 65 | view_method, 66 | request, 67 | args, 68 | kwargs): 69 | 70 | key = self.calculate_key( 71 | view_instance=view_instance, 72 | view_method=view_method, 73 | request=request, 74 | args=args, 75 | kwargs=kwargs 76 | ) 77 | 78 | timeout = self.calculate_timeout(view_instance=view_instance) 79 | 80 | response_triple = self.cache.get(key) 81 | if not response_triple: 82 | # render response to create and cache the content byte string 83 | response = view_method(view_instance, request, *args, **kwargs) 84 | response = view_instance.finalize_response(request, response, *args, **kwargs) 85 | response.render() 86 | 87 | if not response.status_code >= 400 or self.cache_errors: 88 | # django 3.0 has not .items() method, django 3.2 has not ._headers 89 | if hasattr(response, '_headers'): 90 | headers = response._headers.copy() 91 | else: 92 | headers = {k: (k, v) for k, v in response.items()} 93 | response_triple = ( 94 | response.rendered_content, 95 | response.status_code, 96 | headers 97 | ) 98 | self.cache.set(key, response_triple, timeout) 99 | else: 100 | # build smaller Django HttpResponse 101 | content, status, headers = response_triple 102 | response = HttpResponse(content=content, status=status) 103 | for k, v in headers.values(): 104 | response[k] = v 105 | if not hasattr(response, '_closable_objects'): 106 | response._closable_objects = [] 107 | 108 | return response 109 | 110 | def calculate_key(self, 111 | view_instance, 112 | view_method, 113 | request, 114 | args, 115 | kwargs): 116 | if isinstance(self.key_func, str): 117 | key_func = getattr(view_instance, self.key_func) 118 | else: 119 | key_func = self.key_func 120 | return key_func( 121 | view_instance=view_instance, 122 | view_method=view_method, 123 | request=request, 124 | args=args, 125 | kwargs=kwargs, 126 | ) 127 | 128 | def calculate_timeout(self, view_instance, **_): 129 | if isinstance(self.timeout, str): 130 | self.timeout = getattr(view_instance, self.timeout) 131 | return self.timeout 132 | 133 | 134 | cache_response = CacheResponse 135 | -------------------------------------------------------------------------------- /rest_framework_extensions/cache/mixins.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.cache.decorators import cache_response 2 | from rest_framework_extensions.settings import extensions_api_settings 3 | 4 | 5 | class BaseCacheResponseMixin: 6 | # todo: test me. Create generic test like 7 | # test_cache_reponse(view_instance, method, should_rebuild_after_method_evaluation) 8 | object_cache_key_func = extensions_api_settings.DEFAULT_OBJECT_CACHE_KEY_FUNC 9 | list_cache_key_func = extensions_api_settings.DEFAULT_LIST_CACHE_KEY_FUNC 10 | object_cache_timeout = extensions_api_settings.DEFAULT_CACHE_RESPONSE_TIMEOUT 11 | list_cache_timeout = extensions_api_settings.DEFAULT_CACHE_RESPONSE_TIMEOUT 12 | 13 | 14 | class ListCacheResponseMixin(BaseCacheResponseMixin): 15 | @cache_response(key_func='list_cache_key_func', timeout='list_cache_timeout') 16 | def list(self, request, *args, **kwargs): 17 | return super().list(request, *args, **kwargs) 18 | 19 | 20 | class RetrieveCacheResponseMixin(BaseCacheResponseMixin): 21 | @cache_response(key_func='object_cache_key_func', timeout='object_cache_timeout') 22 | def retrieve(self, request, *args, **kwargs): 23 | return super().retrieve(request, *args, **kwargs) 24 | 25 | 26 | class CacheResponseMixin(RetrieveCacheResponseMixin, 27 | ListCacheResponseMixin): 28 | pass 29 | -------------------------------------------------------------------------------- /rest_framework_extensions/compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `compat` module provides support for backwards compatibility with older 3 | versions of django/python, and compatibility wrappers around optional packages. 4 | """ 5 | 6 | 7 | # handle different QuerySet representations 8 | def queryset_to_value_list(queryset): 9 | assert isinstance(queryset, str) 10 | 11 | # django 1.10 introduces syntax "" 12 | # we extract only the list of tuples from the string 13 | idx_bracket_open = queryset.find(u'[') 14 | idx_bracket_close = queryset.rfind(u']') 15 | 16 | return queryset[idx_bracket_open:idx_bracket_close + 1] 17 | -------------------------------------------------------------------------------- /rest_framework_extensions/decorators.py: -------------------------------------------------------------------------------- 1 | def paginate(pagination_class=None, **kwargs): 2 | """ 3 | Decorator that adds a pagination_class to GenericViewSet class. 4 | Custom pagination class also available. 5 | 6 | Usage : 7 | from rest_framework.pagination import CursorPagination 8 | 9 | @paginate(pagination_class=CursorPagination, page_size=5, ordering='-created_at') 10 | class FooViewSet(viewsets.GenericViewSet): 11 | ... 12 | 13 | """ 14 | assert pagination_class is not None, ( 15 | "@paginate missing required argument: 'pagination_class'" 16 | ) 17 | 18 | class _Pagination(pagination_class): 19 | def __init__(self): 20 | self.__dict__.update(kwargs) 21 | 22 | def decorator(_class): 23 | _class.pagination_class = _Pagination 24 | return _class 25 | 26 | return decorator 27 | -------------------------------------------------------------------------------- /rest_framework_extensions/etag/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rest_framework_extensions/etag/decorators.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import wraps, WRAPPER_ASSIGNMENTS 3 | 4 | from django.utils.http import parse_etags, quote_etag 5 | 6 | from rest_framework import status 7 | from rest_framework.permissions import SAFE_METHODS 8 | from rest_framework.response import Response 9 | from rest_framework_extensions.exceptions import PreconditionRequiredException 10 | 11 | from rest_framework_extensions.utils import prepare_header_name 12 | from rest_framework_extensions.settings import extensions_api_settings 13 | 14 | logger = logging.getLogger('django.request') 15 | 16 | 17 | class ETAGProcessor: 18 | """Based on https://github.com/django/django/blob/master/django/views/decorators/http.py""" 19 | 20 | def __init__(self, etag_func=None, rebuild_after_method_evaluation=False): 21 | if not etag_func: 22 | etag_func = extensions_api_settings.DEFAULT_ETAG_FUNC 23 | self.etag_func = etag_func 24 | self.rebuild_after_method_evaluation = rebuild_after_method_evaluation 25 | 26 | def __call__(self, func): 27 | this = self 28 | 29 | @wraps(func, assigned=WRAPPER_ASSIGNMENTS) 30 | def inner(self, request, *args, **kwargs): 31 | return this.process_conditional_request( 32 | view_instance=self, 33 | view_method=func, 34 | request=request, 35 | args=args, 36 | kwargs=kwargs, 37 | ) 38 | 39 | return inner 40 | 41 | def process_conditional_request(self, 42 | view_instance, 43 | view_method, 44 | request, 45 | args, 46 | kwargs): 47 | etags, if_none_match, if_match = self.get_etags_and_matchers(request) 48 | res_etag = self.calculate_etag( 49 | view_instance=view_instance, 50 | view_method=view_method, 51 | request=request, 52 | args=args, 53 | kwargs=kwargs, 54 | ) 55 | 56 | if self.is_if_none_match_failed(res_etag, etags, if_none_match): 57 | if request.method in SAFE_METHODS: 58 | response = Response(status=status.HTTP_304_NOT_MODIFIED) 59 | else: 60 | response = self._get_and_log_precondition_failed_response( 61 | request=request) 62 | elif self.is_if_match_failed(res_etag, etags, if_match): 63 | response = self._get_and_log_precondition_failed_response( 64 | request=request) 65 | else: 66 | response = view_method(view_instance, request, *args, **kwargs) 67 | if self.rebuild_after_method_evaluation: 68 | res_etag = self.calculate_etag( 69 | view_instance=view_instance, 70 | view_method=view_method, 71 | request=request, 72 | args=args, 73 | kwargs=kwargs, 74 | ) 75 | 76 | if res_etag and not response.has_header('ETag'): 77 | response['ETag'] = quote_etag(res_etag) 78 | 79 | return response 80 | 81 | def get_etags_and_matchers(self, request): 82 | etags = None 83 | if_none_match = request.META.get(prepare_header_name("if-none-match")) 84 | if_match = request.META.get(prepare_header_name("if-match")) 85 | if if_none_match or if_match: 86 | # There can be more than one ETag in the request, so we 87 | # consider the list of values. 88 | try: 89 | value_to_parse = if_none_match or if_match 90 | if value_to_parse: 91 | etag_list = [e.strip() for e in value_to_parse.split(' ') if e.strip()] 92 | etag_list = [e if e.startswith('"') else f'"{e}"' for e in etag_list] 93 | value_to_parse = ', '.join(etag_list) 94 | etags = parse_etags(value_to_parse) 95 | except ValueError: 96 | # In case of invalid etag ignore all ETag headers. 97 | # Apparently Opera sends invalidly quoted headers at times 98 | # (we should be returning a 400 response, but that's a 99 | # little extreme) -- this is Django bug #10681. 100 | if_none_match = None 101 | if_match = None 102 | return etags, if_none_match, if_match 103 | 104 | def calculate_etag(self, 105 | view_instance, 106 | view_method, 107 | request, 108 | args, 109 | kwargs): 110 | if isinstance(self.etag_func, str): 111 | etag_func = getattr(view_instance, self.etag_func) 112 | else: 113 | etag_func = self.etag_func 114 | return etag_func( 115 | view_instance=view_instance, 116 | view_method=view_method, 117 | request=request, 118 | args=args, 119 | kwargs=kwargs, 120 | ) 121 | 122 | def is_if_none_match_failed(self, res_etag, etags, if_none_match): 123 | if res_etag and if_none_match: 124 | etags = [etag.strip('"') for etag in etags] 125 | return res_etag in etags or '*' in etags 126 | else: 127 | return False 128 | 129 | def is_if_match_failed(self, res_etag, etags, if_match): 130 | if res_etag and if_match: 131 | res_etag =res_etag.strip('"') 132 | etags = [etag.strip('"') for etag in etags] 133 | matches = res_etag in etags or '*' in etags 134 | return not matches 135 | else: 136 | return False 137 | 138 | def _get_and_log_precondition_failed_response(self, request): 139 | logger.warning('Precondition Failed: %s', request.path, 140 | extra={ 141 | 'status_code': status.HTTP_412_PRECONDITION_FAILED, 142 | 'request': request 143 | } 144 | ) 145 | return Response(status=status.HTTP_412_PRECONDITION_FAILED) 146 | 147 | 148 | class APIETAGProcessor(ETAGProcessor): 149 | """ 150 | This class is responsible for calculating the ETag value given (a list of) model instance(s). 151 | 152 | It does not make sense to compute a default ETag here, because the processor would always issue a 304 response, 153 | even if the response was modified meanwhile. 154 | Therefore the `APIETAGProcessor` cannot be used without specifying an `etag_func` as keyword argument. 155 | 156 | According to RFC 6585, conditional headers may be enforced for certain services that support conditional 157 | requests. For optimistic locking, the server should respond status code 428 including a description on how 158 | to resubmit the request successfully, see https://tools.ietf.org/html/rfc6585#section-3. 159 | """ 160 | 161 | # require a pre-conditional header (e.g. If-Match) for unsafe HTTP methods (RFC 6585) 162 | # override this defaults, if required 163 | precondition_map = {'PUT': ['If-Match'], 164 | 'PATCH': ['If-Match'], 165 | 'DELETE': ['If-Match']} 166 | 167 | def __init__(self, etag_func=None, rebuild_after_method_evaluation=False, precondition_map=None): 168 | assert etag_func is not None, ('None-type functions are not allowed for processing API ETags.' 169 | 'You must specify a proper function to calculate the API ETags ' 170 | 'using the "etag_func" keyword argument.') 171 | 172 | if precondition_map is not None: 173 | self.precondition_map = precondition_map 174 | assert isinstance(self.precondition_map, dict), ('`precondition_map` must be a dict, where ' 175 | 'the key is the HTTP verb, and the value is a list of ' 176 | 'HTTP headers that must all be present for that request.') 177 | 178 | super().__init__(etag_func=etag_func, 179 | rebuild_after_method_evaluation=rebuild_after_method_evaluation) 180 | 181 | def get_etags_and_matchers(self, request): 182 | """Get the etags from the header and perform a validation against the required preconditions.""" 183 | # evaluate the preconditions, raises 428 if condition is not met 184 | self.evaluate_preconditions(request) 185 | # alright, headers are present, extract the values and match the conditions 186 | return super().get_etags_and_matchers(request) 187 | 188 | def evaluate_preconditions(self, request): 189 | """Evaluate whether the precondition for the request is met.""" 190 | if request.method.upper() in self.precondition_map.keys(): 191 | required_headers = self.precondition_map.get( 192 | request.method.upper(), []) 193 | # check the required headers 194 | for header in required_headers: 195 | if not request.META.get(prepare_header_name(header)): 196 | # raise an error for each header that does not match 197 | logger.warning('Precondition required: %s', request.path, 198 | extra={ 199 | 'status_code': status.HTTP_428_PRECONDITION_REQUIRED, 200 | 'request': request 201 | } 202 | ) 203 | # raise an RFC 6585 compliant exception 204 | raise PreconditionRequiredException(detail='Precondition required. This "%s" request ' 205 | 'is required to be conditional. ' 206 | 'Try again using "%s".' % ( 207 | request.method, header) 208 | ) 209 | return True 210 | 211 | 212 | etag = ETAGProcessor 213 | api_etag = APIETAGProcessor 214 | -------------------------------------------------------------------------------- /rest_framework_extensions/etag/mixins.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.etag.decorators import etag, api_etag 2 | from rest_framework_extensions.settings import extensions_api_settings 3 | 4 | 5 | class BaseETAGMixin: 6 | # todo: test me. Create generic test like test_etag(view_instance, 7 | # method, should_rebuild_after_method_evaluation) 8 | object_etag_func = extensions_api_settings.DEFAULT_OBJECT_ETAG_FUNC 9 | list_etag_func = extensions_api_settings.DEFAULT_LIST_ETAG_FUNC 10 | 11 | 12 | class ListETAGMixin(BaseETAGMixin): 13 | @etag(etag_func='list_etag_func') 14 | def list(self, request, *args, **kwargs): 15 | return super().list(request, *args, **kwargs) 16 | 17 | 18 | class RetrieveETAGMixin(BaseETAGMixin): 19 | @etag(etag_func='object_etag_func') 20 | def retrieve(self, request, *args, **kwargs): 21 | return super().retrieve(request, *args, **kwargs) 22 | 23 | 24 | class UpdateETAGMixin(BaseETAGMixin): 25 | @etag(etag_func='object_etag_func', rebuild_after_method_evaluation=True) 26 | def update(self, request, *args, **kwargs): 27 | return super().update(request, *args, **kwargs) 28 | 29 | 30 | class DestroyETAGMixin(BaseETAGMixin): 31 | @etag(etag_func='object_etag_func') 32 | def destroy(self, request, *args, **kwargs): 33 | return super().destroy(request, *args, **kwargs) 34 | 35 | 36 | class ReadOnlyETAGMixin(RetrieveETAGMixin, 37 | ListETAGMixin): 38 | pass 39 | 40 | 41 | class ETAGMixin(RetrieveETAGMixin, 42 | UpdateETAGMixin, 43 | DestroyETAGMixin, 44 | ListETAGMixin): 45 | pass 46 | 47 | 48 | class APIBaseETAGMixin: 49 | # todo: test me. Create generic test like test_etag(view_instance, 50 | # method, should_rebuild_after_method_evaluation) 51 | api_object_etag_func = extensions_api_settings.DEFAULT_API_OBJECT_ETAG_FUNC 52 | api_list_etag_func = extensions_api_settings.DEFAULT_API_LIST_ETAG_FUNC 53 | 54 | 55 | class APIListETAGMixin(APIBaseETAGMixin): 56 | @api_etag(etag_func='api_list_etag_func') 57 | def list(self, request, *args, **kwargs): 58 | return super().list(request, *args, **kwargs) 59 | 60 | 61 | class APIRetrieveETAGMixin(APIBaseETAGMixin): 62 | @api_etag(etag_func='api_object_etag_func') 63 | def retrieve(self, request, *args, **kwargs): 64 | return super().retrieve(request, *args, **kwargs) 65 | 66 | 67 | class APIUpdateETAGMixin(APIBaseETAGMixin): 68 | @api_etag(etag_func='api_object_etag_func', rebuild_after_method_evaluation=True) 69 | def update(self, request, *args, **kwargs): 70 | return super().update(request, *args, **kwargs) 71 | 72 | 73 | class APIDestroyETAGMixin(APIBaseETAGMixin): 74 | @api_etag(etag_func='api_object_etag_func') 75 | def destroy(self, request, *args, **kwargs): 76 | return super().destroy(request, *args, **kwargs) 77 | 78 | 79 | class APIReadOnlyETAGMixin(APIRetrieveETAGMixin, 80 | APIListETAGMixin): 81 | pass 82 | 83 | 84 | class APIETAGMixin(APIRetrieveETAGMixin, 85 | APIUpdateETAGMixin, 86 | APIDestroyETAGMixin, 87 | APIListETAGMixin): 88 | pass 89 | -------------------------------------------------------------------------------- /rest_framework_extensions/exceptions.py: -------------------------------------------------------------------------------- 1 | from django.utils.translation import gettext_lazy as _ 2 | from rest_framework import status 3 | from rest_framework.exceptions import APIException 4 | 5 | 6 | class PreconditionRequiredException(APIException): 7 | status_code = status.HTTP_428_PRECONDITION_REQUIRED 8 | default_detail = _('This "{method}" request is required to be conditional.') 9 | default_code = 'precondition_required' 10 | -------------------------------------------------------------------------------- /rest_framework_extensions/fields.py: -------------------------------------------------------------------------------- 1 | from rest_framework.relations import HyperlinkedRelatedField 2 | 3 | 4 | class ResourceUriField(HyperlinkedRelatedField): 5 | """ 6 | Represents a hyperlinking uri that points to the 7 | detail view for that object. 8 | 9 | Example: 10 | class SurveySerializer(serializers.ModelSerializer): 11 | resource_uri = ResourceUriField(view_name='survey-detail') 12 | 13 | class Meta: 14 | model = Survey 15 | fields = ('id', 'resource_uri') 16 | 17 | ... 18 | { 19 | "id": 1, 20 | "resource_uri": "http://localhost/v1/surveys/1/", 21 | } 22 | """ 23 | # todo: test me 24 | read_only = True 25 | 26 | def __init__(self, *args, **kwargs): 27 | kwargs.setdefault('source', '*') 28 | super().__init__(*args, **kwargs) 29 | -------------------------------------------------------------------------------- /rest_framework_extensions/key_constructor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/rest_framework_extensions/key_constructor/__init__.py -------------------------------------------------------------------------------- /rest_framework_extensions/key_constructor/bits.py: -------------------------------------------------------------------------------- 1 | from django.utils.translation import get_language 2 | from django.db.models.query import EmptyQuerySet 3 | from django.core.exceptions import EmptyResultSet 4 | 5 | from django.utils.encoding import force_str 6 | 7 | from rest_framework_extensions import compat 8 | 9 | 10 | class AllArgsMixin: 11 | 12 | def __init__(self, params='*'): 13 | super().__init__(params) 14 | 15 | 16 | class KeyBitBase: 17 | def __init__(self, params=None): 18 | self.params = params 19 | 20 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 21 | """ 22 | @rtype: dict 23 | """ 24 | raise NotImplementedError() 25 | 26 | 27 | class KeyBitDictBase(KeyBitBase): 28 | """Base class for dict-like source data processing. 29 | 30 | Look at HeadersKeyBit and QueryParamsKeyBit 31 | 32 | """ 33 | 34 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 35 | data = {} 36 | 37 | if params is not None: 38 | source_dict = self.get_source_dict( 39 | params=params, 40 | view_instance=view_instance, 41 | view_method=view_method, 42 | request=request, 43 | args=args, 44 | kwargs=kwargs 45 | ) 46 | 47 | if params == '*': 48 | params = source_dict.keys() 49 | 50 | for key in params: 51 | value = source_dict.get( 52 | self.prepare_key_for_value_retrieving(key)) 53 | if value is not None: 54 | data[self.prepare_key_for_value_assignment( 55 | key)] = force_str(value) 56 | 57 | return data 58 | 59 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs): 60 | raise NotImplementedError() 61 | 62 | def prepare_key_for_value_retrieving(self, key): 63 | return key 64 | 65 | def prepare_key_for_value_assignment(self, key): 66 | return key 67 | 68 | 69 | class UniqueViewIdKeyBit(KeyBitBase): 70 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 71 | return '.'.join([ 72 | view_instance.__module__, 73 | view_instance.__class__.__name__ 74 | ]) 75 | 76 | 77 | class UniqueMethodIdKeyBit(KeyBitBase): 78 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 79 | return '.'.join([ 80 | view_instance.__module__, 81 | view_instance.__class__.__name__, 82 | view_method.__name__ 83 | ]) 84 | 85 | 86 | class LanguageKeyBit(KeyBitBase): 87 | """ 88 | Return example: 89 | 'en' 90 | 91 | """ 92 | 93 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 94 | return force_str(get_language()) 95 | 96 | 97 | class FormatKeyBit(KeyBitBase): 98 | """ 99 | Return example for json: 100 | u'json' 101 | 102 | Return example for html: 103 | u'html' 104 | """ 105 | 106 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 107 | return force_str(request.accepted_renderer.format) 108 | 109 | 110 | class UserKeyBit(KeyBitBase): 111 | """ 112 | Return example for anonymous: 113 | u'anonymous' 114 | 115 | Return example for authenticated (value is user id): 116 | u'10' 117 | """ 118 | 119 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 120 | if hasattr(request, 'user') and request.user and request.user.is_authenticated: 121 | return force_str(self._get_id_from_user(request.user)) 122 | else: 123 | return 'anonymous' 124 | 125 | def _get_id_from_user(self, user): 126 | return user.id 127 | 128 | 129 | class HeadersKeyBit(KeyBitDictBase): 130 | """ 131 | Return example: 132 | {'accept-language': u'ru', 'x-geobase-id': '123'} 133 | 134 | """ 135 | 136 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs): 137 | return request.META 138 | 139 | def prepare_key_for_value_retrieving(self, key): 140 | from rest_framework_extensions.utils import prepare_header_name 141 | 142 | # Accept-Language => http_accept_language 143 | return prepare_header_name(key.lower()) 144 | 145 | def prepare_key_for_value_assignment(self, key): 146 | return key.lower() # Accept-Language => accept-language 147 | 148 | 149 | class RequestMetaKeyBit(KeyBitDictBase): 150 | """ 151 | Return example: 152 | {'REMOTE_ADDR': u'127.0.0.2', 'REMOTE_HOST': u'yandex.ru'} 153 | 154 | """ 155 | 156 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs): 157 | return request.META 158 | 159 | 160 | class QueryParamsKeyBit(AllArgsMixin, KeyBitDictBase): 161 | """ 162 | Return example: 163 | {'part': 'Londo', 'callback': 'jquery_callback'} 164 | 165 | """ 166 | 167 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs): 168 | return request.GET 169 | 170 | 171 | class PaginationKeyBit(QueryParamsKeyBit): 172 | """ 173 | Return example: 174 | {'page_size': 100, 'page': '1'} 175 | 176 | """ 177 | paginator_attrs = [ 178 | 'page_query_param', 'page_size_query_param', 179 | 'limit_query_param', 'offset_query_param', 180 | 'cursor_query_param', 181 | ] 182 | 183 | def get_data(self, **kwargs): 184 | kwargs['params'] = [] 185 | paginator = getattr(kwargs['view_instance'], 'paginator', None) 186 | 187 | if paginator: 188 | for attr in self.paginator_attrs: 189 | param = getattr(paginator, attr, None) 190 | if param: 191 | kwargs['params'].append(param) 192 | 193 | return super().get_data(**kwargs) 194 | 195 | 196 | class SqlQueryKeyBitBase(KeyBitBase): 197 | def _get_queryset_query_string(self, queryset): 198 | if isinstance(queryset, EmptyQuerySet): 199 | return None 200 | else: 201 | try: 202 | return force_str(queryset.query.__str__()) 203 | except EmptyResultSet: 204 | return None 205 | 206 | 207 | class ModelInstanceKeyBitBase(KeyBitBase): 208 | """ 209 | Return the actual contents of the query set. 210 | This class is similar to the `SqlQueryKeyBitBase`. 211 | """ 212 | 213 | def _get_queryset_query_values(self, queryset): 214 | if isinstance(queryset, EmptyQuerySet) or queryset.count() == 0: 215 | return None 216 | else: 217 | try: 218 | # run through the instances and collect all values in ordered fashion 219 | return compat.queryset_to_value_list(force_str(queryset.values_list())) 220 | except EmptyResultSet: 221 | return None 222 | 223 | 224 | class ListSqlQueryKeyBit(SqlQueryKeyBitBase): 225 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 226 | queryset = view_instance.filter_queryset(view_instance.get_queryset()) 227 | return self._get_queryset_query_string(queryset) 228 | 229 | 230 | class RetrieveSqlQueryKeyBit(SqlQueryKeyBitBase): 231 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 232 | lookup_value = view_instance.kwargs[ 233 | view_instance.lookup_url_kwarg or view_instance.lookup_field] 234 | try: 235 | queryset = view_instance.filter_queryset(view_instance.get_queryset()).filter( 236 | **{view_instance.lookup_field: lookup_value} 237 | ) 238 | except ValueError: 239 | return None 240 | else: 241 | return self._get_queryset_query_string(queryset) 242 | 243 | 244 | class RetrieveModelKeyBit(ModelInstanceKeyBitBase): 245 | """ 246 | A key bit reflecting the contents of the model instance. 247 | Return example: 248 | u"[(3, False)]" 249 | """ 250 | 251 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 252 | lookup_value = view_instance.kwargs[view_instance.lookup_field] 253 | try: 254 | queryset = view_instance.filter_queryset(view_instance.get_queryset()).filter( 255 | **{view_instance.lookup_field: lookup_value} 256 | ) 257 | except ValueError: 258 | return None 259 | else: 260 | return self._get_queryset_query_values(queryset) 261 | 262 | 263 | class ListModelKeyBit(ModelInstanceKeyBitBase): 264 | """ 265 | A key bit reflecting the contents of a list of model instances. 266 | Return example: 267 | u"[(1, True), (2, True), (3, False)]" 268 | """ 269 | 270 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 271 | queryset = view_instance.filter_queryset(view_instance.get_queryset()) 272 | return self._get_queryset_query_values(queryset) 273 | 274 | 275 | class ArgsKeyBit(AllArgsMixin, KeyBitBase): 276 | 277 | def get_data(self, params, view_instance, view_method, request, args, kwargs): 278 | if params == '*': 279 | return args 280 | elif params is not None: 281 | return [args[i] for i in params] 282 | else: 283 | return [] 284 | 285 | 286 | class KwargsKeyBit(AllArgsMixin, KeyBitDictBase): 287 | 288 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs): 289 | return kwargs 290 | -------------------------------------------------------------------------------- /rest_framework_extensions/key_constructor/constructors.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | 4 | from rest_framework_extensions.key_constructor import bits 5 | from rest_framework_extensions.settings import extensions_api_settings 6 | 7 | 8 | class KeyConstructor: 9 | def __init__(self, memoize_for_request=None, params=None): 10 | if memoize_for_request is None: 11 | self.memoize_for_request = extensions_api_settings.DEFAULT_KEY_CONSTRUCTOR_MEMOIZE_FOR_REQUEST 12 | else: 13 | self.memoize_for_request = memoize_for_request 14 | if params is None: 15 | self.params = {} 16 | else: 17 | self.params = params 18 | self.bits = self.get_bits() 19 | 20 | def get_bits(self): 21 | _bits = {} 22 | for attr in dir(self.__class__): 23 | attr_value = getattr(self.__class__, attr) 24 | if isinstance(attr_value, bits.KeyBitBase): 25 | _bits[attr] = attr_value 26 | return _bits 27 | 28 | def __call__(self, **kwargs): 29 | return self.get_key(**kwargs) 30 | 31 | def get_key(self, view_instance, view_method, request, args, kwargs): 32 | if self.memoize_for_request: 33 | memoization_key = self._get_memoization_key( 34 | view_instance=view_instance, 35 | view_method=view_method, 36 | args=args, 37 | kwargs=kwargs 38 | ) 39 | if not hasattr(request, '_key_constructor_cache'): 40 | request._key_constructor_cache = {} 41 | if self.memoize_for_request and memoization_key in request._key_constructor_cache: 42 | return request._key_constructor_cache.get(memoization_key) 43 | else: 44 | value = self._get_key( 45 | view_instance=view_instance, 46 | view_method=view_method, 47 | request=request, 48 | args=args, 49 | kwargs=kwargs 50 | ) 51 | if self.memoize_for_request: 52 | request._key_constructor_cache[memoization_key] = value 53 | return value 54 | 55 | def _get_memoization_key(self, view_instance, view_method, args, kwargs): 56 | from rest_framework_extensions.utils import get_unique_method_id 57 | return json.dumps({ 58 | 'unique_method_id': get_unique_method_id(view_instance=view_instance, view_method=view_method), 59 | 'args': args, 60 | 'kwargs': kwargs, 61 | 'instance_id': id(self) 62 | }) 63 | 64 | def _get_key(self, view_instance, view_method, request, args, kwargs): 65 | _kwargs = { 66 | 'view_instance': view_instance, 67 | 'view_method': view_method, 68 | 'request': request, 69 | 'args': args, 70 | 'kwargs': kwargs, 71 | } 72 | return self.prepare_key( 73 | self.get_data_from_bits(**_kwargs) 74 | ) 75 | 76 | def prepare_key(self, key_dict): 77 | return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode('utf-8')).hexdigest() 78 | 79 | def get_data_from_bits(self, **kwargs): 80 | result_dict = {} 81 | for bit_name, bit_instance in self.bits.items(): 82 | if bit_name in self.params: 83 | params = self.params[bit_name] 84 | else: 85 | try: 86 | params = bit_instance.params 87 | except AttributeError: 88 | params = None 89 | result_dict[bit_name] = bit_instance.get_data( 90 | params=params, **kwargs) 91 | return result_dict 92 | 93 | 94 | class DefaultKeyConstructor(KeyConstructor): 95 | unique_method_id = bits.UniqueMethodIdKeyBit() 96 | format = bits.FormatKeyBit() 97 | language = bits.LanguageKeyBit() 98 | 99 | 100 | class DefaultObjectKeyConstructor(DefaultKeyConstructor): 101 | retrieve_sql_query = bits.RetrieveSqlQueryKeyBit() 102 | 103 | 104 | class DefaultListKeyConstructor(DefaultKeyConstructor): 105 | list_sql_query = bits.ListSqlQueryKeyBit() 106 | pagination = bits.PaginationKeyBit() 107 | 108 | 109 | class DefaultAPIModelInstanceKeyConstructor(KeyConstructor): 110 | """ 111 | Use this constructor when the values of the model instance are required 112 | to identify the resource. 113 | """ 114 | retrieve_model_values = bits.RetrieveModelKeyBit() 115 | 116 | 117 | class DefaultAPIModelListKeyConstructor(KeyConstructor): 118 | """ 119 | Use this constructor when the values of the model instance are required 120 | to identify many resources. 121 | """ 122 | list_model_values = bits.ListModelKeyBit() 123 | -------------------------------------------------------------------------------- /rest_framework_extensions/mixins.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.cache.mixins import CacheResponseMixin 2 | # from rest_framework_extensions.etag.mixins import ReadOnlyETAGMixin, ETAGMixin 3 | from rest_framework_extensions.bulk_operations.mixins import ListUpdateModelMixin, ListDestroyModelMixin 4 | from rest_framework_extensions.settings import extensions_api_settings 5 | from django.core.exceptions import ValidationError 6 | from django.http import Http404 7 | import uuid 8 | 9 | class DetailSerializerMixin: 10 | """ 11 | Add custom serializer for detail view 12 | """ 13 | serializer_detail_class = None 14 | queryset_detail = None 15 | 16 | def get_serializer_class(self): 17 | error_message = "'{0}' should include a 'serializer_detail_class' attribute".format( 18 | self.__class__.__name__) 19 | assert self.serializer_detail_class is not None, error_message 20 | if self._is_request_to_detail_endpoint(): 21 | return self.serializer_detail_class 22 | else: 23 | return super().get_serializer_class() 24 | 25 | def get_queryset(self, *args, **kwargs): 26 | if self._is_request_to_detail_endpoint() and self.queryset_detail is not None: 27 | return self.queryset_detail.all() # todo: test all() 28 | else: 29 | return super().get_queryset(*args, **kwargs) 30 | 31 | def _is_request_to_detail_endpoint(self): 32 | if hasattr(self, 'lookup_url_kwarg'): 33 | lookup = self.lookup_url_kwarg or self.lookup_field 34 | return lookup and lookup in self.kwargs 35 | 36 | 37 | class PaginateByMaxMixin: 38 | 39 | def get_page_size(self, request): 40 | if self.page_size_query_param and self.max_page_size and request.query_params.get(self.page_size_query_param) == 'max': 41 | return self.max_page_size 42 | return super().get_page_size(request) 43 | 44 | 45 | # class ReadOnlyCacheResponseAndETAGMixin(ReadOnlyETAGMixin, CacheResponseMixin): 46 | # pass 47 | 48 | 49 | # class CacheResponseAndETAGMixin(ETAGMixin, CacheResponseMixin): 50 | # pass 51 | 52 | 53 | class NestedViewSetMixin: 54 | def get_queryset(self): 55 | return self.filter_queryset_by_parents_lookups( 56 | super().get_queryset() 57 | ) 58 | 59 | def filter_queryset_by_parents_lookups(self, queryset): 60 | parents_query_dict = self.get_parents_query_dict() 61 | if parents_query_dict: 62 | try: 63 | # Try to validate UUID fields before filtering 64 | cleaned_dict = {} 65 | for key, value in parents_query_dict.items(): 66 | if 'uuid' in key.lower() or key.endswith('_code'): 67 | try: 68 | # Try to validate as UUID 69 | cleaned_dict[key] = uuid.UUID(str(value)) 70 | except ValueError: 71 | raise Http404 72 | else: 73 | cleaned_dict[key] = value 74 | return queryset.filter(**cleaned_dict) 75 | except (ValueError, ValidationError): 76 | raise Http404 77 | else: 78 | return queryset 79 | 80 | def get_parents_query_dict(self): 81 | result = {} 82 | for kwarg_name, kwarg_value in self.kwargs.items(): 83 | if kwarg_name.startswith(extensions_api_settings.DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX): 84 | query_lookup = kwarg_name.replace( 85 | extensions_api_settings.DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX, 86 | '', 87 | 1 88 | ) 89 | query_value = kwarg_value 90 | result[query_lookup] = query_value 91 | return result 92 | -------------------------------------------------------------------------------- /rest_framework_extensions/permissions.py: -------------------------------------------------------------------------------- 1 | from rest_framework.permissions import DjangoObjectPermissions 2 | 3 | 4 | class ExtendedDjangoObjectPermissions(DjangoObjectPermissions): 5 | hide_forbidden_for_read_objects = True 6 | 7 | def has_object_permission(self, request, view, obj): 8 | if self.hide_forbidden_for_read_objects: 9 | return super().has_object_permission(request, view, obj) 10 | else: 11 | model_cls = getattr(view, 'model', None) 12 | queryset = getattr(view, 'queryset', None) 13 | 14 | if model_cls is None and queryset is not None: 15 | model_cls = queryset.model 16 | 17 | perms = self.get_required_object_permissions( 18 | request.method, model_cls) 19 | user = request.user 20 | 21 | return user.has_perms(perms, obj) 22 | -------------------------------------------------------------------------------- /rest_framework_extensions/routers.py: -------------------------------------------------------------------------------- 1 | from rest_framework.routers import DefaultRouter, SimpleRouter 2 | from rest_framework_extensions.utils import compose_parent_pk_kwarg_name 3 | 4 | 5 | class NestedRegistryItem: 6 | def __init__(self, router, parent_prefix, parent_item=None, parent_viewset=None): 7 | self.router = router 8 | self.parent_prefix = parent_prefix 9 | self.parent_item = parent_item 10 | self.parent_viewset = parent_viewset 11 | 12 | def register(self, prefix, viewset, basename, parents_query_lookups): 13 | self.router._register( 14 | prefix=self.get_prefix( 15 | current_prefix=prefix, 16 | parents_query_lookups=parents_query_lookups), 17 | viewset=viewset, 18 | basename=basename, 19 | ) 20 | return NestedRegistryItem( 21 | router=self.router, 22 | parent_prefix=prefix, 23 | parent_item=self, 24 | parent_viewset=viewset 25 | ) 26 | 27 | def get_prefix(self, current_prefix, parents_query_lookups): 28 | return '{0}/{1}'.format( 29 | self.get_parent_prefix(parents_query_lookups), 30 | current_prefix 31 | ) 32 | 33 | def get_parent_prefix(self, parents_query_lookups): 34 | prefix = '/' 35 | current_item = self 36 | i = len(parents_query_lookups) - 1 37 | while current_item: 38 | parent_lookup_value_regex = getattr( 39 | current_item.parent_viewset, 'lookup_value_regex', '[^/.]+') 40 | prefix = '{parent_prefix}/(?P<{parent_pk_kwarg_name}>{parent_lookup_value_regex})/{prefix}'.format( 41 | parent_prefix=current_item.parent_prefix, 42 | parent_pk_kwarg_name=compose_parent_pk_kwarg_name( 43 | parents_query_lookups[i]), 44 | parent_lookup_value_regex=parent_lookup_value_regex, 45 | prefix=prefix 46 | ) 47 | i -= 1 48 | current_item = current_item.parent_item 49 | return prefix.strip('/') 50 | 51 | 52 | class NestedRouterMixin: 53 | def _register(self, *args, **kwargs): 54 | return super().register(*args, **kwargs) 55 | 56 | def register(self, *args, **kwargs): 57 | self._register(*args, **kwargs) 58 | return NestedRegistryItem( 59 | router=self, 60 | parent_prefix=self.registry[-1][0], 61 | parent_viewset=self.registry[-1][1] 62 | ) 63 | 64 | 65 | class ExtendedRouterMixin(NestedRouterMixin): 66 | pass 67 | 68 | 69 | class ExtendedSimpleRouter(ExtendedRouterMixin, SimpleRouter): 70 | pass 71 | 72 | 73 | class ExtendedDefaultRouter(ExtendedRouterMixin, DefaultRouter): 74 | pass 75 | -------------------------------------------------------------------------------- /rest_framework_extensions/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.utils import get_model_opts_concrete_fields 2 | 3 | 4 | def get_fields_for_partial_update(opts, init_data, fields, init_files=None): 5 | opts = opts.model._meta.concrete_model._meta 6 | partial_fields = list((init_data or {}).keys()) + \ 7 | list((init_files or {}).keys()) 8 | concrete_field_names = [] 9 | for field in get_model_opts_concrete_fields(opts): 10 | if not field.primary_key: 11 | concrete_field_names.append(field.name) 12 | if field.name != field.attname: 13 | concrete_field_names.append(field.attname) 14 | update_fields = [] 15 | for field_name in partial_fields: 16 | if field_name in fields: 17 | model_field_name = getattr( 18 | fields[field_name], 'source') or field_name 19 | if model_field_name in concrete_field_names: 20 | update_fields.append(model_field_name) 21 | 22 | # recurse on nested fields of same ('*') instance 23 | for k, v in (init_data or {}).items(): 24 | if isinstance(v, dict) and k in fields and fields[k].source == '*': 25 | recursive_fields = get_fields_for_partial_update( 26 | opts, v, fields[k].fields.fields) 27 | update_fields.extend(recursive_fields) 28 | 29 | return sorted(set(update_fields)) 30 | 31 | 32 | class PartialUpdateSerializerMixin: 33 | def save(self, **kwargs): 34 | self._update_fields = kwargs.get('update_fields', None) 35 | return super().save(**kwargs) 36 | 37 | def update(self, instance, validated_attrs): 38 | for attr, value in validated_attrs.items(): 39 | if hasattr(getattr(instance, attr, None), 'set'): 40 | getattr(instance, attr).set(value) 41 | else: 42 | setattr(instance, attr, value) 43 | if self.partial and isinstance(instance, self.Meta.model): 44 | instance.save( 45 | update_fields=getattr(self, '_update_fields') or get_fields_for_partial_update( 46 | opts=self.Meta, 47 | init_data=self.get_initial(), 48 | fields=self.fields.fields 49 | ) 50 | ) 51 | else: 52 | instance.save() 53 | return instance 54 | -------------------------------------------------------------------------------- /rest_framework_extensions/settings.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | 3 | from rest_framework.settings import APISettings 4 | 5 | 6 | USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK_EXTENSIONS', None) 7 | 8 | DEFAULTS = { 9 | # caching 10 | 'DEFAULT_USE_CACHE': 'default', 11 | 'DEFAULT_CACHE_RESPONSE_TIMEOUT': None, 12 | 'DEFAULT_CACHE_ERRORS': True, 13 | 'DEFAULT_CACHE_KEY_FUNC': 'rest_framework_extensions.utils.default_cache_key_func', 14 | 'DEFAULT_OBJECT_CACHE_KEY_FUNC': 'rest_framework_extensions.utils.default_object_cache_key_func', 15 | 'DEFAULT_LIST_CACHE_KEY_FUNC': 'rest_framework_extensions.utils.default_list_cache_key_func', 16 | 17 | # ETAG 18 | 'DEFAULT_ETAG_FUNC': 'rest_framework_extensions.utils.default_etag_func', 19 | 'DEFAULT_OBJECT_ETAG_FUNC': 'rest_framework_extensions.utils.default_object_etag_func', 20 | 'DEFAULT_LIST_ETAG_FUNC': 'rest_framework_extensions.utils.default_list_etag_func', 21 | 22 | # API - ETAG 23 | 'DEFAULT_API_OBJECT_ETAG_FUNC': 'rest_framework_extensions.utils.default_api_object_etag_func', 24 | 'DEFAULT_API_LIST_ETAG_FUNC': 'rest_framework_extensions.utils.default_api_list_etag_func', 25 | 26 | # other 27 | 'DEFAULT_KEY_CONSTRUCTOR_MEMOIZE_FOR_REQUEST': False, 28 | 'DEFAULT_BULK_OPERATION_HEADER_NAME': 'X-BULK-OPERATION', 29 | 'DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX': 'parent_lookup_' 30 | } 31 | 32 | IMPORT_STRINGS = [ 33 | 'DEFAULT_CACHE_KEY_FUNC', 34 | 'DEFAULT_OBJECT_CACHE_KEY_FUNC', 35 | 'DEFAULT_LIST_CACHE_KEY_FUNC', 36 | 'DEFAULT_ETAG_FUNC', 37 | 'DEFAULT_OBJECT_ETAG_FUNC', 38 | 'DEFAULT_LIST_ETAG_FUNC', 39 | # API - ETAG 40 | 'DEFAULT_API_OBJECT_ETAG_FUNC', 41 | 'DEFAULT_API_LIST_ETAG_FUNC', 42 | ] 43 | 44 | 45 | extensions_api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) 46 | -------------------------------------------------------------------------------- /rest_framework_extensions/test.py: -------------------------------------------------------------------------------- 1 | # Leaving this module here for backwards compatibility but this is just proxy 2 | # for rest_framework.test 3 | 4 | import warnings 5 | 6 | from rest_framework.test import ( 7 | force_authenticate, 8 | APIRequestFactory, 9 | ForceAuthClientHandler, 10 | APIClient, 11 | APITransactionTestCase, 12 | APITestCase 13 | ) 14 | 15 | 16 | __all__ = ( 17 | 'force_authenticate,' 18 | 'APIRequestFactory,' 19 | 'ForceAuthClientHandler,' 20 | 'APIClient,' 21 | 'APITransactionTestCase,' 22 | 'APITestCase' 23 | ) 24 | 25 | warnings.warn( 26 | 'Use of this module is deprecated! Use rest_framework.test instead.', 27 | DeprecationWarning 28 | ) 29 | -------------------------------------------------------------------------------- /rest_framework_extensions/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from packaging.version import Version 3 | 4 | import rest_framework 5 | 6 | from rest_framework_extensions.key_constructor.constructors import ( 7 | DefaultKeyConstructor, 8 | DefaultObjectKeyConstructor, 9 | DefaultListKeyConstructor, 10 | DefaultAPIModelInstanceKeyConstructor, 11 | DefaultAPIModelListKeyConstructor 12 | ) 13 | from rest_framework_extensions.settings import extensions_api_settings 14 | 15 | 16 | def get_rest_framework_version(): 17 | return Version(rest_framework.VERSION).release 18 | 19 | 20 | def flatten(list_of_lists): 21 | """ 22 | Takes an iterable of iterables, 23 | returns a single iterable containing all items 24 | """ 25 | # todo: test me 26 | return itertools.chain(*list_of_lists) 27 | 28 | 29 | def prepare_header_name(name): 30 | """ 31 | >> prepare_header_name('Accept-Language') 32 | http_accept_language 33 | """ 34 | return 'http_{0}'.format(name.strip().replace('-', '_')).upper() 35 | 36 | 37 | def get_unique_method_id(view_instance, view_method): 38 | # todo: test me as UniqueMethodIdKeyBit 39 | return '.'.join([ 40 | view_instance.__module__, 41 | view_instance.__class__.__name__, 42 | view_method.__name__ 43 | ]) 44 | 45 | 46 | def get_model_opts_concrete_fields(opts): 47 | # todo: test me 48 | if not hasattr(opts, 'concrete_fields'): 49 | opts.concrete_fields = [f for f in opts.fields if f.column is not None] 50 | return opts.concrete_fields 51 | 52 | 53 | def compose_parent_pk_kwarg_name(value): 54 | return '{0}{1}'.format( 55 | extensions_api_settings.DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX, 56 | value 57 | ) 58 | 59 | 60 | default_cache_key_func = DefaultKeyConstructor() 61 | default_object_cache_key_func = DefaultObjectKeyConstructor() 62 | default_list_cache_key_func = DefaultListKeyConstructor() 63 | 64 | default_etag_func = default_cache_key_func 65 | default_object_etag_func = default_object_cache_key_func 66 | default_list_etag_func = default_list_cache_key_func 67 | 68 | # API (object-centered) functions 69 | default_api_object_etag_func = DefaultAPIModelInstanceKeyConstructor() 70 | default_api_list_etag_func = DefaultAPIModelListKeyConstructor() 71 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from setuptools import setup 3 | import re 4 | import os 5 | import sys 6 | 7 | 8 | def get_version(package): 9 | """ 10 | Return package version as listed in `__version__` in `init.py`. 11 | """ 12 | init_py = open(os.path.join(package, '__init__.py')).read() 13 | return re.match("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1) 14 | 15 | 16 | def get_packages(package): 17 | """ 18 | Return root package and all sub-packages. 19 | """ 20 | return [dirpath 21 | for dirpath, dirnames, filenames in os.walk(package) 22 | if os.path.exists(os.path.join(dirpath, '__init__.py'))] 23 | 24 | 25 | def get_package_data(package): 26 | """ 27 | Return all files under the root package, that are not in a 28 | package themselves. 29 | """ 30 | walk = [(dirpath.replace(package + os.sep, '', 1), filenames) 31 | for dirpath, dirnames, filenames in os.walk(package) 32 | if not os.path.exists(os.path.join(dirpath, '__init__.py'))] 33 | 34 | filepaths = [] 35 | for base, filenames in walk: 36 | filepaths.extend([os.path.join(base, filename) 37 | for filename in filenames]) 38 | return {package: filepaths} 39 | 40 | 41 | version = get_version('rest_framework_extensions') 42 | 43 | 44 | if sys.argv[-1] == 'publish': 45 | os.system("python setup.py sdist upload") 46 | os.system("python setup.py bdist_wheel upload") 47 | print("You probably want to also tag the version now:") 48 | print(" git tag -a %s -m 'version %s'" % (version, version)) 49 | print(" git push --tags") 50 | sys.exit() 51 | 52 | 53 | setup( 54 | name='drf-extensions', 55 | version=version, 56 | url='http://github.com/chibisov/drf-extensions', 57 | download_url='https://pypi.python.org/pypi/drf-extensions/', 58 | license='BSD', 59 | install_requires=['djangorestframework>=3.10.3', 'packaging>=24.1', 'Django>=2.2,<6.0'], 60 | description='Extensions for Django REST Framework', 61 | long_description='DRF-extensions is a collection of custom extensions for Django REST Framework', 62 | author='Asif Saif Uddin, Gennady Chibisov', 63 | author_email='auvipy@gmail.com', 64 | packages=get_packages('rest_framework_extensions'), 65 | package_data=get_package_data('rest_framework_extensions'), 66 | test_suite='rest_framework_extensions.runtests.runtests.main', 67 | classifiers=[ 68 | 'Development Status :: 5 - Production/Stable', 69 | 'Environment :: Web Environment', 70 | 'Framework :: Django', 71 | 'Intended Audience :: Developers', 72 | 'License :: OSI Approved :: MIT License', 73 | 'Operating System :: OS Independent', 74 | 'Framework :: Django', 75 | 'Framework :: Django :: 3.2', 76 | 'Framework :: Django :: 4.2', 77 | 'Framework :: Django :: 5.2', 78 | 'Programming Language :: Python', 79 | 'Programming Language :: Python :: 3', 80 | 'Programming Language :: Python :: 3.9', 81 | 'Programming Language :: Python :: 3.10', 82 | 'Programming Language :: Python :: 3.11', 83 | 'Programming Language :: Python :: 3.12', 84 | 'Topic :: Internet :: WWW/HTTP', 85 | ] 86 | ) 87 | -------------------------------------------------------------------------------- /tests_app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/__init__.py -------------------------------------------------------------------------------- /tests_app/plugins.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | from django_nose.plugin import AlwaysOnPlugin 5 | 6 | from django.test import TestCase 7 | from django.core.cache import cache 8 | from django.conf import settings 9 | 10 | 11 | class UnitTestDiscoveryPlugin(AlwaysOnPlugin): 12 | """ 13 | Enables unittest compatibility mode (dont test functions, only TestCase 14 | subclasses, and only methods that start with [Tt]est). 15 | """ 16 | enabled = True 17 | 18 | def wantModule(self, module): 19 | return True 20 | 21 | def wantFile(self, file): 22 | if file.endswith('.py'): 23 | return True 24 | 25 | def wantClass(self, cls): 26 | if not issubclass(cls, TestCase): 27 | return False 28 | 29 | def wantMethod(self, method): 30 | if not method.__name__.lower().startswith('test'): 31 | return False 32 | 33 | def wantFunction(self, function): 34 | return False 35 | 36 | 37 | class PrepareRestFrameworkSettingsPlugin(AlwaysOnPlugin): 38 | def begin(self): 39 | self._monkeypatch_default_settings() 40 | 41 | def _monkeypatch_default_settings(self): 42 | from rest_framework import settings 43 | 44 | PATCH_REST_FRAMEWORK = { 45 | # Testing 46 | 'TEST_REQUEST_RENDERER_CLASSES': ( 47 | 'rest_framework.renderers.MultiPartRenderer', 48 | 'rest_framework.renderers.JSONRenderer' 49 | ), 50 | 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', 51 | } 52 | 53 | for key, value in PATCH_REST_FRAMEWORK.items(): 54 | if key not in settings.DEFAULTS: 55 | settings.DEFAULTS[key] = value 56 | 57 | 58 | class PrepareFileStorageDir(AlwaysOnPlugin): 59 | def begin(self): 60 | if not os.path.isdir(settings.MEDIA_ROOT): 61 | os.makedirs(settings.MEDIA_ROOT) 62 | 63 | def finalize(self, result): 64 | shutil.rmtree(settings.MEDIA_ROOT, ignore_errors=True) 65 | 66 | 67 | class FlushCache(AlwaysOnPlugin): 68 | # startTest didn't work :( 69 | def begin(self): 70 | self._monkeypatch_testcase() 71 | 72 | def _monkeypatch_testcase(self): 73 | old_run = TestCase.run 74 | 75 | def new_run(*args, **kwargs): 76 | cache.clear() 77 | return old_run(*args, **kwargs) 78 | TestCase.run = new_run 79 | -------------------------------------------------------------------------------- /tests_app/requirements.txt: -------------------------------------------------------------------------------- 1 | pynose 2 | django-nose 3 | django-filter>=2.1.0 4 | mock 5 | ipdb 6 | uuid -------------------------------------------------------------------------------- /tests_app/settings.py: -------------------------------------------------------------------------------- 1 | # Django settings for testproject project. 2 | import multiprocessing 3 | 4 | DEBUG = True 5 | DEBUG_PROPAGATE_EXCEPTIONS = True 6 | 7 | ALLOWED_HOSTS = ['*'] 8 | 9 | ADMINS = ( 10 | # ('Your Name', 'your_email@domain.com'), 11 | ) 12 | 13 | MANAGERS = ADMINS 14 | 15 | DATABASES = { 16 | 'default': { 17 | 'ENGINE': 'django.db.backends.sqlite3', 18 | 'NAME': 'drf_extensions', 19 | 'TEST_CHARSET': 'utf8', 20 | }, 21 | } 22 | 23 | CACHES = { 24 | 'default': { 25 | 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', 26 | }, 27 | 'special_cache': { 28 | 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', 29 | }, 30 | 'another_special_cache': { 31 | 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', 32 | }, 33 | } 34 | 35 | # Local time zone for this installation. Choices can be found here: 36 | # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name 37 | # although not all choices may be available on all operating systems. 38 | # On Unix systems, a value of None will cause Django to use the same 39 | # timezone as the operating system. 40 | # If running in a Windows environment this must be set to the same as your 41 | # system time zone. 42 | TIME_ZONE = 'Europe/London' 43 | 44 | # Language code for this installation. All choices can be found here: 45 | # http://www.i18nguy.com/unicode/language-identifiers.html 46 | LANGUAGE_CODE = 'en-uk' 47 | 48 | SITE_ID = 1 49 | 50 | # If you set this to False, Django will make some optimizations so as not 51 | # to load the internationalization machinery. 52 | USE_I18N = True 53 | 54 | # If you set this to False, Django will not format dates, numbers and 55 | # calendars according to the current locale 56 | USE_L10N = True 57 | 58 | # Absolute filesystem path to the directory that will hold user-uploaded files. 59 | # Example: "/home/media/media.lawrence.com/" 60 | MEDIA_ROOT = 'tests_app/tests/files' 61 | 62 | # URL that handles the media served from MEDIA_ROOT. Make sure to use a 63 | # trailing slash if there is a path component (optional in other cases). 64 | # Examples: "http://media.lawrence.com", "http://example.com/media/" 65 | MEDIA_URL = '' 66 | 67 | # Make this unique, and don't share it with anybody. 68 | SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy' 69 | 70 | # List of callables that know how to import templates from various sources. 71 | TEMPLATES = [ 72 | { 73 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 74 | 'DIRS': [], 75 | 'APP_DIRS': True, 76 | 'OPTIONS': { 77 | 'context_processors': [ 78 | 'django.template.context_processors.debug', 79 | 'django.template.context_processors.request', 80 | 'django.contrib.auth.context_processors.auth', 81 | 'django.contrib.messages.context_processors.messages', 82 | ], 83 | }, 84 | }, 85 | ] 86 | 87 | 88 | MIDDLEWARE = [ 89 | 'django.middleware.security.SecurityMiddleware', 90 | 'django.contrib.sessions.middleware.SessionMiddleware', 91 | 'django.middleware.common.CommonMiddleware', 92 | 'django.middleware.csrf.CsrfViewMiddleware', 93 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 94 | 'django.contrib.messages.middleware.MessageMiddleware', 95 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 96 | ] 97 | 98 | ROOT_URLCONF = 'urls' 99 | 100 | 101 | INSTALLED_APPS = ( 102 | 'django.contrib.auth', 103 | 'django.contrib.contenttypes', 104 | 'django.contrib.sessions', 105 | 'django.contrib.sites', 106 | 'django.contrib.messages', 107 | 'django.contrib.admin', 108 | # Uncomment the next line to enable the admin: 109 | # 'django.contrib.admin', 110 | # Uncomment the next line to enable admin documentation: 111 | # 'django.contrib.admindocs', 112 | 'django_nose', 113 | 'guardian', 114 | 'rest_framework_extensions', 115 | 116 | 'tests_app.tests.functional', 117 | 'tests_app.tests.unit', 118 | ) 119 | 120 | STATIC_URL = '/static/' 121 | 122 | # Password validation 123 | # https://docs.djangoproject.com/en/{{ docs_version }}/ref/settings/#auth-password-validators 124 | 125 | AUTH_PASSWORD_VALIDATORS = [ 126 | { 127 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 128 | }, 129 | { 130 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 131 | }, 132 | { 133 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 134 | }, 135 | { 136 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 137 | }, 138 | ] 139 | 140 | AUTH_USER_MODEL = 'auth.User' 141 | 142 | 143 | TEST_RUNNER = 'django.test.runner.DiscoverRunner' 144 | 145 | NOSE_ARGS = [ 146 | '--processes=%s' % multiprocessing.cpu_count(), 147 | '--process-timeout=100', 148 | '--nocapture', 149 | ] 150 | 151 | NOSE_PLUGINS = [ 152 | 'plugins.UnitTestDiscoveryPlugin', 153 | 'plugins.PrepareRestFrameworkSettingsPlugin', 154 | 'plugins.FlushCache', 155 | 'plugins.PrepareFileStorageDir' 156 | ] 157 | 158 | # guardian 159 | ANONYMOUS_USER_ID = -1 160 | 161 | AUTHENTICATION_BACKENDS = ( 162 | 'django.contrib.auth.backends.ModelBackend', # this is default 163 | 'guardian.backends.ObjectPermissionBackend', 164 | ) 165 | 166 | DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' 167 | -------------------------------------------------------------------------------- /tests_app/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/functional/_concurrency/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/_concurrency/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/functional/_concurrency/conditional_request/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/_concurrency/conditional_request/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/functional/_concurrency/conditional_request/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class Book(models.Model): 5 | """A sample model for conditional requests.""" 6 | 7 | name = models.CharField(max_length=100, default=None, blank=True, null=True) 8 | author = models.CharField(max_length=100, default=None, blank=True, null=True) 9 | issn = models.CharField(max_length=100, default=None, blank=True, null=True) 10 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_concurrency/conditional_request/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | from .models import Book 3 | 4 | 5 | class BookSerializer(serializers.ModelSerializer): 6 | class Meta: 7 | model = Book 8 | fields = '__all__' 9 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_concurrency/conditional_request/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import re_path, include 2 | from rest_framework import routers 3 | from .views import (BookViewSet, BookListCreateView, BookChangeView, BookCustomDestroyView, 4 | BookUnconditionalDestroyView, BookUnconditionalUpdateView) 5 | 6 | router = routers.DefaultRouter() 7 | router.register(r'books', BookViewSet) 8 | 9 | urlpatterns = [ 10 | # manually add endpoints for APIView instances 11 | re_path(r'books_view/(?P[0-9]+)/custom/delete/', BookCustomDestroyView.as_view(), name='book_view-custom_delete'), 12 | re_path(r'books_view/(?P[0-9]+)/unconditional/delete/', BookUnconditionalDestroyView.as_view(), 13 | name='book_view-unconditional_delete'), 14 | re_path(r'books_view/(?P[0-9]+)/unconditional/update/', BookUnconditionalUpdateView.as_view(), 15 | name='book_view-unconditional_update'), 16 | re_path(r'books_view/', BookListCreateView.as_view(), name='book_view-list'), 17 | re_path(r'books_view/(?P[0-9]+)/', BookChangeView.as_view(), name='book_view-detail'), 18 | 19 | # include the URLs from the default viewset 20 | re_path(r'^', include(router.urls)), 21 | ] 22 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_concurrency/conditional_request/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets 2 | from rest_framework import generics 3 | from rest_framework import status 4 | from rest_framework.response import Response 5 | from rest_framework_extensions.etag.mixins import APIETAGMixin 6 | from rest_framework_extensions.etag.decorators import api_etag 7 | from rest_framework_extensions.utils import default_api_object_etag_func 8 | from .models import Book 9 | from .serializers import BookSerializer 10 | 11 | 12 | class BookViewSet(APIETAGMixin, 13 | viewsets.ModelViewSet): 14 | """Test the mixin with DRF viewset.""" 15 | 16 | queryset = Book.objects.all() 17 | serializer_class = BookSerializer 18 | 19 | 20 | class BookChangeView(APIETAGMixin, 21 | generics.RetrieveUpdateDestroyAPIView): 22 | """Test the mixin with DRF generic API views.""" 23 | 24 | queryset = Book.objects.all() 25 | serializer_class = BookSerializer 26 | 27 | 28 | class BookListCreateView(APIETAGMixin, 29 | generics.ListCreateAPIView): 30 | """Test the mixin with DRF generic API views.""" 31 | 32 | queryset = Book.objects.all() 33 | serializer_class = BookSerializer 34 | 35 | 36 | class BookCustomDestroyView(generics.DestroyAPIView): 37 | """Test the decorator with DRF generic API views.""" 38 | 39 | # include the queryset here to enable the object lookup in `@api_etag` 40 | queryset = Book.objects.all() 41 | 42 | @api_etag(etag_func=default_api_object_etag_func) 43 | def delete(self, request, *args, **kwargs): 44 | obj = Book.objects.get(id=kwargs['pk']) 45 | obj.delete() 46 | return Response(status=status.HTTP_204_NO_CONTENT) 47 | 48 | 49 | class BookUnconditionalDestroyView(generics.DestroyAPIView): 50 | """Test the decorator with DRF generic API views.""" 51 | 52 | # include the queryset here to enable the object lookup in `@api_etag` 53 | queryset = Book.objects.all() 54 | 55 | @api_etag(etag_func=default_api_object_etag_func, precondition_map={}) 56 | def delete(self, request, *args, **kwargs): 57 | obj = Book.objects.get(id=kwargs['pk']) 58 | obj.delete() 59 | return Response(status=status.HTTP_204_NO_CONTENT) 60 | 61 | 62 | class BookUnconditionalUpdateView(generics.UpdateAPIView): 63 | """Test the decorator with DRF generic API views.""" 64 | 65 | # include the queryset here to enable the object lookup in `@api_etag` 66 | queryset = Book.objects.all() 67 | serializer_class = BookSerializer 68 | 69 | @api_etag(etag_func=default_api_object_etag_func, precondition_map={}) 70 | def update(self, request, *args, **kwargs): 71 | return super().update(request, *args, **kwargs) -------------------------------------------------------------------------------- /tests_app/tests/functional/_examples/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_examples/etags/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/middleware.py: -------------------------------------------------------------------------------- 1 | try: 2 | from django.utils.deprecation import MiddlewareMixin 3 | except ImportError: 4 | MiddlewareMixin = object 5 | 6 | 7 | class RemoveEtagGzipPostfix(MiddlewareMixin): 8 | def process_response(self, request, response): 9 | if response.has_header('ETag') and response['ETag'][-6:] == ';gzip"': 10 | response['ETag'] = response['ETag'][:-6] + '"' 11 | return response 12 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase, override_settings 2 | from django.conf import settings 3 | 4 | @override_settings(ROOT_URLCONF='tests_app.tests.functional._examples.etags.remove_etag_gzip_postfix.urls') 5 | class RemoveEtagGzipPostfixTest(TestCase): 6 | 7 | @override_settings(MIDDLEWARE_CLASSES=( 8 | 'django.middleware.common.CommonMiddleware', 9 | 'django.middleware.gzip.GZipMiddleware', 10 | )) 11 | def test_without_middleware(self): 12 | response = self.client.get('/remove-etag-gzip-postfix/', **{ 13 | 'HTTP_ACCEPT_ENCODING': 'gzip' 14 | }) 15 | 16 | self.assertEqual(response.status_code, 200) 17 | # previously it was '"etag_value;gzip"' , instead of '"etag_value"', gzip don't append ;gzip suffix after encoding 18 | self.assertEqual(response['ETag'], '"etag_value"') 19 | 20 | @override_settings(MIDDLEWARE_CLASSES=( 21 | 'django.middleware.common.CommonMiddleware', 22 | 'django.middleware.gzip.GZipMiddleware', 23 | 'tests_app.tests.functional._examples.etags.remove_etag_gzip_postfix.middleware.RemoveEtagGzipPostfix', 24 | )) 25 | def test_with_middleware(self): 26 | response = self.client.get('/remove-etag-gzip-postfix/', **{ 27 | 'HTTP_ACCEPT_ENCODING': 'gzip' 28 | }) 29 | self.assertEqual(response.status_code, 200) 30 | self.assertEqual(response['ETag'], '"etag_value"') 31 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import re_path 2 | 3 | from .views import MyView 4 | 5 | 6 | urlpatterns = [ 7 | re_path(r'^remove-etag-gzip-postfix/$', MyView.as_view()), 8 | ] 9 | -------------------------------------------------------------------------------- /tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/views.py: -------------------------------------------------------------------------------- 1 | from django.views import View 2 | from django.http import HttpResponse 3 | 4 | 5 | class MyView(View): 6 | def get(self, request): 7 | """ 8 | GZipMiddleware will NOT compress content if any of the following are true: 9 | * The content body is less than 200 bytes long. 10 | * The response has already set the Content-Encoding header. 11 | * The request (the browser) hasn’t sent an Accept-Encoding header containing gzip. 12 | """ 13 | response = HttpResponse('r' * 300) 14 | response['ETag'] = '"etag_value"' 15 | return response 16 | -------------------------------------------------------------------------------- /tests_app/tests/functional/cache/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/cache/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/cache/decorators/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase, override_settings 2 | from django.utils.encoding import force_str 3 | 4 | 5 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.cache.decorators.urls') 6 | class TestCacheResponseFunctionally(TestCase): 7 | 8 | def test_should_return_response(self): 9 | resp = self.client.get('/hello/') 10 | self.assertEqual(force_str(resp.content), '"Hello world"') 11 | 12 | def test_should_return_same_response_if_cached(self): 13 | resp_1 = self.client.get('/hello/') 14 | resp_2 = self.client.get('/hello/') 15 | self.assertEqual(resp_1.content, resp_2.content) 16 | -------------------------------------------------------------------------------- /tests_app/tests/functional/cache/decorators/urls.py: -------------------------------------------------------------------------------- 1 | from django.urls import re_path 2 | 3 | from .views import HelloView 4 | 5 | 6 | urlpatterns = [ 7 | re_path(r'^hello/$', HelloView.as_view(), name='hello'), 8 | ] 9 | -------------------------------------------------------------------------------- /tests_app/tests/functional/cache/decorators/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import views 2 | from rest_framework.response import Response 3 | 4 | from rest_framework_extensions.cache.decorators import cache_response 5 | 6 | 7 | class HelloView(views.APIView): 8 | @cache_response() 9 | def get(self, request, *args, **kwargs): 10 | return Response('Hello world') 11 | -------------------------------------------------------------------------------- /tests_app/tests/functional/key_constructor/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/key_constructor/bits/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/key_constructor/bits/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class KeyConstructorUserProperty(models.Model): 5 | name = models.CharField(max_length=100) 6 | 7 | 8 | class KeyConstructorUserModel(models.Model): 9 | property = models.ForeignKey( 10 | KeyConstructorUserProperty, 11 | on_delete=models.CASCADE 12 | ) 13 | -------------------------------------------------------------------------------- /tests_app/tests/functional/key_constructor/bits/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from .models import KeyConstructorUserModel 4 | 5 | 6 | class UserModelSerializer(serializers.ModelSerializer): 7 | class Meta: 8 | model = KeyConstructorUserModel 9 | fields = '__all__' 10 | -------------------------------------------------------------------------------- /tests_app/tests/functional/key_constructor/bits/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import override_settings 2 | 3 | from rest_framework.test import APITestCase 4 | 5 | from .models import KeyConstructorUserProperty 6 | 7 | 8 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.key_constructor.bits.urls') 9 | class ListSqlQueryKeyBitTestBehaviour(APITestCase): 10 | """Regression tests for https://github.com/chibisov/drf-extensions/issues/28#issuecomment-51711927 11 | 12 | `rest_framework.filters.DjangoFilterBackend` uses default `FilterSet`. 13 | When there is no filtered fk in db, then `FilterSet.form` is invalid with errors: 14 | {'property': [u'Select a valid choice. That choice is not one of the available choices.']} 15 | In that case `FilterSet.qs` returns `self.queryset.none()` 16 | """ 17 | 18 | def test_with_fk_in_db(self): 19 | KeyConstructorUserProperty.objects.create(name='some property') 20 | 21 | # list 22 | response = self.client.get('/users/?property=1') 23 | self.assertEqual(response.status_code, 200) 24 | 25 | # retrieve 26 | response = self.client.get('/users/1/?property=1') 27 | self.assertEqual(response.status_code, 404) 28 | 29 | def test_without_fk_in_db(self): 30 | # list 31 | response = self.client.get('/users/?property=1') 32 | self.assertEqual(response.status_code, 400) 33 | 34 | # retrieve 35 | response = self.client.get('/users/1/?property=1') 36 | self.assertEqual(response.status_code, 400) 37 | -------------------------------------------------------------------------------- /tests_app/tests/functional/key_constructor/bits/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework import routers 2 | 3 | from .views import UserModelViewSet 4 | 5 | 6 | viewset_router = routers.DefaultRouter() 7 | viewset_router.register('users', UserModelViewSet) 8 | urlpatterns = viewset_router.urls 9 | -------------------------------------------------------------------------------- /tests_app/tests/functional/key_constructor/bits/views.py: -------------------------------------------------------------------------------- 1 | import django_filters 2 | from rest_framework import viewsets 3 | 4 | from .models import KeyConstructorUserModel as UserModel 5 | from .serializers import UserModelSerializer 6 | 7 | 8 | class UserModelViewSet(viewsets.ModelViewSet): 9 | queryset = UserModel.objects.all() 10 | serializer_class = UserModelSerializer 11 | filter_backends = (django_filters.rest_framework.DjangoFilterBackend,) 12 | filterset_fields = ('property',) 13 | -------------------------------------------------------------------------------- /tests_app/tests/functional/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 2.2 on 2019-04-16 11:51 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | initial = True 10 | 11 | dependencies = [ 12 | ('contenttypes', '0002_remove_content_type_name'), 13 | ] 14 | 15 | operations = [ 16 | migrations.CreateModel( 17 | name='Comment', 18 | fields=[ 19 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 20 | ('email', models.EmailField(max_length=254)), 21 | ('content', models.CharField(max_length=200)), 22 | ('created', models.DateTimeField(auto_now_add=True)), 23 | ], 24 | ), 25 | migrations.CreateModel( 26 | name='CommentForListDestroyModelMixin', 27 | fields=[ 28 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 29 | ('email', models.EmailField(max_length=254)), 30 | ], 31 | ), 32 | migrations.CreateModel( 33 | name='CommentForListUpdateModelMixin', 34 | fields=[ 35 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 36 | ('email', models.EmailField(max_length=254)), 37 | ], 38 | ), 39 | migrations.CreateModel( 40 | name='CommentForPaginateByMaxMixin', 41 | fields=[ 42 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 43 | ('email', models.EmailField(max_length=254)), 44 | ('content', models.CharField(max_length=200)), 45 | ('created', models.DateTimeField(auto_now_add=True)), 46 | ], 47 | options={ 48 | 'ordering': ['id'], 49 | }, 50 | ), 51 | migrations.CreateModel( 52 | name='DefaultRouterGroupModel', 53 | fields=[ 54 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 55 | ('name', models.CharField(max_length=10)), 56 | ], 57 | ), 58 | migrations.CreateModel( 59 | name='DefaultRouterPermissionModel', 60 | fields=[ 61 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 62 | ('name', models.CharField(max_length=10)), 63 | ], 64 | ), 65 | migrations.CreateModel( 66 | name='KeyConstructorUserProperty', 67 | fields=[ 68 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 69 | ('name', models.CharField(max_length=100)), 70 | ], 71 | ), 72 | migrations.CreateModel( 73 | name='NestedRouterMixinBookModel', 74 | fields=[ 75 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 76 | ('title', models.CharField(max_length=30)), 77 | ], 78 | ), 79 | migrations.CreateModel( 80 | name='NestedRouterMixinGroupModel', 81 | fields=[ 82 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 83 | ('name', models.CharField(max_length=10)), 84 | ], 85 | ), 86 | migrations.CreateModel( 87 | name='NestedRouterMixinPermissionModel', 88 | fields=[ 89 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 90 | ('name', models.CharField(max_length=10)), 91 | ], 92 | ), 93 | migrations.CreateModel( 94 | name='NestedRouterMixinTaskModel', 95 | fields=[ 96 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 97 | ('title', models.CharField(max_length=30)), 98 | ], 99 | ), 100 | migrations.CreateModel( 101 | name='PermissionsComment', 102 | fields=[ 103 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 104 | ('text', models.CharField(max_length=100)), 105 | ], 106 | ), 107 | migrations.CreateModel( 108 | name='RouterTestModel', 109 | fields=[ 110 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 111 | ('uuid', models.CharField(max_length=20)), 112 | ('text', models.CharField(max_length=200)), 113 | ], 114 | ), 115 | migrations.CreateModel( 116 | name='UserForListUpdateModelMixin', 117 | fields=[ 118 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 119 | ('email', models.EmailField(max_length=254)), 120 | ('name', models.CharField(max_length=10)), 121 | ('age', models.IntegerField()), 122 | ('last_name', models.CharField(max_length=10)), 123 | ('password', models.CharField(max_length=100)), 124 | ], 125 | ), 126 | migrations.CreateModel( 127 | name='NestedRouterMixinUserModel', 128 | fields=[ 129 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 130 | ('email', models.EmailField(blank=True, max_length=254, null=True)), 131 | ('name', models.CharField(max_length=10)), 132 | ('groups', models.ManyToManyField(related_name='user_groups', to='functional.NestedRouterMixinGroupModel')), 133 | ], 134 | ), 135 | migrations.AddField( 136 | model_name='nestedroutermixingroupmodel', 137 | name='permissions', 138 | field=models.ManyToManyField(to='functional.NestedRouterMixinPermissionModel'), 139 | ), 140 | migrations.CreateModel( 141 | name='NestedRouterMixinCommentModel', 142 | fields=[ 143 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 144 | ('object_id', models.PositiveIntegerField(blank=True, null=True)), 145 | ('text', models.CharField(max_length=30)), 146 | ('content_type', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='contenttypes.ContentType')), 147 | ], 148 | ), 149 | migrations.CreateModel( 150 | name='KeyConstructorUserModel', 151 | fields=[ 152 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 153 | ('property', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='functional.KeyConstructorUserProperty')), 154 | ], 155 | ), 156 | migrations.CreateModel( 157 | name='DefaultRouterUserModel', 158 | fields=[ 159 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 160 | ('name', models.CharField(max_length=10)), 161 | ('groups', models.ManyToManyField(related_name='user_groups', to='functional.DefaultRouterGroupModel')), 162 | ], 163 | ), 164 | migrations.AddField( 165 | model_name='defaultroutergroupmodel', 166 | name='permissions', 167 | field=models.ManyToManyField(to='functional.DefaultRouterPermissionModel'), 168 | ), 169 | migrations.CreateModel( 170 | name='Book', 171 | fields=[ 172 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 173 | ('name', models.CharField(max_length=100, default=None, blank=True, null=True)), 174 | ('author', models.CharField(max_length=100, default=None, blank=True, null=True)), 175 | ('issn', models.CharField(max_length=100, default=None, blank=True, null=True)), 176 | ], 177 | ), 178 | ] 179 | -------------------------------------------------------------------------------- /tests_app/tests/functional/migrations/0002_nestedroutermixinusermodel_code.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 3.2.18 on 2023-03-26 11:14 2 | 3 | from django.db import migrations, models 4 | import uuid 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | dependencies = [ 10 | ('functional', '0001_initial'), 11 | ] 12 | 13 | operations = [ 14 | migrations.AddField( 15 | model_name='nestedroutermixinusermodel', 16 | name='code', 17 | field=models.UUIDField(default=uuid.uuid4), 18 | ), 19 | ] 20 | -------------------------------------------------------------------------------- /tests_app/tests/functional/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/migrations/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/mixins/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/detail_serializer_mixin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/detail_serializer_mixin/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class Comment(models.Model): 5 | email = models.EmailField() 6 | content = models.CharField(max_length=200) 7 | created = models.DateTimeField(auto_now_add=True) 8 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/detail_serializer_mixin/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from .models import Comment 4 | 5 | 6 | class CommentSerializer(serializers.ModelSerializer): 7 | class Meta: 8 | model = Comment 9 | fields = ( 10 | 'id', 11 | 'email', 12 | ) 13 | 14 | 15 | class CommentDetailSerializer(serializers.ModelSerializer): 16 | class Meta: 17 | model = Comment 18 | fields = ( 19 | 'id', 20 | 'email', 21 | 'content', 22 | ) -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/detail_serializer_mixin/tests.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from django.test import TestCase, override_settings 4 | 5 | # todo: use from rest_framework when released 6 | from rest_framework.test import APIRequestFactory 7 | from .models import Comment 8 | 9 | 10 | factory = APIRequestFactory() 11 | 12 | 13 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.detail_serializer_mixin.urls') 14 | class DetailSerializerMixinTest_serializer_detail_class(TestCase): 15 | 16 | def setUp(self): 17 | self.comment = Comment.objects.create( 18 | id=1, 19 | email='example@ya.ru', 20 | content='Hello world', 21 | created=datetime.datetime.now() 22 | ) 23 | 24 | def test_serializer_class_response(self): 25 | resp = self.client.get('/comments/') 26 | expected = [{ 27 | 'id': 1, 28 | 'email': 'example@ya.ru' 29 | }] 30 | self.assertEqual(resp.data, expected) 31 | 32 | def test_serializer_detail_class_response(self): 33 | resp = self.client.get('/comments/1/') 34 | expected = { 35 | 'id': 1, 36 | 'email': 'example@ya.ru', 37 | 'content': 'Hello world', 38 | } 39 | self.assertEqual(resp.data, expected, 'should use detail serializer for detail endpoint') 40 | 41 | def test_view_with_mixin_and_without__serializer_detail_class__should_raise_exception(self): 42 | msg = "'CommentWithoutDetailSerializerClassViewSet' should include a 'serializer_detail_class' attribute" 43 | self.assertRaisesMessage(AssertionError, msg, self.client.get, '/comments-2/') 44 | 45 | 46 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.detail_serializer_mixin.urls') 47 | class DetailSerializerMixin_queryset_detail(TestCase): 48 | 49 | def setUp(self): 50 | self.comments = [ 51 | Comment.objects.create( 52 | id=1, 53 | email='example@ya.ru', 54 | content='Hello world', 55 | created=datetime.datetime.now() 56 | ), 57 | Comment.objects.create( 58 | id=2, 59 | email='example2@ya.ru', 60 | content='Hello world 2', 61 | created=datetime.datetime.now() 62 | ), 63 | ] 64 | 65 | def test_list_should_use_default_queryset_method(self): 66 | resp = self.client.get('/comments-3/') 67 | expected = [{ 68 | 'id': 2, 69 | 'email': 'example2@ya.ru' 70 | }] 71 | self.assertEqual(resp.data, expected) 72 | 73 | def test_detail_view_should_use_default_queryset_if_queryset_detail_not_specified(self): 74 | resp = self.client.get('/comments-3/1/') 75 | self.assertEqual(resp.status_code, 404) 76 | 77 | resp = self.client.get('/comments-3/2/') 78 | expected = { 79 | 'id': 2, 80 | 'email': 'example2@ya.ru', 81 | 'content': 'Hello world 2', 82 | } 83 | self.assertEqual(resp.data, expected) 84 | 85 | def test_list_should_use_default_queryset_method_if_queryset_detail_specified(self): 86 | resp = self.client.get('/comments-4/') 87 | expected = [{ 88 | 'id': 2, 89 | 'email': 'example2@ya.ru' 90 | }] 91 | self.assertEqual(resp.data, expected) 92 | 93 | def test_detail_view_should_use_custom_queryset_if_queryset_detail_specified(self): 94 | resp = self.client.get('/comments-4/2/') 95 | self.assertEqual(resp.status_code, 404) 96 | 97 | resp = self.client.get('/comments-4/1/') 98 | expected = { 99 | 'id': 1, 100 | 'email': 'example@ya.ru', 101 | 'content': 'Hello world', 102 | } 103 | self.assertEqual(resp.data, expected) 104 | 105 | def test_nested_model_view_with_mixin_should_use_get_detail_queryset(self): 106 | """ 107 | Regression tests for https://github.com/chibisov/drf-extensions/pull/24 108 | """ 109 | resp = self.client.get('/comments-5/1/') 110 | expected = { 111 | 'id': 1, 112 | 'email': 'example@ya.ru', 113 | 'content': 'Hello world', 114 | } 115 | self.assertEqual(resp.status_code, 200) 116 | self.assertEqual(resp.data, expected) 117 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/detail_serializer_mixin/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework import routers 2 | 3 | from .views import ( 4 | CommentViewSet, 5 | CommentWithoutDetailSerializerClassViewSet, 6 | CommentWithIdTwoViewSet, 7 | CommentWithIdTwoAndIdOneForDetailViewSet, 8 | CommentWithDetailSerializerAndNoArgsForGetQuerySetViewSet 9 | ) 10 | 11 | 12 | viewset_router = routers.DefaultRouter() 13 | viewset_router.register('comments', CommentViewSet, basename='alt1') 14 | viewset_router.register('comments-2', CommentWithoutDetailSerializerClassViewSet, basename='alt2') 15 | viewset_router.register('comments-3', CommentWithIdTwoViewSet, basename='alt3') 16 | viewset_router.register('comments-4', CommentWithIdTwoAndIdOneForDetailViewSet, basename='alt4') 17 | viewset_router.register('comments-5', CommentWithDetailSerializerAndNoArgsForGetQuerySetViewSet, basename='alt5') 18 | 19 | urlpatterns = viewset_router.urls 20 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/detail_serializer_mixin/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets 2 | from rest_framework_extensions.mixins import DetailSerializerMixin 3 | 4 | from .models import Comment 5 | from .serializers import CommentSerializer, CommentDetailSerializer 6 | 7 | 8 | class CommentViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet): 9 | serializer_class = CommentSerializer 10 | serializer_detail_class = CommentDetailSerializer 11 | queryset = Comment.objects.all() 12 | 13 | 14 | class CommentWithoutDetailSerializerClassViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet): 15 | serializer_class = CommentSerializer 16 | queryset = Comment.objects.all() 17 | 18 | 19 | class CommentWithIdTwoViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet): 20 | serializer_class = CommentSerializer 21 | serializer_detail_class = CommentDetailSerializer 22 | queryset = Comment.objects.filter(id=2) 23 | 24 | 25 | class CommentWithIdTwoAndIdOneForDetailViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet): 26 | serializer_class = CommentSerializer 27 | serializer_detail_class = CommentDetailSerializer 28 | queryset = Comment.objects.filter(id=2) 29 | queryset_detail = Comment.objects.filter(id=1) 30 | 31 | 32 | class CommentWithDetailSerializerAndNoArgsForGetQuerySetViewSet(DetailSerializerMixin, viewsets.ModelViewSet): 33 | """ 34 | For regression tests https://github.com/chibisov/drf-extensions/pull/24 35 | """ 36 | serializer_class = CommentSerializer 37 | serializer_detail_class = CommentDetailSerializer 38 | queryset = Comment.objects.all() 39 | queryset_detail = Comment.objects.filter(id=1) 40 | 41 | def get_queryset(self): 42 | return super().get_queryset() 43 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_destroy_model_mixin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_destroy_model_mixin/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class CommentForListDestroyModelMixin(models.Model): 5 | email = models.EmailField() 6 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_destroy_model_mixin/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import override_settings 2 | 3 | from rest_framework.test import APITestCase 4 | from rest_framework_extensions.settings import extensions_api_settings 5 | from rest_framework_extensions import utils 6 | 7 | from .models import CommentForListDestroyModelMixin as Comment 8 | from tests_app.testutils import override_extensions_api_settings 9 | 10 | 11 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.list_destroy_model_mixin.urls') 12 | class ListDestroyModelMixinTest(APITestCase): 13 | 14 | def setUp(self): 15 | self.comments = [ 16 | Comment.objects.create( 17 | id=1, 18 | email='example@ya.ru' 19 | ), 20 | Comment.objects.create( 21 | id=2, 22 | email='example@gmail.com' 23 | ) 24 | ] 25 | self.protection_headers = { 26 | utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true' 27 | } 28 | 29 | def test_simple_response(self): 30 | resp = self.client.get('/comments/') 31 | expected = [ 32 | { 33 | 'id': 1, 34 | 'email': 'example@ya.ru' 35 | }, 36 | { 37 | 'id': 2, 38 | 'email': 'example@gmail.com' 39 | } 40 | ] 41 | self.assertEqual(resp.data, expected) 42 | 43 | def test_filter_works(self): 44 | resp = self.client.get('/comments/?id=1') 45 | expected = [ 46 | { 47 | 'id': 1, 48 | 'email': 'example@ya.ru' 49 | } 50 | ] 51 | self.assertEqual(resp.data, expected) 52 | 53 | def test_destroy_instance(self): 54 | resp = self.client.delete('/comments/1/') 55 | self.assertEqual(resp.status_code, 204) 56 | self.assertFalse(1 in Comment.objects.values_list('pk', flat=True)) 57 | 58 | def test_bulk_destroy__without_protection_header(self): 59 | resp = self.client.delete('/comments/') 60 | self.assertEqual(resp.status_code, 400) 61 | expected_message = { 62 | 'detail': 'Header \'{0}\' should be provided for bulk operation.'.format( 63 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME 64 | ) 65 | } 66 | self.assertEqual(resp.data, expected_message) 67 | 68 | def test_bulk_destroy__with_protection_header(self): 69 | resp = self.client.delete('/comments/', **self.protection_headers) 70 | self.assertEqual(resp.status_code, 204) 71 | self.assertEqual(Comment.objects.count(), 0) 72 | 73 | @override_extensions_api_settings(DEFAULT_BULK_OPERATION_HEADER_NAME=None) 74 | def test_bulk_destroy__without_protection_header__and_with_turned_off_protection_header(self): 75 | resp = self.client.delete('/comments/') 76 | self.assertEqual(resp.status_code, 204) 77 | self.assertEqual(Comment.objects.count(), 0) 78 | 79 | def test_bulk_destroy__should_destroy_filtered_queryset(self): 80 | resp = self.client.delete('/comments/?id=1', **self.protection_headers) 81 | self.assertEqual(resp.status_code, 204) 82 | self.assertEqual(Comment.objects.count(), 1) 83 | self.assertEqual(Comment.objects.all()[0], self.comments[1]) 84 | 85 | def test_bulk_destroy__should_not_destroy_if_client_has_no_permissions(self): 86 | resp = self.client.delete('/comments-with-permission/', **self.protection_headers) 87 | self.assertEqual(resp.status_code, 404) 88 | self.assertEqual(Comment.objects.count(), 2) 89 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_destroy_model_mixin/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework import routers 2 | 3 | from .views import CommentViewSet, CommentViewSetWithPermissions 4 | 5 | 6 | viewset_router = routers.DefaultRouter() 7 | viewset_router.register('comments', CommentViewSet, basename='alt1') 8 | viewset_router.register('comments-with-permissions', CommentViewSetWithPermissions, basename='alt2') 9 | urlpatterns = viewset_router.urls 10 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_destroy_model_mixin/views.py: -------------------------------------------------------------------------------- 1 | import django_filters 2 | from rest_framework import viewsets, serializers 3 | from rest_framework import filters 4 | from rest_framework.permissions import DjangoModelPermissions 5 | from rest_framework_extensions.bulk_operations.mixins import ListDestroyModelMixin 6 | 7 | from .models import CommentForListDestroyModelMixin as Comment 8 | 9 | 10 | class CommentFilter(django_filters.FilterSet): 11 | class Meta: 12 | model = Comment 13 | fields = [ 14 | 'id' 15 | ] 16 | 17 | 18 | class CommentSerializer(serializers.ModelSerializer): 19 | class Meta: 20 | model = Comment 21 | fields = '__all__' 22 | 23 | 24 | class CommentViewSet(ListDestroyModelMixin, viewsets.ModelViewSet): 25 | queryset = Comment.objects.all() 26 | serializer_class = CommentSerializer 27 | filter_backends = (django_filters.rest_framework.DjangoFilterBackend,) 28 | filterset_class = CommentFilter 29 | 30 | 31 | class CommentViewSetWithPermissions(CommentViewSet): 32 | permission_classes = (DjangoModelPermissions,) 33 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_update_model_mixin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_update_model_mixin/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class CommentForListUpdateModelMixin(models.Model): 5 | email = models.EmailField() 6 | 7 | 8 | class UserForListUpdateModelMixin(models.Model): 9 | email = models.EmailField() 10 | name = models.CharField(max_length=10) 11 | age = models.IntegerField() 12 | last_name = models.CharField(max_length=10) 13 | password = models.CharField(max_length=100) 14 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_update_model_mixin/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | from .models import ( 3 | UserForListUpdateModelMixin as User, 4 | CommentForListUpdateModelMixin as Comment, 5 | ) 6 | 7 | 8 | class UserSerializer(serializers.ModelSerializer): 9 | surname = serializers.CharField(source='last_name') 10 | 11 | class Meta: 12 | model = User 13 | extra_kwargs = {'password': {'write_only': True}} 14 | read_only_fields = ('name',) 15 | fields = [ 16 | 'id', 17 | 'age', 18 | 'name', 19 | 'surname', 20 | 'password' 21 | ] 22 | 23 | 24 | class CommentSerializer(serializers.ModelSerializer): 25 | class Meta: 26 | model = Comment 27 | fields = '__all__' 28 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_update_model_mixin/tests.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import unittest 4 | 5 | import django 6 | from django.test import override_settings 7 | 8 | from rest_framework.test import APITestCase 9 | from rest_framework_extensions.settings import extensions_api_settings 10 | from rest_framework_extensions import utils 11 | 12 | from .models import ( 13 | CommentForListUpdateModelMixin as Comment, 14 | UserForListUpdateModelMixin as User 15 | ) 16 | from tests_app.testutils import override_extensions_api_settings 17 | 18 | 19 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.list_update_model_mixin.urls') 20 | class ListUpdateModelMixinTest(APITestCase): 21 | 22 | def setUp(self): 23 | self.comments = [ 24 | Comment.objects.create( 25 | id=1, 26 | email='example@ya.ru' 27 | ), 28 | Comment.objects.create( 29 | id=2, 30 | email='example@gmail.com' 31 | ) 32 | ] 33 | self.protection_headers = { 34 | utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true' 35 | } 36 | self.patch_data = { 37 | 'email': 'example@yandex.ru' 38 | } 39 | 40 | def test_simple_response(self): 41 | resp = self.client.get('/comments/') 42 | expected = [ 43 | { 44 | 'id': 1, 45 | 'email': 'example@ya.ru' 46 | }, 47 | { 48 | 'id': 2, 49 | 'email': 'example@gmail.com' 50 | } 51 | ] 52 | self.assertEqual(resp.data, expected) 53 | 54 | def test_filter_works(self): 55 | resp = self.client.get('/comments/?id=1') 56 | expected = [ 57 | { 58 | 'id': 1, 59 | 'email': 'example@ya.ru' 60 | } 61 | ] 62 | self.assertEqual(resp.data, expected) 63 | 64 | def test_update_instance(self): 65 | data = { 66 | 'id': 1, 67 | 'email': 'example@yandex.ru' 68 | } 69 | resp = self.client.put('/comments/1/', data=json.dumps(data), content_type='application/json') 70 | self.assertEqual(resp.status_code, 200) 71 | self.assertEqual(Comment.objects.get(pk=1).email, 'example@yandex.ru') 72 | 73 | def test_partial_update_instance(self): 74 | data = { 75 | 'id': 1, 76 | 'email': 'example@yandex.ru' 77 | } 78 | resp = self.client.patch('/comments/1/', data=json.dumps(data), content_type='application/json') 79 | self.assertEqual(resp.status_code, 200) 80 | self.assertEqual(Comment.objects.get(pk=1).email, 'example@yandex.ru') 81 | 82 | def test_bulk_partial_update__without_protection_header(self): 83 | resp = self.client.patch('/comments/', data=json.dumps(self.patch_data), content_type='application/json') 84 | self.assertEqual(resp.status_code, 400) 85 | expected_message = { 86 | 'detail': 'Header \'{0}\' should be provided for bulk operation.'.format( 87 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME 88 | ) 89 | } 90 | self.assertEqual(resp.data, expected_message) 91 | 92 | def test_bulk_partial_update__with_protection_header(self): 93 | resp = self.client.patch('/comments/', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers) 94 | self.assertEqual(resp.status_code, 204) 95 | for comment in Comment.objects.all(): 96 | self.assertEqual(comment.email, self.patch_data['email']) 97 | 98 | @override_extensions_api_settings(DEFAULT_BULK_OPERATION_HEADER_NAME=None) 99 | def test_bulk_partial_update__without_protection_header__and_with_turned_off_protection_header(self): 100 | resp = self.client.patch('/comments/', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers) 101 | self.assertEqual(resp.status_code, 204) 102 | for comment in Comment.objects.all(): 103 | self.assertEqual(comment.email, self.patch_data['email']) 104 | 105 | def test_bulk_partial_update__should_update_filtered_queryset(self): 106 | resp = self.client.patch('/comments/?id=1', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers) 107 | self.assertEqual(resp.status_code, 204) 108 | self.assertEqual(Comment.objects.get(pk=1).email, self.patch_data['email']) 109 | self.assertEqual(Comment.objects.get(pk=2).email, self.comments[1].email) 110 | 111 | def test_bulk_partial_update__should_not_update_if_client_has_no_permissions(self): 112 | resp = self.client.patch('/comments-with-permission/', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers) 113 | self.assertEqual(resp.status_code, 404) 114 | for i, comment in enumerate(Comment.objects.all()): 115 | self.assertEqual(comment.email, self.comments[i].email) 116 | 117 | 118 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.list_update_model_mixin.urls') 119 | class ListUpdateModelMixinTestBehaviour__serializer_fields(APITestCase): 120 | 121 | def setUp(self): 122 | self.user = User.objects.create( 123 | id=1, 124 | name='Gennady', 125 | age=24, 126 | last_name='Chibisov', 127 | email='example@ya.ru', 128 | password='somepassword' 129 | ) 130 | self.headers = { 131 | utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true' 132 | } 133 | 134 | def get_fresh_user(self): 135 | return User.objects.get(pk=self.user.pk) 136 | 137 | def test_simple_response(self): 138 | resp = self.client.get('/users/') 139 | expected = [ 140 | { 141 | 'id': 1, 142 | 'age': 24, 143 | 'name': 'Gennady', 144 | 'surname': 'Chibisov' 145 | } 146 | ] 147 | self.assertEqual(resp.data, expected) 148 | 149 | def test_invalid_for_db_data(self): 150 | data = { 151 | 'age': 'Not integer value' 152 | } 153 | try: 154 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) 155 | except ValueError: 156 | self.fail('Errors with invalid for DB data should be caught') 157 | else: 158 | self.assertEqual(resp.status_code, 400) 159 | if django.VERSION < (3, 0, 0): 160 | expected_message = { 161 | 'detail': "invalid literal for int() with base 10: 'Not integer value'" 162 | } 163 | else: 164 | expected_message = { 165 | 'detail': "Field 'age' expected a number but got 'Not integer value'." 166 | } 167 | self.assertEqual(resp.data, expected_message) 168 | 169 | def test_should_use_source_if_it_set_in_serializer(self): 170 | data = { 171 | 'surname': 'Ivanov' 172 | } 173 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) 174 | self.assertEqual(resp.status_code, 204) 175 | self.assertEqual(self.get_fresh_user().last_name, data['surname']) 176 | 177 | def test_should_update_write_only_fields(self): 178 | data = { 179 | 'password': '123' 180 | } 181 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) 182 | self.assertEqual(resp.status_code, 204) 183 | self.assertEqual(self.get_fresh_user().password, data['password']) 184 | 185 | def test_should_not_update_read_only_fields(self): 186 | data = { 187 | 'name': 'Ivan' 188 | } 189 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) 190 | self.assertEqual(resp.status_code, 204) 191 | self.assertEqual(self.get_fresh_user().name, self.user.name) 192 | 193 | def test_should_not_update_hidden_fields(self): 194 | data = { 195 | 'email': 'example@gmail.com' 196 | } 197 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers) 198 | self.assertEqual(resp.status_code, 204) 199 | self.assertEqual(self.get_fresh_user().email, self.user.email) 200 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_update_model_mixin/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework import routers 2 | 3 | from .views import CommentViewSet, CommentViewSetWithPermissions, UserViewSet 4 | 5 | 6 | viewset_router = routers.DefaultRouter() 7 | viewset_router.register('comments', CommentViewSet, basename='alt1') 8 | viewset_router.register('comments-with-permissions', CommentViewSetWithPermissions, basename='alt2') 9 | viewset_router.register('users', UserViewSet, basename='alt3') 10 | urlpatterns = viewset_router.urls 11 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/list_update_model_mixin/views.py: -------------------------------------------------------------------------------- 1 | import django_filters 2 | from rest_framework import viewsets 3 | from rest_framework import filters 4 | from rest_framework.permissions import DjangoModelPermissions 5 | from rest_framework_extensions.mixins import ListUpdateModelMixin 6 | 7 | from .models import ( 8 | CommentForListUpdateModelMixin as Comment, 9 | UserForListUpdateModelMixin as User 10 | ) 11 | from .serializers import UserSerializer, CommentSerializer 12 | 13 | 14 | class CommentFilter(django_filters.FilterSet): 15 | class Meta: 16 | model = Comment 17 | fields = [ 18 | 'id' 19 | ] 20 | 21 | 22 | class CommentViewSet(ListUpdateModelMixin, viewsets.ModelViewSet): 23 | queryset = Comment.objects.all() 24 | serializer_class = CommentSerializer 25 | filter_backends = (django_filters.rest_framework.DjangoFilterBackend,) 26 | filterset_class = CommentFilter 27 | 28 | 29 | class CommentViewSetWithPermissions(CommentViewSet): 30 | permission_classes = (DjangoModelPermissions,) 31 | 32 | 33 | class UserViewSet(ListUpdateModelMixin, viewsets.ModelViewSet): 34 | queryset = User.objects.all() 35 | serializer_class = UserSerializer -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/paginate_by_max_mixin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/paginate_by_max_mixin/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class CommentForPaginateByMaxMixin(models.Model): 5 | email = models.EmailField() 6 | content = models.CharField(max_length=200) 7 | created = models.DateTimeField(auto_now_add=True) 8 | 9 | class Meta: 10 | ordering = ['id'] 11 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/paginate_by_max_mixin/pagination.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.mixins import PaginateByMaxMixin 2 | from rest_framework.pagination import PageNumberPagination 3 | 4 | 5 | class WithMaxPagination(PaginateByMaxMixin, PageNumberPagination): 6 | page_size = 10 7 | page_size_query_param = 'limit' 8 | max_page_size = 20 9 | 10 | 11 | class FlexiblePagination(PageNumberPagination): 12 | page_size = 10 13 | page_size_query_param = 'page_size' 14 | 15 | 16 | class FixedPagination(PageNumberPagination): 17 | page_size = 10 18 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/paginate_by_max_mixin/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from .models import CommentForPaginateByMaxMixin 4 | 5 | 6 | class CommentSerializer(serializers.ModelSerializer): 7 | class Meta: 8 | model = CommentForPaginateByMaxMixin 9 | fields = ( 10 | 'id', 11 | 'email', 12 | ) 13 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/paginate_by_max_mixin/tests.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from django.test import TestCase, override_settings 4 | 5 | from .models import CommentForPaginateByMaxMixin 6 | 7 | 8 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.paginate_by_max_mixin.urls') 9 | class PaginateByMaxMixinTest(TestCase): 10 | 11 | def setUp(self): 12 | for i in range(30): 13 | CommentForPaginateByMaxMixin.objects.create( 14 | email='example@ya.ru', 15 | content='Hello world', 16 | created=datetime.datetime.now() 17 | ) 18 | 19 | def test_default_page_size(self): 20 | resp = self.client.get('/comments/') 21 | self.assertEqual(len(resp.data['results']), 10) 22 | 23 | def test_custom_page_size__less_then_maximum(self): 24 | resp = self.client.get('/comments/?limit=15') 25 | self.assertEqual(len(resp.data['results']), 15) 26 | 27 | def test_custom_page_size__more_then_maximum(self): 28 | resp = self.client.get('/comments/?limit=25') 29 | self.assertEqual(len(resp.data['results']), 20) 30 | 31 | def test_custom_page_size_with_max_value(self): 32 | resp = self.client.get('/comments/?limit=max') 33 | self.assertEqual(len(resp.data['results']), 20) 34 | 35 | def test_custom_page_size_with_max_value__for_view_without__paginate_by_param__attribute(self): 36 | resp = self.client.get( 37 | '/comments-without-paginate-by-param-attribute/?page_size=max') 38 | self.assertEqual(len(resp.data['results']), 10) 39 | 40 | def test_custom_page_size_with_max_value__for_view_without__max_paginate_by__attribute(self): 41 | resp = self.client.get( 42 | '/comments-without-max-paginate-by-attribute/?page_size=max') 43 | self.assertEqual(len(resp.data['results']), 10) 44 | 45 | 46 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.paginate_by_max_mixin.urls') 47 | class PaginateByMaxMixinTestBehavior__should_not_affect_view_if_DRF_does_not_supports__max_paginate_by(TestCase): 48 | 49 | def setUp(self): 50 | for i in range(30): 51 | CommentForPaginateByMaxMixin.objects.create( 52 | email='example@ya.ru', 53 | content='Hello world', 54 | created=datetime.datetime.now() 55 | ) 56 | 57 | def test_default_page_size(self): 58 | resp = self.client.get('/comments/') 59 | self.assertEqual(len(resp.data['results']), 10) 60 | 61 | def test_custom_page_size__less_then_maximum(self): 62 | resp = self.client.get('/comments/?limit=15') 63 | self.assertEqual(len(resp.data['results']), 15) 64 | 65 | def test_custom_page_size__more_then_maximum(self): 66 | resp = self.client.get('/comments/?limit=25') 67 | self.assertEqual(len(resp.data['results']), 20) 68 | 69 | def test_custom_page_size_with_max_value(self): 70 | resp = self.client.get('/comments/?limit=max') 71 | self.assertEqual(len(resp.data['results']), 20) 72 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/paginate_by_max_mixin/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework import routers 2 | 3 | from .views import ( 4 | CommentViewSet, 5 | CommentWithoutPaginateByParamViewSet, 6 | CommentWithoutMaxPaginateByAttributeViewSet, 7 | ) 8 | 9 | 10 | viewset_router = routers.DefaultRouter() 11 | viewset_router.register('comments', CommentViewSet, basename='1') 12 | viewset_router.register('comments-without-paginate-by-param-attribute', CommentWithoutPaginateByParamViewSet, basename='2') 13 | viewset_router.register('comments-without-max-paginate-by-attribute', CommentWithoutMaxPaginateByAttributeViewSet, basename='3') 14 | urlpatterns = viewset_router.urls 15 | -------------------------------------------------------------------------------- /tests_app/tests/functional/mixins/paginate_by_max_mixin/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets 2 | 3 | from .pagination import WithMaxPagination, FixedPagination, FlexiblePagination 4 | from .models import CommentForPaginateByMaxMixin 5 | from .serializers import CommentSerializer 6 | 7 | 8 | class CommentViewSet(viewsets.ReadOnlyModelViewSet): 9 | serializer_class = CommentSerializer 10 | pagination_class = WithMaxPagination 11 | queryset = CommentForPaginateByMaxMixin.objects.all().order_by('id') 12 | 13 | 14 | class CommentWithoutPaginateByParamViewSet(viewsets.ReadOnlyModelViewSet): 15 | serializer_class = CommentSerializer 16 | pagination_class = FixedPagination 17 | queryset = CommentForPaginateByMaxMixin.objects.all() 18 | 19 | 20 | class CommentWithoutMaxPaginateByAttributeViewSet(viewsets.ReadOnlyModelViewSet): 21 | pagination_class = FlexiblePagination 22 | serializer_class = CommentSerializer 23 | queryset = CommentForPaginateByMaxMixin.objects.all() 24 | -------------------------------------------------------------------------------- /tests_app/tests/functional/models.py: -------------------------------------------------------------------------------- 1 | # from .concurrency.conditional_request.models import * 2 | from .key_constructor.bits.models import * 3 | from .mixins.detail_serializer_mixin.models import * 4 | from .mixins.list_destroy_model_mixin.models import * 5 | from .mixins.list_update_model_mixin.models import * 6 | from .mixins.paginate_by_max_mixin.models import * 7 | from .permissions.extended_django_object_permissions.models import * 8 | from .routers.models import * 9 | from .routers.extended_default_router.models import * 10 | from .routers.nested_router_mixin.models import * 11 | from ._concurrency.conditional_request.models import * 12 | -------------------------------------------------------------------------------- /tests_app/tests/functional/permissions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/permissions/extended_django_object_permissions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/permissions/extended_django_object_permissions/models.py: -------------------------------------------------------------------------------- 1 | import django 2 | from django.db import models 3 | 4 | 5 | class PermissionsComment(models.Model): 6 | text = models.CharField(max_length=100) 7 | 8 | class Meta: 9 | if django.VERSION < (2, 1): 10 | permissions = ( 11 | ('view_permissionscomment', 'Can view comment'), 12 | # add, change, delete built in to django 13 | ) 14 | -------------------------------------------------------------------------------- /tests_app/tests/functional/permissions/extended_django_object_permissions/tests.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from django.contrib.auth.models import User, Group, Permission 4 | from django.contrib.contenttypes.models import ContentType 5 | from django.test import override_settings 6 | 7 | from rest_framework import status 8 | from rest_framework.test import APITestCase 9 | 10 | from tests_app.testutils import basic_auth_header 11 | from .models import PermissionsComment 12 | 13 | 14 | class ExtendedDjangoObjectPermissionTestMixin: 15 | def setUp(self): 16 | from guardian.shortcuts import assign_perm 17 | 18 | # create users 19 | create = User.objects.create_user 20 | users = { 21 | 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 22 | 'readonly': create('readonly', 'readonly@example.com', 'password'), 23 | 'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 24 | 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), 25 | } 26 | 27 | # create custom permission 28 | Permission.objects.get_or_create( 29 | codename='view_permissionscomment', 30 | content_type=ContentType.objects.get_for_model(PermissionsComment), 31 | defaults={'name': 'Can view comment'}, 32 | ) 33 | 34 | # give everyone model level permissions, as we are not testing those 35 | everyone = Group.objects.create(name='everyone') 36 | model_name = PermissionsComment._meta.model_name 37 | app_label = PermissionsComment._meta.app_label 38 | f = '{0}_{1}'.format 39 | perms = { 40 | 'view': f('view', model_name), 41 | 'change': f('change', model_name), 42 | 'delete': f('delete', model_name) 43 | } 44 | for perm in perms.values(): 45 | perm = '{0}.{1}'.format(app_label, perm) 46 | assign_perm(perm, everyone) 47 | everyone.user_set.add(*users.values()) 48 | 49 | # appropriate object level permissions 50 | readers = Group.objects.create(name='readers') 51 | writers = Group.objects.create(name='writers') 52 | deleters = Group.objects.create(name='deleters') 53 | 54 | model = PermissionsComment.objects.create(text='foo', id=1) 55 | 56 | assign_perm(perms['view'], readers, model) 57 | assign_perm(perms['change'], writers, model) 58 | assign_perm(perms['delete'], deleters, model) 59 | 60 | readers.user_set.add(users['fullaccess'], users['readonly']) 61 | writers.user_set.add(users['fullaccess'], users['writeonly']) 62 | deleters.user_set.add(users['fullaccess'], users['deleteonly']) 63 | 64 | self.credentials = {} 65 | for user in users.values(): 66 | self.credentials[user.username] = basic_auth_header(user.username, 'password') 67 | 68 | 69 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.permissions.extended_django_object_permissions.urls') 70 | class ExtendedDjangoObjectPermissionsTest_should_inherit_standard(ExtendedDjangoObjectPermissionTestMixin, 71 | APITestCase): 72 | 73 | # Delete 74 | def test_can_delete_permissions(self): 75 | response = self.client.delete( 76 | '/comments/1/', 77 | **{'HTTP_AUTHORIZATION': self.credentials['deleteonly']}) 78 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) 79 | 80 | def test_cannot_delete_permissions(self): 81 | response = self.client.delete( 82 | '/comments/1/', 83 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']}) 84 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 85 | 86 | # Update 87 | def test_can_update_permissions(self): 88 | response = self.client.patch( 89 | '/comments/1/', 90 | content_type='application/json', 91 | data=json.dumps({'text': 'foobar'}), 92 | **{ 93 | 'HTTP_AUTHORIZATION': self.credentials['writeonly'] 94 | } 95 | ) 96 | self.assertEqual(response.status_code, status.HTTP_200_OK) 97 | self.assertEqual(response.data.get('text'), 'foobar') 98 | 99 | def test_cannot_update_permissions(self): 100 | response = self.client.patch( 101 | '/comments/1/', 102 | content_type='application/json', 103 | data=json.dumps({'text': 'foobar'}), 104 | **{ 105 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly'] 106 | } 107 | ) 108 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 109 | 110 | def test_cannot_update_permissions_non_existing(self): 111 | response = self.client.patch( 112 | '/comments/999/', 113 | content_type='application/json', 114 | data=json.dumps({'text': 'foobar'}), 115 | **{ 116 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly'] 117 | } 118 | ) 119 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 120 | 121 | # Read 122 | def test_can_read_permissions(self): 123 | response = self.client.get( 124 | '/comments/1/', 125 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']}) 126 | self.assertEqual(response.status_code, status.HTTP_200_OK) 127 | 128 | def test_cannot_read_permissions(self): 129 | response = self.client.get( 130 | '/comments/1/', 131 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']}) 132 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 133 | 134 | # Read list 135 | def test_can_read_list_permissions(self): 136 | response = self.client.get( 137 | '/comments-permission-filter-backend/', 138 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']} 139 | ) 140 | self.assertEqual(response.status_code, status.HTTP_200_OK) 141 | self.assertEqual(response.data[0].get('id'), 1) 142 | 143 | def test_cannot_read_list_permissions(self): 144 | response = self.client.get( 145 | '/comments-permission-filter-backend/', 146 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']} 147 | ) 148 | self.assertEqual(response.status_code, status.HTTP_200_OK) 149 | self.assertListEqual(response.data, []) 150 | 151 | 152 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.permissions.extended_django_object_permissions.urls') 153 | class ExtendedDjangoObjectPermissionsTest_without_hiding_forbidden_objects(ExtendedDjangoObjectPermissionTestMixin, 154 | APITestCase): 155 | 156 | # Delete 157 | def test_can_delete_permissions(self): 158 | response = self.client.delete( 159 | '/comments-without-hiding-forbidden-objects/1/', 160 | **{'HTTP_AUTHORIZATION': self.credentials['deleteonly']} 161 | ) 162 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) 163 | 164 | def test_cannot_delete_permissions(self): 165 | response = self.client.delete( 166 | '/comments-without-hiding-forbidden-objects/1/', 167 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']} 168 | ) 169 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 170 | 171 | # Update 172 | def test_can_update_permissions(self): 173 | response = self.client.patch( 174 | '/comments-without-hiding-forbidden-objects/1/', 175 | content_type='application/json', 176 | data=json.dumps({'text': 'foobar'}), 177 | **{ 178 | 'HTTP_AUTHORIZATION': self.credentials['writeonly'] 179 | } 180 | ) 181 | self.assertEqual(response.status_code, status.HTTP_200_OK) 182 | self.assertEqual(response.data.get('text'), 'foobar') 183 | 184 | def test_cannot_update_permissions(self): 185 | response = self.client.patch( 186 | '/comments-without-hiding-forbidden-objects/1/', 187 | content_type='application/json', 188 | data=json.dumps({'text': 'foobar'}), 189 | **{ 190 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly'] 191 | } 192 | ) 193 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 194 | 195 | def test_cannot_update_permissions_non_existing(self): 196 | response = self.client.patch( 197 | '/comments-without-hiding-forbidden-objects/999/', 198 | content_type='application/json', 199 | data=json.dumps({'text': 'foobar'}), 200 | **{ 201 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly'] 202 | } 203 | ) 204 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 205 | 206 | # Read 207 | def test_can_read_permissions(self): 208 | response = self.client.get( 209 | '/comments-without-hiding-forbidden-objects/1/', 210 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']} 211 | ) 212 | self.assertEqual(response.status_code, status.HTTP_200_OK) 213 | 214 | def test_cannot_read_permissions(self): 215 | response = self.client.get( 216 | '/comments-without-hiding-forbidden-objects/1/', 217 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']} 218 | ) 219 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 220 | 221 | # Read list 222 | def test_can_read_list_permissions(self): 223 | response = self.client.get( 224 | '/comments-without-hiding-forbidden-objects-permission-filter-backend/', 225 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']} 226 | ) 227 | self.assertEqual(response.status_code, status.HTTP_200_OK) 228 | self.assertEqual(response.data[0].get('id'), 1) 229 | 230 | def test_cannot_read_list_permissions(self): 231 | response = self.client.get( 232 | '/comments-without-hiding-forbidden-objects-permission-filter-backend/', 233 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']} 234 | ) 235 | self.assertEqual(response.status_code, status.HTTP_200_OK) 236 | self.assertListEqual(response.data, []) 237 | -------------------------------------------------------------------------------- /tests_app/tests/functional/permissions/extended_django_object_permissions/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework import routers 2 | 3 | from .views import ( 4 | CommentViewSet, 5 | CommentViewSetPermissionFilterBackend, 6 | CommentViewSetWithoutHidingForbiddenObjects, 7 | CommentViewSetWithoutHidingForbiddenObjectsPermissionFilterBackend 8 | ) 9 | 10 | 11 | viewset_router = routers.DefaultRouter() 12 | viewset_router.register('comments', CommentViewSet, basename='alt1') 13 | viewset_router.register('comments-permission-filter-backend', CommentViewSetPermissionFilterBackend, basename='alt2') 14 | viewset_router.register('comments-without-hiding-forbidden-objects', CommentViewSetWithoutHidingForbiddenObjects, basename='alt3') 15 | viewset_router.register( 16 | 'comments-without-hiding-forbidden-objects-permission-filter-backend', 17 | CommentViewSetWithoutHidingForbiddenObjectsPermissionFilterBackend, 18 | basename='alt4' 19 | ) 20 | urlpatterns = viewset_router.urls 21 | -------------------------------------------------------------------------------- /tests_app/tests/functional/permissions/extended_django_object_permissions/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets, serializers 2 | from rest_framework import authentication 3 | 4 | try: 5 | # djangorestframework >= 3.9 6 | from rest_framework_guardian.filters import DjangoObjectPermissionsFilter 7 | except ImportError: 8 | from rest_framework.filters import DjangoObjectPermissionsFilter 9 | 10 | try: 11 | from rest_framework_extensions.permissions import ExtendedDjangoObjectPermissions 12 | except ImportError: 13 | class ExtendedDjangoObjectPermissions: 14 | pass 15 | 16 | from .models import PermissionsComment 17 | 18 | 19 | class CommentObjectPermissions(ExtendedDjangoObjectPermissions): 20 | perms_map = { 21 | 'GET': ['%(app_label)s.view_%(model_name)s'], 22 | 'OPTIONS': ['%(app_label)s.view_%(model_name)s'], 23 | 'HEAD': ['%(app_label)s.view_%(model_name)s'], 24 | 'POST': ['%(app_label)s.add_%(model_name)s'], 25 | 'PUT': ['%(app_label)s.change_%(model_name)s'], 26 | 'PATCH': ['%(app_label)s.change_%(model_name)s'], 27 | 'DELETE': ['%(app_label)s.delete_%(model_name)s'], 28 | } 29 | 30 | 31 | class PermissionsCommentSerializer(serializers.ModelSerializer): 32 | class Meta: 33 | model = PermissionsComment 34 | fields = '__all__' 35 | 36 | 37 | class CommentObjectPermissionsWithoutHidingForbiddenObjects(CommentObjectPermissions): 38 | hide_forbidden_for_read_objects = False 39 | 40 | 41 | class CommentViewSet(viewsets.ModelViewSet): 42 | queryset = PermissionsComment.objects.all() 43 | serializer_class = PermissionsCommentSerializer 44 | authentication_classes = [authentication.BasicAuthentication] 45 | permission_classes = (CommentObjectPermissions,) 46 | 47 | 48 | class CommentViewSetPermissionFilterBackend(CommentViewSet): 49 | filter_backends = (DjangoObjectPermissionsFilter,) 50 | 51 | 52 | class CommentViewSetWithoutHidingForbiddenObjects(CommentViewSet): 53 | permission_classes = (CommentObjectPermissionsWithoutHidingForbiddenObjects,) 54 | 55 | 56 | class CommentViewSetWithoutHidingForbiddenObjectsPermissionFilterBackend(CommentViewSetWithoutHidingForbiddenObjects): 57 | filter_backends = (DjangoObjectPermissionsFilter,) -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/extended_default_router/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/extended_default_router/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class DefaultRouterUserModel(models.Model): 5 | name = models.CharField(max_length=10) 6 | groups = models.ManyToManyField('DefaultRouterGroupModel', related_name='user_groups') 7 | 8 | 9 | class DefaultRouterGroupModel(models.Model): 10 | name = models.CharField(max_length=10) 11 | permissions = models.ManyToManyField('DefaultRouterPermissionModel') 12 | 13 | 14 | class DefaultRouterPermissionModel(models.Model): 15 | name = models.CharField(max_length=10) 16 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/extended_default_router/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import override_settings 2 | from django.urls import NoReverseMatch 3 | 4 | from rest_framework.test import APITestCase 5 | 6 | 7 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.routers.extended_default_router.urls') 8 | class ExtendedDefaultRouterTestBehaviour(APITestCase): 9 | 10 | def test_index_page(self): 11 | try: 12 | response = self.client.get('/') 13 | except NoReverseMatch: 14 | issue = 'https://github.com/chibisov/drf-extensions/issues/14' 15 | self.fail('DefaultRouter tries to reverse nested routes and breaks with error. NoReverseMatch should be ' 16 | 'handled for nested routes. They must be excluded from index page. ' + issue) 17 | self.assertEqual(response.status_code, 200) 18 | 19 | expected = { 20 | 'users': 'http://testserver/users/', 21 | 'groups': 'http://testserver/groups/', 22 | 'permissions': 'http://testserver/permissions/', 23 | } 24 | self.assertEqual(response.data, expected) 25 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/extended_default_router/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.routers import ExtendedDefaultRouter 2 | 3 | from .views import ( 4 | UserViewSet, 5 | GroupViewSet, 6 | PermissionViewSet, 7 | ) 8 | 9 | 10 | router = ExtendedDefaultRouter() 11 | # nested routes 12 | ( 13 | router.register(r'users', UserViewSet) 14 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user_groups']) 15 | .register(r'permissions', PermissionViewSet, 'users-groups-permission', parents_query_lookups=['group__user', 'group']) 16 | ) 17 | # simple routes 18 | router.register(r'groups', GroupViewSet, 'group') 19 | router.register(r'permissions', PermissionViewSet, 'permission') 20 | 21 | urlpatterns = router.urls 22 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/extended_default_router/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework.viewsets import ModelViewSet 2 | 3 | from rest_framework_extensions.mixins import NestedViewSetMixin 4 | 5 | from .models import ( 6 | DefaultRouterUserModel, 7 | DefaultRouterGroupModel, 8 | DefaultRouterPermissionModel, 9 | ) 10 | 11 | 12 | class UserViewSet(NestedViewSetMixin, ModelViewSet): 13 | queryset = DefaultRouterUserModel.objects.all() 14 | 15 | 16 | class GroupViewSet(NestedViewSetMixin, ModelViewSet): 17 | queryset = DefaultRouterGroupModel.objects.all() 18 | 19 | 20 | class PermissionViewSet(NestedViewSetMixin, ModelViewSet): 21 | queryset = DefaultRouterPermissionModel.objects.all() 22 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class RouterTestModel(models.Model): 5 | uuid = models.CharField(max_length=20) 6 | text = models.CharField(max_length=200) -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/nested_router_mixin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/nested_router_mixin/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from django.db import models 4 | from django.contrib.contenttypes.fields import GenericForeignKey 5 | 6 | 7 | class NestedRouterMixinUserModel(models.Model): 8 | email = models.EmailField(blank=True, null=True) 9 | code = models.UUIDField(default=uuid.uuid4) 10 | name = models.CharField(max_length=10) 11 | groups = models.ManyToManyField( 12 | 'NestedRouterMixinGroupModel', related_name='user_groups') 13 | 14 | 15 | class NestedRouterMixinGroupModel(models.Model): 16 | name = models.CharField(max_length=10) 17 | permissions = models.ManyToManyField('NestedRouterMixinPermissionModel') 18 | 19 | 20 | class NestedRouterMixinPermissionModel(models.Model): 21 | name = models.CharField(max_length=10) 22 | 23 | 24 | class NestedRouterMixinTaskModel(models.Model): 25 | title = models.CharField(max_length=30) 26 | 27 | 28 | class NestedRouterMixinBookModel(models.Model): 29 | title = models.CharField(max_length=30) 30 | 31 | 32 | class NestedRouterMixinCommentModel(models.Model): 33 | content_type = models.ForeignKey( 34 | "contenttypes.ContentType", 35 | blank=True, 36 | null=True, 37 | on_delete=models.CASCADE, 38 | ) 39 | object_id = models.PositiveIntegerField(blank=True, null=True) 40 | content_object = GenericForeignKey() 41 | text = models.CharField(max_length=30) 42 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/nested_router_mixin/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from .models import ( 4 | NestedRouterMixinUserModel as UserModel, 5 | NestedRouterMixinGroupModel as GroupModel, 6 | NestedRouterMixinPermissionModel as PermissionModel, 7 | NestedRouterMixinTaskModel as TaskModel, 8 | NestedRouterMixinBookModel as BookModel, 9 | NestedRouterMixinCommentModel as CommentModel, 10 | ) 11 | 12 | 13 | class UserSerializer(serializers.ModelSerializer): 14 | class Meta: 15 | model = UserModel 16 | fields = ( 17 | 'id', 18 | 'name' 19 | ) 20 | 21 | 22 | class GroupSerializer(serializers.ModelSerializer): 23 | class Meta: 24 | model = GroupModel 25 | fields = ( 26 | 'id', 27 | 'name' 28 | ) 29 | 30 | 31 | class PermissionSerializer(serializers.ModelSerializer): 32 | class Meta: 33 | model = PermissionModel 34 | fields = ( 35 | 'id', 36 | 'name' 37 | ) 38 | 39 | 40 | class TaskSerializer(serializers.ModelSerializer): 41 | class Meta: 42 | model = TaskModel 43 | fields = ( 44 | 'id', 45 | 'title' 46 | ) 47 | 48 | 49 | class BookSerializer(serializers.ModelSerializer): 50 | class Meta: 51 | model = BookModel 52 | fields = ( 53 | 'id', 54 | 'title' 55 | ) 56 | 57 | 58 | class CommentSerializer(serializers.ModelSerializer): 59 | class Meta: 60 | model = CommentModel 61 | fields = ( 62 | 'id', 63 | 'content_type', 64 | 'object_id', 65 | 'text' 66 | ) -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/nested_router_mixin/urls.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.routers import ExtendedSimpleRouter 2 | 3 | from .views import ( 4 | UserViewSet, 5 | GroupViewSet, 6 | PermissionViewSet, 7 | ) 8 | 9 | 10 | router = ExtendedSimpleRouter() 11 | # main routes 12 | ( 13 | router.register(r'users', UserViewSet) 14 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user_groups']) 15 | .register(r'permissions', PermissionViewSet, 'users-groups-permission', parents_query_lookups=['group__user', 'group']) 16 | ) 17 | 18 | # register on one depth 19 | permissions_routes = router.register(r'permissions', PermissionViewSet) 20 | permissions_routes.register(r'groups', GroupViewSet, 'permissions-group', parents_query_lookups=['permissions']) 21 | permissions_routes.register(r'users', UserViewSet, 'permissions-user', parents_query_lookups=['groups__permissions']) 22 | 23 | # simple routes 24 | router.register(r'groups', GroupViewSet, 'group') 25 | router.register(r'permissions', PermissionViewSet, 'permission') 26 | 27 | urlpatterns = router.urls 28 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/nested_router_mixin/urls_generic_relations.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.routers import ExtendedSimpleRouter 2 | 3 | from .views import ( 4 | TaskViewSet, 5 | TaskCommentViewSet, 6 | BookViewSet, 7 | BookCommentViewSet 8 | ) 9 | 10 | 11 | router = ExtendedSimpleRouter() 12 | # tasks route 13 | ( 14 | router.register(r'tasks', TaskViewSet) 15 | .register(r'comments', TaskCommentViewSet, 'tasks-comment', parents_query_lookups=['object_id']) 16 | ) 17 | # books route 18 | ( 19 | router.register(r'books', BookViewSet) 20 | .register(r'comments', BookCommentViewSet, 'books-comment', parents_query_lookups=['object_id']) 21 | ) 22 | 23 | urlpatterns = router.urls 24 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/nested_router_mixin/urls_parent_viewset_lookup.py: -------------------------------------------------------------------------------- 1 | from rest_framework_extensions.routers import ExtendedSimpleRouter 2 | 3 | from .views import ( 4 | UserViewSetWithEmailLookup, 5 | UserViewSetWithUUIDLookup, 6 | GroupViewSet, 7 | ) 8 | 9 | 10 | router = ExtendedSimpleRouter() 11 | 12 | # main routes 13 | ( 14 | router.register(r'users', UserViewSetWithEmailLookup, basename='users-by-uuid') 15 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user_groups__email']) 16 | ) 17 | 18 | # uuid routes 19 | ( 20 | router.register(r'users-by-uuid', UserViewSetWithUUIDLookup) 21 | .register(r'groups', GroupViewSet, 'users-group-uuid', parents_query_lookups=['user_groups__code']) 22 | ) 23 | 24 | urlpatterns = router.urls 25 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/nested_router_mixin/views.py: -------------------------------------------------------------------------------- 1 | from django.contrib.contenttypes.models import ContentType 2 | from django.core.exceptions import ValidationError 3 | from django.http import Http404 4 | 5 | from rest_framework.decorators import action 6 | from rest_framework.response import Response 7 | from rest_framework.viewsets import ModelViewSet 8 | import uuid 9 | from rest_framework_extensions.mixins import NestedViewSetMixin 10 | 11 | from .models import ( 12 | NestedRouterMixinUserModel as UserModel, 13 | NestedRouterMixinGroupModel as GroupModel, 14 | NestedRouterMixinPermissionModel as PermissionModel, 15 | NestedRouterMixinTaskModel as TaskModel, 16 | NestedRouterMixinBookModel as BookModel, 17 | NestedRouterMixinCommentModel as CommentModel 18 | ) 19 | from .serializers import ( 20 | UserSerializer, 21 | GroupSerializer, 22 | PermissionSerializer, 23 | TaskSerializer, 24 | BookSerializer, 25 | CommentSerializer 26 | ) 27 | 28 | 29 | class UserViewSet(NestedViewSetMixin, ModelViewSet): 30 | queryset = UserModel.objects.all() 31 | serializer_class = UserSerializer 32 | 33 | @action(detail=False, methods=['post'], url_path='users-list-action') 34 | def users_list_action(self, request, *args, **kwargs): 35 | return Response('users list action') 36 | 37 | @action(detail=True, methods=['post'], url_path='users-action') 38 | def users_action(self, request, *args, **kwargs): 39 | return Response('users action') 40 | 41 | 42 | class GroupViewSet(NestedViewSetMixin, ModelViewSet): 43 | queryset = GroupModel.objects.all() 44 | serializer_class = GroupSerializer 45 | 46 | @action(detail=False, url_path='groups-list-link') 47 | def groups_list_link(self, request, *args, **kwargs): 48 | return Response('groups list link') 49 | 50 | @action(detail=True, url_path='groups-link') 51 | def groups_link(self, request, *args, **kwargs): 52 | return Response('groups link') 53 | 54 | 55 | class PermissionViewSet(NestedViewSetMixin, ModelViewSet): 56 | queryset = PermissionModel.objects.all() 57 | serializer_class = PermissionSerializer 58 | 59 | @action(detail=False, methods=['post'], url_path='permissions-list-action') 60 | def permissions_list_action(self, request, *args, **kwargs): 61 | return Response('permissions list action') 62 | 63 | @action(detail=True, methods=['post'], url_path='permissions-action') 64 | def permissions_action(self, request, *args, **kwargs): 65 | return Response('permissions action') 66 | 67 | 68 | class TaskViewSet(NestedViewSetMixin, ModelViewSet): 69 | queryset = TaskModel.objects.all() 70 | serializer_class = TaskSerializer 71 | 72 | 73 | class BookViewSet(NestedViewSetMixin, ModelViewSet): 74 | queryset = BookModel.objects.all() 75 | serializer_class = BookSerializer 76 | 77 | 78 | class CommentViewSet(NestedViewSetMixin, ModelViewSet): 79 | queryset = CommentModel.objects.all() 80 | serializer_class = CommentSerializer 81 | 82 | 83 | class TaskCommentViewSet(CommentViewSet): 84 | def get_queryset(self): 85 | return super().get_queryset().filter( 86 | content_type=ContentType.objects.get_for_model(TaskModel) 87 | ) 88 | 89 | 90 | class BookCommentViewSet(CommentViewSet): 91 | def get_queryset(self): 92 | return super().get_queryset().filter( 93 | content_type=ContentType.objects.get_for_model(BookModel) 94 | ) 95 | 96 | 97 | class UserViewSetWithEmailLookup(NestedViewSetMixin, ModelViewSet): 98 | queryset = UserModel.objects.all() 99 | serializer_class = UserSerializer 100 | lookup_field = 'email' 101 | lookup_value_regex = r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+' 102 | 103 | 104 | class UserViewSetWithUUIDLookup(NestedViewSetMixin, ModelViewSet): 105 | queryset = UserModel.objects.all() 106 | serializer_class = UserSerializer 107 | lookup_field = 'code' 108 | 109 | def get_object(self): 110 | try: 111 | # Try to validate UUID before getting object 112 | lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field 113 | if lookup_url_kwarg in self.kwargs: 114 | try: 115 | uuid.UUID(str(self.kwargs[lookup_url_kwarg])) 116 | except ValueError: 117 | raise Http404 118 | return super().get_object() 119 | except (ValueError, ValidationError): 120 | raise Http404 121 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | from rest_framework_extensions.routers import ExtendedSimpleRouter 4 | 5 | from tests_app.testutils import get_url_pattern_by_regex_pattern 6 | from .views import RouterViewSet 7 | 8 | 9 | class TestTrailingSlashIncluded(TestCase): 10 | def test_urls_have_trailing_slash_by_default(self): 11 | router = ExtendedSimpleRouter() 12 | router.register(r'router-viewset', RouterViewSet) 13 | urls = router.urls 14 | 15 | for exp in ['^router-viewset/$', 16 | '^router-viewset/(?P[^/.]+)/$', 17 | '^router-viewset/list_controller/$', 18 | '^router-viewset/(?P[^/.]+)/detail_controller/$']: 19 | msg = 'Should find url pattern with regexp %s' % exp 20 | self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg) 21 | 22 | 23 | class TestTrailingSlashRemoved(TestCase): 24 | def test_urls_can_have_trailing_slash_removed(self): 25 | router = ExtendedSimpleRouter(trailing_slash=False) 26 | router.register(r'router-viewset', RouterViewSet) 27 | urls = router.urls 28 | 29 | for exp in ['^router-viewset$', 30 | '^router-viewset/(?P[^/.]+)$', 31 | '^router-viewset/list_controller$', 32 | '^router-viewset/(?P[^/.]+)/detail_controller$']: 33 | msg = 'Should find url pattern with regexp %s' % exp 34 | self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg) 35 | -------------------------------------------------------------------------------- /tests_app/tests/functional/routers/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets 2 | from rest_framework.decorators import action 3 | 4 | from .models import RouterTestModel 5 | 6 | 7 | class RouterViewSet(viewsets.ModelViewSet): 8 | queryset = RouterTestModel.objects.all() 9 | 10 | @action(detail=True) 11 | def detail_controller(self): 12 | pass 13 | 14 | @action(detail=False) 15 | def list_controller(self): 16 | pass 17 | -------------------------------------------------------------------------------- /tests_app/tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/unit/_etag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/_etag/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/unit/_etag/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/cache/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/unit/cache/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/decorators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/decorators/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/unit/decorators/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | from rest_framework import pagination, viewsets 3 | from rest_framework_extensions.decorators import paginate 4 | 5 | 6 | class TestPaginateDecorator(TestCase): 7 | 8 | def test_empty_pagination_class(self): 9 | msg = "@paginate missing required argument: 'pagination_class'" 10 | with self.assertRaisesMessage(AssertionError, msg): 11 | @paginate() 12 | class MockGenericViewSet(viewsets.GenericViewSet): 13 | pass 14 | 15 | def test_adding_page_number_pagination(self): 16 | """ 17 | Other default pagination classes' test result will be same as this even if kwargs changed to anything. 18 | """ 19 | 20 | @paginate(pagination_class=pagination.PageNumberPagination, page_size=5, ordering='-created_at') 21 | class MockGenericViewSet(viewsets.GenericViewSet): 22 | pass 23 | 24 | assert hasattr(MockGenericViewSet, 'pagination_class') 25 | assert MockGenericViewSet.pagination_class().page_size == 5 26 | assert MockGenericViewSet.pagination_class().ordering == '-created_at' 27 | 28 | def test_adding_custom_pagination(self): 29 | class CustomPagination(pagination.BasePagination): 30 | pass 31 | 32 | @paginate(pagination_class=CustomPagination, kwarg1='kwarg1', kwarg2='kwarg2') 33 | class MockGenericViewSet(viewsets.GenericViewSet): 34 | pass 35 | 36 | assert hasattr(MockGenericViewSet, 'pagination_class') 37 | assert MockGenericViewSet.pagination_class().kwarg1 == 'kwarg1' 38 | assert MockGenericViewSet.pagination_class().kwarg2 == 'kwarg2' 39 | -------------------------------------------------------------------------------- /tests_app/tests/unit/key_constructor/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/key_constructor/bits/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/key_constructor/bits/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class BitTestModel(models.Model): 5 | is_active = models.BooleanField(default=False) 6 | -------------------------------------------------------------------------------- /tests_app/tests/unit/key_constructor/constructor/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 2.2 on 2019-04-16 11:51 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | initial = True 10 | 11 | dependencies = [ 12 | ] 13 | 14 | operations = [ 15 | migrations.CreateModel( 16 | name='BitTestModel', 17 | fields=[ 18 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 19 | ('is_active', models.BooleanField(default=False)), 20 | ], 21 | ), 22 | migrations.CreateModel( 23 | name='NestedRouterMixinGroupModel', 24 | fields=[ 25 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 26 | ('name', models.CharField(max_length=10)), 27 | ], 28 | ), 29 | migrations.CreateModel( 30 | name='NestedRouterMixinPermissionModel', 31 | fields=[ 32 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 33 | ('name', models.CharField(max_length=10)), 34 | ], 35 | ), 36 | migrations.CreateModel( 37 | name='UserModel', 38 | fields=[ 39 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 40 | ('name', models.CharField(max_length=20)), 41 | ], 42 | ), 43 | migrations.CreateModel( 44 | name='NestedRouterMixinUserModel', 45 | fields=[ 46 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 47 | ('name', models.CharField(max_length=10)), 48 | ('groups', models.ManyToManyField(related_name='user_groups', to='unit.NestedRouterMixinGroupModel')), 49 | ], 50 | ), 51 | migrations.AddField( 52 | model_name='nestedroutermixingroupmodel', 53 | name='permissions', 54 | field=models.ManyToManyField(to='unit.NestedRouterMixinPermissionModel'), 55 | ), 56 | migrations.CreateModel( 57 | name='CommentModel', 58 | fields=[ 59 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 60 | ('title', models.CharField(max_length=20)), 61 | ('text', models.CharField(max_length=200)), 62 | ('attachment', models.FileField(blank=True, max_length=500, null=True, upload_to='test_serializers')), 63 | ('hidden_text', models.CharField(blank=True, max_length=200, null=True)), 64 | ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='comments', to='unit.UserModel')), 65 | ('users_liked', models.ManyToManyField(blank=True, to='unit.UserModel')), 66 | ], 67 | ), 68 | ] 69 | -------------------------------------------------------------------------------- /tests_app/tests/unit/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/migrations/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/unit/models.py: -------------------------------------------------------------------------------- 1 | from .key_constructor.bits.models import * 2 | from .routers.nested_router_mixin.models import * 3 | from .serializers.models import * 4 | -------------------------------------------------------------------------------- /tests_app/tests/unit/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/routers/__init__.py -------------------------------------------------------------------------------- /tests_app/tests/unit/routers/nested_router_mixin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/routers/nested_router_mixin/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class NestedRouterMixinPermissionModel(models.Model): 5 | name = models.CharField(max_length=10) 6 | 7 | 8 | class NestedRouterMixinGroupModel(models.Model): 9 | name = models.CharField(max_length=10) 10 | permissions = models.ManyToManyField( 11 | 'NestedRouterMixinPermissionModel') 12 | 13 | 14 | class NestedRouterMixinUserModel(models.Model): 15 | name = models.CharField(max_length=10) 16 | groups = models.ManyToManyField( 17 | 'NestedRouterMixinGroupModel', related_name='user_groups') 18 | -------------------------------------------------------------------------------- /tests_app/tests/unit/routers/nested_router_mixin/tests.py: -------------------------------------------------------------------------------- 1 | from rest_framework.test import APITestCase 2 | from rest_framework_extensions.routers import ExtendedSimpleRouter 3 | from rest_framework_extensions.utils import compose_parent_pk_kwarg_name 4 | from .views import ( 5 | UserViewSet, 6 | GroupViewSet, 7 | PermissionViewSet, 8 | CustomRegexUserViewSet, 9 | CustomRegexGroupViewSet, 10 | CustomRegexPermissionViewSet, 11 | ) 12 | 13 | 14 | def get_regex_pattern(urlpattern): 15 | return urlpattern.pattern.regex.pattern 16 | 17 | 18 | class NestedRouterMixinTest(APITestCase): 19 | def get_lookup_regex(self, value): 20 | return '(?P<{0}>[^/.]+)'.format(value) 21 | 22 | def get_parent_lookup_regex(self, value): 23 | return '(?P<{0}>[^/.]+)'.format(compose_parent_pk_kwarg_name(value)) 24 | 25 | def get_custom_regex_lookup(self, pk_kwarg_name, lookup_value_regex): 26 | """ Build lookup regex with custom regular expression. """ 27 | return '(?P<{pk_kwarg_name}>{lookup_value_regex})'.format( 28 | pk_kwarg_name=pk_kwarg_name, 29 | lookup_value_regex=lookup_value_regex 30 | ) 31 | 32 | def get_custom_regex_parent_lookup(self, parent_pk_kwarg_name, 33 | parent_lookup_value_regex): 34 | """ Build parent lookup regex with custom regular expression. """ 35 | return self.get_custom_regex_lookup( 36 | compose_parent_pk_kwarg_name(parent_pk_kwarg_name), 37 | parent_lookup_value_regex 38 | ) 39 | 40 | def test_one_route(self): 41 | router = ExtendedSimpleRouter() 42 | router.register(r'users', UserViewSet, 'user') 43 | 44 | # test user list 45 | self.assertEqual(router.urls[0].name, 'user-list') 46 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$') 47 | 48 | # test user detail 49 | self.assertEqual(router.urls[1].name, 'user-detail') 50 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format(self.get_lookup_regex('pk'))) 51 | 52 | def test_nested_route(self): 53 | router = ExtendedSimpleRouter() 54 | ( 55 | router.register(r'users', UserViewSet, 'user') 56 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user']) 57 | ) 58 | 59 | # test user list 60 | self.assertEqual(router.urls[0].name, 'user-list') 61 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$') 62 | 63 | # test user detail 64 | self.assertEqual(router.urls[1].name, 'user-detail') 65 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format(self.get_lookup_regex('pk'))) 66 | 67 | # test users group list 68 | self.assertEqual(router.urls[2].name, 'users-group-list') 69 | self.assertEqual(get_regex_pattern(router.urls[2]), r'^users/{0}/groups/$'.format( 70 | self.get_parent_lookup_regex('user') 71 | ) 72 | ) 73 | 74 | # test users group detail 75 | self.assertEqual(router.urls[3].name, 'users-group-detail') 76 | self.assertEqual(get_regex_pattern(router.urls[3]), r'^users/{0}/groups/{1}/$'.format( 77 | self.get_parent_lookup_regex('user'), 78 | self.get_lookup_regex('pk') 79 | ), 80 | ) 81 | 82 | def test_nested_route_depth_3(self): 83 | router = ExtendedSimpleRouter() 84 | ( 85 | router.register(r'users', UserViewSet, 'user') 86 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user']) 87 | .register(r'permissions', PermissionViewSet, 'users-groups-permission', parents_query_lookups=[ 88 | 'group__user', 89 | 'group', 90 | ] 91 | ) 92 | ) 93 | 94 | # test user list 95 | self.assertEqual(router.urls[0].name, 'user-list') 96 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$') 97 | 98 | # test user detail 99 | self.assertEqual(router.urls[1].name, 'user-detail') 100 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format(self.get_lookup_regex('pk'))) 101 | 102 | # test users group list 103 | self.assertEqual(router.urls[2].name, 'users-group-list') 104 | self.assertEqual(get_regex_pattern(router.urls[2]), r'^users/{0}/groups/$'.format( 105 | self.get_parent_lookup_regex('user') 106 | ) 107 | ) 108 | 109 | # test users group detail 110 | self.assertEqual(router.urls[3].name, 'users-group-detail') 111 | self.assertEqual(get_regex_pattern(router.urls[3]), r'^users/{0}/groups/{1}/$'.format( 112 | self.get_parent_lookup_regex('user'), 113 | self.get_lookup_regex('pk') 114 | ), 115 | ) 116 | 117 | # test users groups permission list 118 | self.assertEqual(router.urls[4].name, 'users-groups-permission-list') 119 | self.assertEqual(get_regex_pattern(router.urls[4]), r'^users/{0}/groups/{1}/permissions/$'.format( 120 | self.get_parent_lookup_regex('group__user'), 121 | self.get_parent_lookup_regex('group'), 122 | ) 123 | ) 124 | 125 | # test users groups permission detail 126 | self.assertEqual(router.urls[5].name, 'users-groups-permission-detail') 127 | self.assertEqual(get_regex_pattern(router.urls[5]), r'^users/{0}/groups/{1}/permissions/{2}/$'.format( 128 | self.get_parent_lookup_regex('group__user'), 129 | self.get_parent_lookup_regex('group'), 130 | self.get_lookup_regex('pk') 131 | ), 132 | ) 133 | 134 | def test_nested_route_depth_3_custom_regex(self): 135 | """ 136 | Nested routes with over two level of depth should respect all parents' 137 | `lookup_value_regex` attribute. 138 | """ 139 | router = ExtendedSimpleRouter() 140 | ( 141 | router.register(r'users', CustomRegexUserViewSet, 'user') 142 | .register(r'groups', CustomRegexGroupViewSet, 'users-group', 143 | parents_query_lookups=['user']) 144 | .register(r'permissions', CustomRegexPermissionViewSet, 145 | 'users-groups-permission', parents_query_lookups=[ 146 | 'group__user', 147 | 'group', 148 | ] 149 | ) 150 | ) 151 | 152 | # custom regex configuration 153 | user_viewset_regex = CustomRegexUserViewSet.lookup_value_regex 154 | group_viewset_regex = CustomRegexGroupViewSet.lookup_value_regex 155 | perm_viewset_regex = CustomRegexPermissionViewSet.lookup_value_regex 156 | 157 | # test user list 158 | self.assertEqual(router.urls[0].name, 'user-list') 159 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$') 160 | 161 | # test user detail 162 | self.assertEqual(router.urls[1].name, 'user-detail') 163 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format( 164 | self.get_custom_regex_lookup('pk', user_viewset_regex)) 165 | ) 166 | 167 | # test users group list 168 | self.assertEqual(router.urls[2].name, 'users-group-list') 169 | self.assertEqual(get_regex_pattern(router.urls[2]), r'^users/{0}/groups/$'.format( 170 | self.get_custom_regex_parent_lookup('user', user_viewset_regex) 171 | ) 172 | ) 173 | # test users group detail 174 | self.assertEqual(router.urls[3].name, 'users-group-detail') 175 | self.assertEqual(get_regex_pattern(router.urls[3]), r'^users/{0}/groups/{1}/$'.format( 176 | self.get_custom_regex_parent_lookup('user', user_viewset_regex), 177 | self.get_custom_regex_lookup('pk', group_viewset_regex) 178 | ), 179 | ) 180 | # test users groups permission list 181 | self.assertEqual(router.urls[4].name, 'users-groups-permission-list') 182 | self.assertEqual(get_regex_pattern(router.urls[4]), r'^users/{0}/groups/{1}/permissions/$'.format( 183 | self.get_custom_regex_parent_lookup('group__user', user_viewset_regex), 184 | self.get_custom_regex_parent_lookup('group', group_viewset_regex), 185 | ) 186 | ) 187 | 188 | # test users groups permission detail 189 | self.assertEqual(router.urls[5].name, 'users-groups-permission-detail') 190 | self.assertEqual(get_regex_pattern(router.urls[5]), r'^users/{0}/groups/{1}/permissions/{2}/$'.format( 191 | self.get_custom_regex_parent_lookup('group__user', user_viewset_regex), 192 | self.get_custom_regex_parent_lookup('group', group_viewset_regex), 193 | self.get_custom_regex_lookup('pk', perm_viewset_regex) 194 | ), 195 | ) 196 | -------------------------------------------------------------------------------- /tests_app/tests/unit/routers/nested_router_mixin/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework.viewsets import ModelViewSet 2 | 3 | from .models import ( 4 | NestedRouterMixinUserModel as UserModel, 5 | NestedRouterMixinGroupModel as GroupModel, 6 | NestedRouterMixinPermissionModel as PermissionModel, 7 | ) 8 | 9 | 10 | class UserViewSet(ModelViewSet): 11 | model = UserModel 12 | 13 | 14 | class GroupViewSet(ModelViewSet): 15 | model = GroupModel 16 | 17 | 18 | class PermissionViewSet(ModelViewSet): 19 | model = PermissionModel 20 | 21 | 22 | class CustomRegexUserViewSet(ModelViewSet): 23 | lookup_value_regex = 'a' 24 | model = UserModel 25 | 26 | 27 | class CustomRegexGroupViewSet(ModelViewSet): 28 | lookup_value_regex = 'b' 29 | model = GroupModel 30 | 31 | 32 | class CustomRegexPermissionViewSet(ModelViewSet): 33 | lookup_value_regex = 'c' 34 | model = PermissionModel 35 | -------------------------------------------------------------------------------- /tests_app/tests/unit/routers/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | from rest_framework import viewsets 4 | from rest_framework.decorators import action 5 | from rest_framework.response import Response 6 | from rest_framework_extensions.routers import ExtendedDefaultRouter 7 | 8 | 9 | class ExtendedDefaultRouterTest(TestCase): 10 | def setUp(self): 11 | self.router = ExtendedDefaultRouter() 12 | 13 | def get_routes_names(self, routes): 14 | return [i.name for i in routes] 15 | 16 | def get_dynamic_route_by_def_name(self, def_name, routes): 17 | try: 18 | return [i for i in routes if def_name in i.mapping.values()][0] 19 | except IndexError: 20 | return None 21 | 22 | def test_dynamic_list_route_should_come_before_detail_route(self): 23 | class BasicViewSet(viewsets.ViewSet): 24 | def list(self, request, *args, **kwargs): 25 | return Response({'method': 'list'}) 26 | 27 | @action(detail=False) 28 | def detail1(self, request, *args, **kwargs): 29 | return Response({'method': 'detail1'}) 30 | 31 | routes = self.router.get_routes(BasicViewSet) 32 | expected = [ 33 | '{basename}-list', 34 | '{basename}-detail1', 35 | '{basename}-detail' 36 | ] 37 | msg = '@detail_route methods should come first in routes order' 38 | self.assertEqual(self.get_routes_names(routes), expected, msg) 39 | 40 | def test_detail_route(self): 41 | class BasicViewSet(viewsets.ViewSet): 42 | @action(detail=True) 43 | def action1(self, request, *args, **kwargs): 44 | pass 45 | 46 | routes = self.router.get_routes(BasicViewSet) 47 | action1_route = self.get_dynamic_route_by_def_name('action1', routes) 48 | 49 | msg = '@detail_route should map methods to def name' 50 | self.assertEqual(action1_route.mapping, {'get': 'action1'}, msg) 51 | 52 | msg = '@detail_route should use url with detail lookup' 53 | self.assertEqual(action1_route.url, u'^{prefix}/{lookup}/action1{trailing_slash}$', msg) 54 | 55 | def test_detail_route__with_methods(self): 56 | class BasicViewSet(viewsets.ViewSet): 57 | @action(detail=True, methods=['post']) 58 | def action1(self, request, *args, **kwargs): 59 | pass 60 | 61 | routes = self.router.get_routes(BasicViewSet) 62 | action1_route = self.get_dynamic_route_by_def_name('action1', routes) 63 | 64 | msg = '@detail_route should map methods to def name' 65 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg) 66 | 67 | msg = '@detail_route should use url with detail lookup' 68 | self.assertEqual(action1_route.url, u'^{prefix}/{lookup}/action1{trailing_slash}$', msg) 69 | 70 | def test_detail_route__with_methods__and__with_url_path(self): 71 | class BasicViewSet(viewsets.ViewSet): 72 | @action(detail=True, methods=['post'], url_path='action-one') 73 | def action1(self, request, *args, **kwargs): 74 | pass 75 | 76 | routes = self.router.get_routes(BasicViewSet) 77 | action1_route = self.get_dynamic_route_by_def_name('action1', routes) 78 | 79 | msg = '@detail_route should map methods to "url_path"' 80 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg) 81 | 82 | msg = '@detail_route should use url with detail lookup and "url_path" value' 83 | self.assertEqual(action1_route.url, u'^{prefix}/{lookup}/action-one{trailing_slash}$', msg) 84 | 85 | def test_list_route(self): 86 | class BasicViewSet(viewsets.ViewSet): 87 | @action(detail=False) 88 | def action1(self, request, *args, **kwargs): 89 | pass 90 | 91 | routes = self.router.get_routes(BasicViewSet) 92 | action1_route = self.get_dynamic_route_by_def_name('action1', routes) 93 | 94 | msg = '@list_route should map methods to def name' 95 | self.assertEqual(action1_route.mapping, {'get': 'action1'}, msg) 96 | 97 | msg = '@list_route should use url in list scope' 98 | self.assertEqual(action1_route.url, u'^{prefix}/action1{trailing_slash}$', msg) 99 | 100 | def test_list_route__with_methods(self): 101 | class BasicViewSet(viewsets.ViewSet): 102 | @action(detail=False, methods=['post']) 103 | def action1(self, request, *args, **kwargs): 104 | pass 105 | 106 | routes = self.router.get_routes(BasicViewSet) 107 | action1_route = self.get_dynamic_route_by_def_name('action1', routes) 108 | 109 | msg = '@list_route should map methods to def name' 110 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg) 111 | 112 | msg = '@list_route should use url in list scope' 113 | self.assertEqual(action1_route.url, u'^{prefix}/action1{trailing_slash}$', msg) 114 | 115 | def test_list_route__with_methods__and__with_url_path(self): 116 | class BasicViewSet(viewsets.ViewSet): 117 | @action(detail=False, methods=['post'], url_path='action-one') 118 | def action1(self, request, *args, **kwargs): 119 | pass 120 | 121 | routes = self.router.get_routes(BasicViewSet) 122 | action1_route = self.get_dynamic_route_by_def_name('action1', routes) 123 | 124 | msg = '@list_route should map methods to "url_path"' 125 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg) 126 | 127 | msg = '@list_route should use url in list scope with "url_path" value' 128 | self.assertEqual(action1_route.url, u'^{prefix}/action-one{trailing_slash}$', msg) 129 | 130 | def test_list_route_and_detail_route_with_exact_names(self): 131 | class BasicViewSet(viewsets.ViewSet): 132 | @action(detail=False, url_path='action-one') 133 | def action1(self, request, *args, **kwargs): 134 | pass 135 | 136 | @action(detail=True, url_path='action-one') 137 | def action1_detail(self, request, *args, **kwargs): 138 | pass 139 | 140 | routes = self.router.get_routes(BasicViewSet) 141 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes) 142 | action1_detail_route = self.get_dynamic_route_by_def_name('action1_detail', routes) 143 | 144 | self.assertEqual(action1_list_route.mapping, {'get': 'action1'}) 145 | self.assertEqual(action1_list_route.url, u'^{prefix}/action-one{trailing_slash}$') 146 | 147 | self.assertEqual(action1_detail_route.mapping, {'get': 'action1_detail'}) 148 | self.assertEqual(action1_detail_route.url, u'^{prefix}/{lookup}/action-one{trailing_slash}$') 149 | 150 | def test_list_route_and_detail_route_names(self): 151 | class BasicViewSet(viewsets.ViewSet): 152 | @action(detail=False) 153 | def action1(self, request, *args, **kwargs): 154 | pass 155 | 156 | @action(detail=True) 157 | def action2(self, request, *args, **kwargs): 158 | pass 159 | 160 | routes = self.router.get_routes(BasicViewSet) 161 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes) 162 | action2_detail_route = self.get_dynamic_route_by_def_name('action2', routes) 163 | 164 | self.assertEqual(action1_list_route.name, u'{basename}-action1') 165 | self.assertEqual(action2_detail_route.name, u'{basename}-action2') 166 | 167 | def test_list_route_and_detail_route_default_names__with_endpoints(self): 168 | class BasicViewSet(viewsets.ViewSet): 169 | @action(detail=False, url_path='action_one') 170 | def action1(self, request, *args, **kwargs): 171 | pass 172 | 173 | @action(detail=True, url_path='action-two') 174 | def action2(self, request, *args, **kwargs): 175 | pass 176 | 177 | routes = self.router.get_routes(BasicViewSet) 178 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes) 179 | action2_detail_route = self.get_dynamic_route_by_def_name('action2', routes) 180 | 181 | self.assertEqual(action1_list_route.name, u'{basename}-action1') 182 | self.assertEqual(action2_detail_route.name, u'{basename}-action2') 183 | 184 | def test_list_route_and_detail_route_names__with_endpoints(self): 185 | class BasicViewSet(viewsets.ViewSet): 186 | @action(detail=False, url_path='action_one', url_name='action_one') 187 | def action1(self, request, *args, **kwargs): 188 | pass 189 | 190 | @action(detail=True, url_path='action-two', url_name='action-two') 191 | def action2(self, request, *args, **kwargs): 192 | pass 193 | 194 | routes = self.router.get_routes(BasicViewSet) 195 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes) 196 | action2_detail_route = self.get_dynamic_route_by_def_name('action2', routes) 197 | 198 | self.assertEqual(action1_list_route.name, u'{basename}-action_one') 199 | self.assertEqual(action2_detail_route.name, u'{basename}-action-two') 200 | -------------------------------------------------------------------------------- /tests_app/tests/unit/serializers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/serializers/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class UserModel(models.Model): 5 | name = models.CharField(max_length=20) 6 | 7 | 8 | class CommentModel(models.Model): 9 | user = models.ForeignKey( 10 | UserModel, 11 | related_name='comments', 12 | on_delete=models.CASCADE, 13 | ) 14 | users_liked = models.ManyToManyField(UserModel, blank=True) 15 | title = models.CharField(max_length=20) 16 | text = models.CharField(max_length=200) 17 | attachment = models.FileField( 18 | upload_to='test_serializers', blank=True, null=True, max_length=500) 19 | hidden_text = models.CharField(max_length=200, blank=True, null=True) 20 | -------------------------------------------------------------------------------- /tests_app/tests/unit/serializers/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | from rest_framework_extensions import serializers as drf_serializers 3 | 4 | from .models import CommentModel, UserModel 5 | 6 | 7 | class UserSerializer(drf_serializers.PartialUpdateSerializerMixin, 8 | serializers.ModelSerializer): 9 | class Meta: 10 | model = UserModel 11 | fields = ( 12 | 'name', 13 | 'comments' 14 | ) 15 | 16 | 17 | class CommentSerializer(drf_serializers.PartialUpdateSerializerMixin, 18 | serializers.ModelSerializer): 19 | title_from_source = serializers.CharField(source='title', required=False) 20 | 21 | class Meta: 22 | model = CommentModel 23 | fields = ( 24 | 'id', 25 | 'user', 26 | 'users_liked', 27 | 'title', 28 | 'text', 29 | 'attachment', 30 | 'title_from_source' 31 | ) 32 | 33 | 34 | class CommentTextSerializer(drf_serializers.PartialUpdateSerializerMixin, 35 | serializers.ModelSerializer): 36 | 37 | class Meta: 38 | model = CommentModel 39 | fields = ( 40 | 'title', 41 | 'text' 42 | ) 43 | 44 | 45 | class CommentSerializerWithGroupedFields(CommentSerializer): 46 | text_content = CommentTextSerializer(source='*') 47 | 48 | class Meta(CommentSerializer.Meta): 49 | fields = ( 50 | 'id', 51 | 'user', 52 | 'users_liked', 53 | 'attachment', 54 | 'title_from_source', 55 | 'text_content' 56 | ) 57 | 58 | 59 | class CommentSerializerWithAllowedUserId(CommentSerializer): 60 | user_id = serializers.IntegerField() 61 | 62 | class Meta(CommentSerializer.Meta): 63 | fields = ('user_id',) + CommentSerializer.Meta.fields 64 | 65 | 66 | class CommentSerializerWithExpandedUsersLiked(drf_serializers.PartialUpdateSerializerMixin, 67 | serializers.ModelSerializer): 68 | user = UserSerializer() 69 | 70 | class Meta: 71 | model = CommentModel 72 | fields = ( 73 | 'title', 74 | 'user' 75 | ) 76 | -------------------------------------------------------------------------------- /tests_app/tests/unit/serializers/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | from django.core.files.base import ContentFile 3 | from rest_framework_extensions.serializers import get_fields_for_partial_update 4 | 5 | from .serializers import CommentSerializer, UserSerializer, \ 6 | CommentSerializerWithGroupedFields, CommentSerializerWithAllowedUserId 7 | from .models import UserModel, CommentModel 8 | 9 | 10 | class PartialUpdateSerializerMixinTest(TestCase): 11 | def setUp(self): 12 | self.files = [ 13 | ContentFile(u'file one'.encode('utf-8'), name='file1.txt'), 14 | ContentFile(u'file two'.encode('utf-8'), name='file2.txt'), 15 | ] 16 | self.files[0].size = 8 17 | self.files[1].size = 8 18 | self.user = UserModel.objects.create(name='gena') 19 | self.comment = CommentModel.objects.create( 20 | user=self.user, 21 | title='hello', 22 | text='world', 23 | attachment=self.files[0] 24 | ) 25 | 26 | def get_comment(self): 27 | return CommentModel.objects.get(pk=self.comment.pk) 28 | 29 | def test_should_use_default_saving_without_partial(self): 30 | serializer = CommentSerializer(data={ 31 | 'user': self.user.id, 32 | 'title': 'hola', 33 | 'text': 'amigos', 34 | }) 35 | 36 | self.assertTrue(serializer.is_valid()) # bug for python3 comes from here 37 | 38 | saved_object = serializer.save() 39 | self.assertEqual(saved_object.user, self.user) 40 | self.assertEqual(saved_object.title, 'hola') 41 | self.assertEqual(saved_object.text, 'amigos') 42 | 43 | def test_should_save_partial(self): 44 | serializer = CommentSerializer( 45 | instance=self.comment, data={'title': 'hola'}, partial=True) 46 | self.assertTrue(serializer.is_valid()) 47 | saved_object = serializer.save() 48 | self.assertEqual(saved_object.user, self.user) 49 | self.assertEqual(saved_object.title, 'hola') 50 | self.assertEqual(saved_object.text, 'world') 51 | 52 | def test_update_fields_correctly_determined(self): 53 | serializer = CommentSerializerWithGroupedFields( 54 | instance=self.comment, 55 | data={'text_content': {'title': 'hola', 'text': 'group'}, 56 | 'title_from_source': 'hola', 'attachment': self.files[1]}, 57 | partial=True) 58 | update_fields = get_fields_for_partial_update( 59 | serializer.Meta, 60 | serializer.get_initial(), 61 | serializer.fields.fields) 62 | self.assertListEqual(update_fields, ['attachment', 'text', 'title']) 63 | 64 | def test_should_save_grouped_partial(self): 65 | serializer = CommentSerializerWithGroupedFields( 66 | instance=self.comment, 67 | data={'text_content': {'title': 'hola', 'text': 'group'}}, 68 | partial=True) 69 | self.assertTrue(serializer.is_valid()) 70 | serializer.save() 71 | self.comment.refresh_from_db() 72 | self.assertEqual(self.comment.user, self.user) 73 | self.assertEqual(self.comment.title, 'hola') 74 | self.assertEqual(self.comment.text, 'group') 75 | 76 | def test_should_save_only_fields_from_data_for_partial_update(self): 77 | # it's important to use different instances for Comment, 78 | # because serializer's save method affects instance from arguments 79 | serializer_one = CommentSerializer( 80 | instance=self.get_comment(), 81 | data={'title': 'goodbye'}, partial=True) 82 | serializer_two = CommentSerializer( 83 | instance=self.get_comment(), data={'text': 'moon'}, partial=True) 84 | serializer_three_kwargs = { 85 | 'instance': self.get_comment(), 86 | 'partial': True 87 | } 88 | serializer_three_kwargs['data'] = {'attachment': self.files[1]} 89 | serializer_three = CommentSerializer(**serializer_three_kwargs) 90 | self.assertTrue(serializer_one.is_valid()) 91 | self.assertTrue(serializer_two.is_valid()) 92 | self.assertTrue(serializer_three.is_valid()) 93 | 94 | # saving three serializers expecting they don't affect each other's saving 95 | serializer_one.save() 96 | serializer_two.save() 97 | serializer_three.save() 98 | 99 | fresh_instance = self.get_comment() 100 | 101 | self.assertEqual(fresh_instance.attachment.read(), u'file two'.encode('utf-8')) 102 | fresh_instance.attachment.close() 103 | 104 | self.assertEqual(fresh_instance.text, 'moon') 105 | self.assertEqual(fresh_instance.title, 'goodbye') 106 | 107 | def test_should_use_related_field_name_for_update_field_list(self): 108 | another_user = UserModel.objects.create(name='vova') 109 | data = { 110 | 'title': 'goodbye', 111 | 'user': another_user.pk 112 | } 113 | serializer = CommentSerializer( 114 | instance=self.get_comment(), data=data, partial=True) 115 | self.assertTrue(serializer.is_valid()) 116 | serializer.save() 117 | fresh_instance = self.get_comment() 118 | self.assertEqual(fresh_instance.title, 'goodbye') 119 | self.assertEqual(fresh_instance.user, another_user) 120 | 121 | def test_should_use_field_source_value_for_searching_model_concrete_fields(self): 122 | data = { 123 | 'title_from_source': 'goodbye' 124 | } 125 | serializer = CommentSerializer( 126 | instance=self.get_comment(), data=data, partial=True) 127 | self.assertTrue(serializer.is_valid()) 128 | serializer.save() 129 | fresh_instance = self.get_comment() 130 | self.assertEqual(fresh_instance.title, 'goodbye') 131 | 132 | def test_should_not_use_m2m_field_name_for_update_field_list(self): 133 | another_user = UserModel.objects.create(name='vova') 134 | data = { 135 | 'title': 'goodbye', 136 | 'users_liked': [self.user.pk, another_user.pk] 137 | } 138 | serializer = CommentSerializer( 139 | instance=self.get_comment(), data=data, partial=True) 140 | self.assertTrue(serializer.is_valid()) 141 | try: 142 | serializer.save() 143 | except ValueError: 144 | self.fail( 145 | 'If m2m field used in partial update then it should not be used in update_fields list') 146 | fresh_instance = self.get_comment() 147 | self.assertEqual(fresh_instance.title, 'goodbye') 148 | users_liked = set( 149 | fresh_instance.users_liked.all().values_list('pk', flat=True)) 150 | self.assertEqual( 151 | users_liked, set([self.user.pk, another_user.pk])) 152 | 153 | def test_should_not_use_related_set_field_name_for_update_field_list(self): 154 | another_user = UserModel.objects.create(name='vova') 155 | another_comment = CommentModel.objects.create( 156 | user=another_user, 157 | title='goodbye', 158 | text='moon', 159 | ) 160 | data = { 161 | 'name': 'vova', 162 | 'comments': [another_comment.pk] 163 | } 164 | serializer = UserSerializer(instance=another_user, data=data, partial=True) 165 | self.assertTrue(serializer.is_valid()) 166 | serializer.save() 167 | try: 168 | serializer.save() 169 | except ValueError: 170 | self.fail('If related set field used in partial update then it should not be used in update_fields list') 171 | fresh_comment = CommentModel.objects.get(pk=another_comment.pk) 172 | fresh_user = UserModel.objects.get(pk=another_user.pk) 173 | self.assertEqual(fresh_comment.user, another_user) 174 | self.assertEqual(fresh_user.name, 'vova') 175 | 176 | def test_should_not_try_to_update_fields_that_are_not_in_model(self): 177 | data = { 178 | 'title': 'goodbye', 179 | 'not_existing_field': 'moon' 180 | } 181 | serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) 182 | self.assertTrue(serializer.is_valid()) 183 | try: 184 | serializer.save() 185 | except ValueError: 186 | msg = 'Should not pass values to update_fields from data, if they are not in model' 187 | self.fail(msg) 188 | fresh_instance = self.get_comment() 189 | self.assertEqual(fresh_instance.title, 'goodbye') 190 | self.assertEqual(fresh_instance.text, 'world') 191 | 192 | def test_should_not_try_to_update_fields_that_are_not_allowed_from_serializer(self): 193 | data = { 194 | 'title': 'goodbye', 195 | 'hidden_text': 'do not change me' 196 | } 197 | serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) 198 | self.assertTrue(serializer.is_valid()) 199 | serializer.save() 200 | fresh_instance = self.get_comment() 201 | self.assertEqual(fresh_instance.title, 'goodbye') 202 | self.assertEqual(fresh_instance.text, 'world') 203 | self.assertEqual(fresh_instance.hidden_text, None) 204 | 205 | def test_should_use_list_of_fields_to_update_from_arguments_if_it_passed(self): 206 | data = { 207 | 'title': 'goodbye', 208 | 'text': 'moon' 209 | } 210 | serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True) 211 | self.assertTrue(serializer.is_valid()) 212 | serializer.save(**{'update_fields': ['title']}) 213 | fresh_instance = self.get_comment() 214 | self.assertEqual(fresh_instance.title, 'goodbye') 215 | self.assertEqual(fresh_instance.text, 'world') 216 | 217 | def test_should_not_use_field_attname_for_update_fields__if_attname_not_allowed_in_serializer_fields(self): 218 | another_user = UserModel.objects.create(name='vova') 219 | data = { 220 | 'title': 'goodbye', 221 | 'user_id': another_user.id 222 | } 223 | serializer = CommentSerializer( 224 | instance=self.get_comment(), data=data, partial=True) 225 | self.assertTrue(serializer.is_valid()) 226 | serializer.save() 227 | fresh_instance = self.get_comment() 228 | self.assertEqual(fresh_instance.user_id, self.user.id) 229 | 230 | def test_should_use_field_attname_for_update_fields__if_attname_allowed_in_serializer_fields(self): 231 | another_user = UserModel.objects.create(name='vova') 232 | data = { 233 | 'title': 'goodbye', 234 | 'user_id': another_user.id 235 | } 236 | serializer = CommentSerializerWithAllowedUserId( 237 | instance=self.get_comment(), data=data, partial=True) 238 | self.assertTrue(serializer.is_valid()) 239 | serializer.save() 240 | fresh_instance = self.get_comment() 241 | self.assertEqual(fresh_instance.user_id, another_user.id) 242 | 243 | def test_should_not_use_pk_field_for_update_fields(self): 244 | old_pk = self.get_comment().pk 245 | data = { 246 | 'id': old_pk + 1, 247 | 'title': 'goodbye' 248 | } 249 | serializer = CommentSerializer( 250 | instance=self.get_comment(), data=data, partial=True) 251 | self.assertTrue(serializer.is_valid()) 252 | try: 253 | serializer.save() 254 | except ValueError: 255 | self.fail( 256 | 'Primary key field should be excluded from update_fields list') 257 | fresh_instance = self.get_comment() 258 | self.assertEqual(fresh_instance.pk, old_pk) 259 | self.assertEqual(fresh_instance.title, u'goodbye') 260 | -------------------------------------------------------------------------------- /tests_app/tests/unit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests_app/tests/unit/utils/tests.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | try: 3 | from unittest import mock 4 | except ImportError: 5 | import mock 6 | 7 | from django.test import TestCase 8 | 9 | from rest_framework_extensions.utils import prepare_header_name, get_rest_framework_version 10 | 11 | 12 | @contextlib.contextmanager 13 | def parsed_version(version): 14 | with mock.patch('rest_framework.VERSION', version): 15 | yield get_rest_framework_version() 16 | 17 | 18 | class TestPrepareHeaderName(TestCase): 19 | def test_upper(self): 20 | self.assertEqual(prepare_header_name('Accept'), 'HTTP_ACCEPT') 21 | 22 | def test_replace_dash_with_underscores(self): 23 | self.assertEqual( 24 | prepare_header_name('Accept-Language'), 'HTTP_ACCEPT_LANGUAGE') 25 | 26 | def test_strips_whitespaces(self): 27 | self.assertEqual( 28 | prepare_header_name(' Accept-Language '), 'HTTP_ACCEPT_LANGUAGE') 29 | 30 | def test_adds_http_prefix(self): 31 | self.assertEqual( 32 | prepare_header_name('Accept-Language'), 'HTTP_ACCEPT_LANGUAGE') 33 | 34 | def test_get_rest_framework_version_exotic_version(self): 35 | """See """ 36 | with parsed_version('1.2a2') as version: 37 | self.assertEqual(version, (1, 2)) 38 | 39 | def test_get_rest_framework_version_normal_version(self): 40 | """See """ 41 | with parsed_version('3.14.16') as version: 42 | self.assertEqual(version, (3, 14, 16)) 43 | -------------------------------------------------------------------------------- /tests_app/testutils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | try: 3 | from unittest.mock import patch 4 | except ImportError: 5 | from mock import patch 6 | 7 | from rest_framework import HTTP_HEADER_ENCODING 8 | 9 | from rest_framework_extensions.key_constructor import bits 10 | from rest_framework_extensions.key_constructor.constructors import KeyConstructor 11 | 12 | 13 | def get_url_pattern_by_regex_pattern(patterns, pattern_string): 14 | # todo: test me 15 | for pattern in patterns: 16 | if pattern.pattern.regex.pattern == pattern_string: 17 | return pattern 18 | 19 | 20 | def override_extensions_api_settings(**kwargs): 21 | return patch.multiple( 22 | 'rest_framework_extensions.settings.extensions_api_settings', 23 | **kwargs 24 | ) 25 | 26 | 27 | def basic_auth_header(username, password): 28 | credentials = ('%s:%s' % (username, password)) 29 | base64_credentials = base64.b64encode( 30 | credentials.encode(HTTP_HEADER_ENCODING) 31 | ).decode(HTTP_HEADER_ENCODING) 32 | return 'Basic %s' % base64_credentials 33 | 34 | 35 | class TestFormatKeyBit(bits.KeyBitBase): 36 | def get_data(self, **kwargs): 37 | return u'json' 38 | 39 | 40 | class TestLanguageKeyBit(bits.KeyBitBase): 41 | def get_data(self, **kwargs): 42 | return u'ru' 43 | 44 | 45 | class TestUsedKwargsKeyBit(bits.KeyBitBase): 46 | def get_data(self, **kwargs): 47 | return kwargs 48 | 49 | 50 | class TestKeyConstructor(KeyConstructor): 51 | format = TestFormatKeyBit() 52 | language = TestLanguageKeyBit() 53 | -------------------------------------------------------------------------------- /tests_app/urls.py: -------------------------------------------------------------------------------- 1 | urlpatterns = [] 2 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py{38,39}-django{22}-drf{311,312} 3 | py{38,39,310,311,312}-django{32}-drf{312,313,314} 4 | py{38,39,310,311,312}-django{42}-drf{314,315} 5 | py{310,311,312}-django{52}-drf{314,315} 6 | 7 | 8 | [testenv] 9 | deps= 10 | -rtests_app/requirements.txt 11 | django-guardian>=1.4.4 12 | drf311: djangorestframework>=3.11,<3.12 13 | djangorestframework-guardian 14 | drf312: djangorestframework>=3.12,<3.13 15 | djangorestframework-guardian 16 | drf313: djangorestframework>=3.13,<3.14 17 | djangorestframework-guardian 18 | drf314: djangorestframework>=3.14,<3.15 19 | djangorestframework-guardian 20 | drf315: djangorestframework>=3.15,<3.16 21 | djangorestframework-guardian 22 | django22: Django>=2.2,<3.0 23 | django32: Django>=3.2,<4.0 24 | django42: Django>=4.2,<5.0 25 | django52: Django>=5.2,<6.0 26 | 27 | 28 | setenv = 29 | PYTHONPATH = {toxinidir}:{toxinidir}/tests_app 30 | commands = 31 | python --version 32 | pip freeze 33 | python -Wd {envbindir}/django-admin test --settings=settings {posargs} 34 | --------------------------------------------------------------------------------