├── .github
└── workflows
│ ├── ci.yml
│ └── codeql-analysis.yml
├── .gitignore
├── .travis.yml
├── AUTHORS.md
├── GNUmakefile
├── LICENSE
├── MANIFEST.in
├── README.md
├── SECURITY.md
├── docs
├── backdoc.py
├── index.html
└── index.md
├── rest_framework_extensions
├── __init__.py
├── bulk_operations
│ ├── __init__.py
│ └── mixins.py
├── cache
│ ├── __init__.py
│ ├── decorators.py
│ └── mixins.py
├── compat.py
├── decorators.py
├── etag
│ ├── __init__.py
│ ├── decorators.py
│ └── mixins.py
├── exceptions.py
├── fields.py
├── key_constructor
│ ├── __init__.py
│ ├── bits.py
│ └── constructors.py
├── mixins.py
├── permissions.py
├── routers.py
├── serializers.py
├── settings.py
├── test.py
└── utils.py
├── setup.cfg
├── setup.py
├── tests_app
├── __init__.py
├── plugins.py
├── requirements.txt
├── settings.py
├── tests
│ ├── __init__.py
│ ├── functional
│ │ ├── __init__.py
│ │ ├── _concurrency
│ │ │ ├── __init__.py
│ │ │ └── conditional_request
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── serializers.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ ├── _examples
│ │ │ ├── __init__.py
│ │ │ └── etags
│ │ │ │ ├── __init__.py
│ │ │ │ └── remove_etag_gzip_postfix
│ │ │ │ ├── __init__.py
│ │ │ │ ├── middleware.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ ├── cache
│ │ │ ├── __init__.py
│ │ │ └── decorators
│ │ │ │ ├── __init__.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ ├── key_constructor
│ │ │ ├── __init__.py
│ │ │ └── bits
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── serializers.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ ├── migrations
│ │ │ ├── 0001_initial.py
│ │ │ ├── 0002_nestedroutermixinusermodel_code.py
│ │ │ └── __init__.py
│ │ ├── mixins
│ │ │ ├── __init__.py
│ │ │ ├── detail_serializer_mixin
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── serializers.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ │ ├── list_destroy_model_mixin
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ │ ├── list_update_model_mixin
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── serializers.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ │ └── paginate_by_max_mixin
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── pagination.py
│ │ │ │ ├── serializers.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ ├── models.py
│ │ ├── permissions
│ │ │ ├── __init__.py
│ │ │ └── extended_django_object_permissions
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ ├── tests.py
│ │ │ │ ├── urls.py
│ │ │ │ └── views.py
│ │ └── routers
│ │ │ ├── __init__.py
│ │ │ ├── extended_default_router
│ │ │ ├── __init__.py
│ │ │ ├── models.py
│ │ │ ├── tests.py
│ │ │ ├── urls.py
│ │ │ └── views.py
│ │ │ ├── models.py
│ │ │ ├── nested_router_mixin
│ │ │ ├── __init__.py
│ │ │ ├── models.py
│ │ │ ├── serializers.py
│ │ │ ├── tests.py
│ │ │ ├── urls.py
│ │ │ ├── urls_generic_relations.py
│ │ │ ├── urls_parent_viewset_lookup.py
│ │ │ └── views.py
│ │ │ ├── tests.py
│ │ │ └── views.py
│ └── unit
│ │ ├── __init__.py
│ │ ├── _etag
│ │ ├── __init__.py
│ │ └── decorators
│ │ │ ├── __init__.py
│ │ │ └── tests.py
│ │ ├── cache
│ │ ├── __init__.py
│ │ └── decorators
│ │ │ ├── __init__.py
│ │ │ └── tests.py
│ │ ├── decorators
│ │ ├── __init__.py
│ │ └── tests.py
│ │ ├── key_constructor
│ │ ├── __init__.py
│ │ ├── bits
│ │ │ ├── __init__.py
│ │ │ ├── models.py
│ │ │ └── tests.py
│ │ └── constructor
│ │ ├── migrations
│ │ ├── 0001_initial.py
│ │ └── __init__.py
│ │ ├── models.py
│ │ ├── routers
│ │ ├── __init__.py
│ │ ├── nested_router_mixin
│ │ │ ├── __init__.py
│ │ │ ├── models.py
│ │ │ ├── tests.py
│ │ │ └── views.py
│ │ └── tests.py
│ │ ├── serializers
│ │ ├── __init__.py
│ │ ├── models.py
│ │ ├── serializers.py
│ │ └── tests.py
│ │ └── utils
│ │ ├── __init__.py
│ │ └── tests.py
├── testutils.py
└── urls.py
└── tox.ini
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Django CI
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build:
12 |
13 | runs-on: ubuntu-latest
14 | strategy:
15 | max-parallel: 4
16 | matrix:
17 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
18 | django-version: ["2.2","3.2", "4.2", "5.2"]
19 |
20 | steps:
21 | - uses: actions/checkout@v3
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v4
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install Dependencies
27 | run: |
28 | python -m pip install --upgrade pip tox
29 | pip install -r tests_app/requirements.txt
30 | - name: Run Tests
31 | run: |
32 | tox -- tests_app
33 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ master ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ master ]
20 |
21 | jobs:
22 | analyze:
23 | name: Analyze
24 | runs-on: ubuntu-latest
25 | permissions:
26 | actions: read
27 | contents: read
28 | security-events: write
29 |
30 | strategy:
31 | fail-fast: false
32 | matrix:
33 | language: [ 'python' ]
34 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
35 | # Learn more about CodeQL language support at https://git.io/codeql-language-support
36 |
37 | steps:
38 | - name: Checkout repository
39 | uses: actions/checkout@v2
40 |
41 | # Initializes the CodeQL tools for scanning.
42 | - name: Initialize CodeQL
43 | uses: github/codeql-action/init@v1
44 | with:
45 | languages: ${{ matrix.language }}
46 | # If you wish to specify custom queries, you can do so here or in a config file.
47 | # By default, queries listed here will override any specified in a config file.
48 | # Prefix the list here with "+" to use these queries and those in the config file.
49 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
50 |
51 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
52 | # If this step fails, then you should remove it and run the build manually (see below)
53 | - name: Autobuild
54 | uses: github/codeql-action/autobuild@v1
55 |
56 | # ℹ️ Command-line programs to run using the OS shell.
57 | # 📚 https://git.io/JvXDl
58 |
59 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
60 | # and modify them (or add more) to build your code if your project
61 | # uses a compiled language
62 |
63 | #- run: |
64 | # make bootstrap
65 | # make release
66 |
67 | - name: Perform CodeQL Analysis
68 | uses: github/codeql-action/analyze@v1
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.pyc
3 | *.egg-info
4 | .tox
5 | *.egg
6 | .idea
7 | env
8 | build
9 | dist
10 | .DS_Store
11 | venv
12 | tests_app/tests/files
13 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | cache: pip
3 | dist: bionic
4 | sudo: false
5 | arch:
6 | - amd64
7 | - ppc64le
8 | python:
9 | - 3.6
10 | - 3.7
11 | - 3.8
12 |
13 |
14 | install:
15 | - pip install tox tox-travis
16 |
17 | script:
18 | - tox -r
19 |
--------------------------------------------------------------------------------
/AUTHORS.md:
--------------------------------------------------------------------------------
1 | ## Original Author
2 | ---------------
3 | Gennady Chibisov https://github.com/chibisov
4 |
5 | ## Core maintainer
6 | Asif Saif Uddin https://github.com/auvipy
7 |
8 |
9 | ## Contributors
10 | ------------
11 | Luke Murphy https://github.com/lwm
12 |
--------------------------------------------------------------------------------
/GNUmakefile:
--------------------------------------------------------------------------------
1 | build_docs:
2 | PYTHONIOENCODING=utf-8 python docs/backdoc.py --title "Django Rest Framework extensions documentation" < docs/index.md > docs/index.html
3 |
4 | watch_docs:
5 | make build_docs
6 | watchmedo shell-command -p "*.md" -R -c "make build_docs" docs/
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2013 Gennady Chibisov.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README.md
3 | include tox.ini
4 | recursive-include docs *.md *.html *.txt *.py
5 | recursive-include tests_app requirements.txt *.py
6 | recursive-exclude * __pycache__
7 | global-exclude *pyc
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Django REST Framework extensions
2 |
3 | DRF-extensions is a collection of custom extensions for [Django REST Framework](https://github.com/tomchristie/django-rest-framework)
4 |
5 | Full documentation for project is available at [http://chibisov.github.io/drf-extensions/docs](http://chibisov.github.io/drf-extensions/docs)
6 |
7 | [](#backers) [](#sponsors) [](https://pypi.python.org/pypi/drf-extensions)
8 |
9 | ### Sponsor
10 |
11 | [Tidelift gives software development teams a single source for purchasing and maintaining their software, with professional grade assurances from the experts who know it best, while seamlessly integrating with existing tools.](https://tidelift.com/subscription/pkg/pypi-drf-extensions?utm_source=pypi-drf-extensions&utm_medium=referral&utm_campaign=readme)
12 |
13 |
14 | ## Requirements
15 |
16 | * Tested for Python 3.8, 3.9, 3.10, 3.11 and 3.12
17 | * Tested for Django Rest Framework 3.12, 3.13, 3.14 and 3.15
18 | * Tested for Django 2.2 to 5.2
19 | * Tested for django-filter 2.1.0+
20 |
21 | ## Installation:
22 |
23 | pip3 install drf-extensions
24 |
25 | or from github
26 |
27 | pip3 install https://github.com/chibisov/drf-extensions/archive/master.zip
28 |
29 | ## Some features
30 |
31 | * DetailSerializerMixin
32 | * Caching
33 | * Conditional requests
34 | * Customizable key construction for caching and conditional requests
35 | * Nested routes
36 | * Bulk operations
37 |
38 | Read more in [documentation](http://chibisov.github.io/drf-extensions/docs)
39 |
40 | ## Development
41 |
42 | Running the tests:
43 |
44 | $ pip3 install tox
45 | $ tox -- tests_app
46 |
47 | Running test for exact environment:
48 |
49 | $ tox -e py38 -- tests_app
50 |
51 | Recreate envs before running tests:
52 |
53 | $ tox --recreate -- tests_app
54 |
55 | Pass custom arguments:
56 |
57 | $ tox -- tests_app --verbosity=3
58 |
59 | Run with pdb support:
60 |
61 | $ tox -- tests_app --processes=0 --nocapture
62 |
63 | Run exact TestCase:
64 |
65 | $ tox -- tests_app.tests.unit.mixins.tests:DetailSerializerMixinTest_serializer_detail_class
66 |
67 | Run tests from exact module:
68 |
69 | $ tox -- tests_app.tests.unit.mixins.tests
70 |
71 | Build docs:
72 |
73 | $ make build_docs
74 |
75 | Automatically build docs by watching changes:
76 |
77 | $ pip install watchdog
78 | $ make watch_docs
79 |
80 | ## Developing new features
81 |
82 | Every new feature should be:
83 |
84 | * Documented
85 | * Tested
86 | * Implemented
87 | * Pushed to main repository
88 |
89 | ### How to write documentation
90 |
91 | When new feature implementation starts you should place it into `development version` pull. Add `Development version`
92 | section to `Release notes` and describe every new feature in it. Use `#anchors` to facilitate navigation.
93 |
94 | Every feature should have title and information that it was implemented in current development version.
95 |
96 | For example if we've just implemented `Usage of the specific cache`:
97 |
98 | ...
99 |
100 | #### Usage of the specific cache
101 |
102 | *New in DRF-extensions development version*
103 |
104 | `@cache_response` can also take...
105 |
106 | ...
107 |
108 | ### Release notes
109 |
110 | ...
111 |
112 | #### Development version
113 |
114 | * Added ability to [use a specific cache](#usage-of-the-specific-cache) for `@cache_response` decorator
115 |
116 | ## Publishing new releases
117 |
118 | Increment version in `rest_framework_extensions/__init__.py`. For example:
119 |
120 | __version__ = '0.2.2' # from 0.2.1
121 |
122 | Move to new version section all release notes in documentation.
123 |
124 | Add date for release note section.
125 |
126 | Replace in documentation all `New in DRF-extensions development version` notes to `New in DRF-extensions 0.2.2`.
127 |
128 | Rebuild documentation.
129 |
130 | Run tests.
131 |
132 | Commit changes with message "Version 0.2.2"
133 |
134 | Add new tag version for commit:
135 |
136 | $ git tag 0.2.2
137 |
138 | Push to master with tags:
139 |
140 | $ git push origin master --tags
141 |
142 | Don't forget to merge `master` to `gh-pages` branch and push to origin:
143 |
144 | $ git co gh-pages
145 | $ git merge --no-ff master
146 | $ git push origin gh-pages
147 |
148 | Publish to pypi:
149 |
150 | $ python setup.py publish
151 |
152 | ## Contributors
153 |
154 | This project exists thanks to all the people who contribute.
155 |
156 |
157 | ## Backers
158 |
159 | Thank you to all our backers! 🙏 [[Become a backer](https://opencollective.com/drf-extensions#backer)]
160 |
161 |
162 |
163 |
164 | ## Sponsors
165 |
166 | Support this project by becoming a sponsor. Your logo will show up here with a link to your website. [[Become a sponsor](https://opencollective.com/drf-extensions#sponsor)]
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | ## Supported Versions
4 |
5 |
6 | | Version | Supported |
7 | | ------- | ------------------ |
8 | | 0.7.x | :white_check_mark: |
9 | | 0.6.x | :x: |
10 | | < 0.7 | :x: |
11 |
12 | ## Reporting a Vulnerability
13 |
14 | Please report Vulnerability to auvipy@gmail.com via email.
15 |
--------------------------------------------------------------------------------
/rest_framework_extensions/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.8.0' # from 0.7.1
2 |
3 | VERSION = __version__
4 |
--------------------------------------------------------------------------------
/rest_framework_extensions/bulk_operations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/rest_framework_extensions/bulk_operations/__init__.py
--------------------------------------------------------------------------------
/rest_framework_extensions/bulk_operations/mixins.py:
--------------------------------------------------------------------------------
1 | from django.utils.encoding import force_str
2 |
3 | from rest_framework import status
4 | from rest_framework.response import Response
5 | from rest_framework_extensions.settings import extensions_api_settings
6 | from rest_framework_extensions import utils
7 |
8 |
9 | class BulkOperationBaseMixin:
10 | def is_object_operation(self):
11 | return bool(self.get_object_lookup_value())
12 |
13 | def get_object_lookup_value(self):
14 | return self.kwargs.get(getattr(self, 'lookup_url_kwarg', None) or self.lookup_field, None)
15 |
16 | def is_valid_bulk_operation(self):
17 | if extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME:
18 | header_name = utils.prepare_header_name(
19 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME)
20 | return bool(self.request.META.get(header_name, None)), {
21 | 'detail': 'Header \'{0}\' should be provided for bulk operation.'.format(
22 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME
23 | )
24 | }
25 | else:
26 | return True, {}
27 |
28 |
29 | class ListDestroyModelMixin(BulkOperationBaseMixin):
30 | def delete(self, request, *args, **kwargs):
31 | if self.is_object_operation():
32 | return super().destroy(request, *args, **kwargs)
33 | else:
34 | return self.destroy_bulk(request, *args, **kwargs)
35 |
36 | def destroy_bulk(self, request, *args, **kwargs):
37 | is_valid, errors = self.is_valid_bulk_operation()
38 | if is_valid:
39 | queryset = self.filter_queryset(self.get_queryset())
40 | self.pre_delete_bulk(queryset) # todo: test and document me
41 | queryset.delete()
42 | self.post_delete_bulk(queryset) # todo: test and document me
43 | return Response(status=status.HTTP_204_NO_CONTENT)
44 | else:
45 | return Response(errors, status=status.HTTP_400_BAD_REQUEST)
46 |
47 | def pre_delete_bulk(self, queryset):
48 | """
49 | Placeholder method for calling before deleting an queryset.
50 | """
51 | pass
52 |
53 | def post_delete_bulk(self, queryset):
54 | """
55 | Placeholder method for calling after deleting an queryset.
56 | """
57 | pass
58 |
59 |
60 | class ListUpdateModelMixin(BulkOperationBaseMixin):
61 | def patch(self, request, *args, **kwargs):
62 | if self.is_object_operation():
63 | return super().partial_update(request, *args, **kwargs)
64 | else:
65 | return self.partial_update_bulk(request, *args, **kwargs)
66 |
67 | def partial_update_bulk(self, request, *args, **kwargs):
68 | is_valid, errors = self.is_valid_bulk_operation()
69 | if is_valid:
70 | queryset = self.filter_queryset(self.get_queryset())
71 | update_bulk_dict = self.get_update_bulk_dict(
72 | serializer=self.get_serializer_class()(), data=request.data)
73 | # todo: test and document me
74 | self.pre_save_bulk(queryset, update_bulk_dict)
75 | try:
76 | queryset.update(**update_bulk_dict)
77 | except ValueError as e:
78 | errors = {
79 | 'detail': force_str(e)
80 | }
81 | return Response(errors, status=status.HTTP_400_BAD_REQUEST)
82 | # todo: test and document me
83 | self.post_save_bulk(queryset, update_bulk_dict)
84 | return Response(status=status.HTTP_204_NO_CONTENT)
85 | else:
86 | return Response(errors, status=status.HTTP_400_BAD_REQUEST)
87 |
88 | def get_update_bulk_dict(self, serializer, data):
89 | update_bulk_dict = {}
90 | for field_name, field in serializer.fields.items():
91 | if field_name in data and not field.read_only:
92 | update_bulk_dict[field.source or field_name] = data[field_name]
93 | return update_bulk_dict
94 |
95 | def pre_save_bulk(self, queryset, update_bulk_dict):
96 | """
97 | Placeholder method for calling before deleting an queryset.
98 | """
99 | pass
100 |
101 | def post_save_bulk(self, queryset, update_bulk_dict):
102 | """
103 | Placeholder method for calling after deleting an queryset.
104 | """
105 | pass
106 |
--------------------------------------------------------------------------------
/rest_framework_extensions/cache/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/rest_framework_extensions/cache/__init__.py
--------------------------------------------------------------------------------
/rest_framework_extensions/cache/decorators.py:
--------------------------------------------------------------------------------
1 | from functools import wraps, WRAPPER_ASSIGNMENTS
2 |
3 | from django.http.response import HttpResponse
4 |
5 |
6 | from rest_framework_extensions.settings import extensions_api_settings
7 |
8 |
9 | def get_cache(alias):
10 | from django.core.cache import caches
11 | return caches[alias]
12 |
13 |
14 | class CacheResponse:
15 | """
16 | Store/Receive and return cached `HttpResponse` based on DRF response.
17 |
18 |
19 | .. note::
20 | This decorator will render and discard the original DRF response in
21 | favor of Django's `HttpResponse`. The allows the cache to retain a
22 | smaller memory footprint and eliminates the need to re-render
23 | responses on each request. Furthermore it eliminates the risk for users
24 | to unknowingly cache whole Serializers and QuerySets.
25 |
26 | """
27 | def __init__(self,
28 | timeout=None,
29 | key_func=None,
30 | cache=None,
31 | cache_errors=None):
32 | if timeout is None:
33 | self.timeout = extensions_api_settings.DEFAULT_CACHE_RESPONSE_TIMEOUT
34 | else:
35 | self.timeout = timeout
36 |
37 | if key_func is None:
38 | self.key_func = extensions_api_settings.DEFAULT_CACHE_KEY_FUNC
39 | else:
40 | self.key_func = key_func
41 |
42 | if cache_errors is None:
43 | self.cache_errors = extensions_api_settings.DEFAULT_CACHE_ERRORS
44 | else:
45 | self.cache_errors = cache_errors
46 |
47 | self.cache = get_cache(cache or extensions_api_settings.DEFAULT_USE_CACHE)
48 |
49 | def __call__(self, func):
50 | this = self
51 |
52 | @wraps(func, assigned=WRAPPER_ASSIGNMENTS)
53 | def inner(self, request, *args, **kwargs):
54 | return this.process_cache_response(
55 | view_instance=self,
56 | view_method=func,
57 | request=request,
58 | args=args,
59 | kwargs=kwargs,
60 | )
61 | return inner
62 |
63 | def process_cache_response(self,
64 | view_instance,
65 | view_method,
66 | request,
67 | args,
68 | kwargs):
69 |
70 | key = self.calculate_key(
71 | view_instance=view_instance,
72 | view_method=view_method,
73 | request=request,
74 | args=args,
75 | kwargs=kwargs
76 | )
77 |
78 | timeout = self.calculate_timeout(view_instance=view_instance)
79 |
80 | response_triple = self.cache.get(key)
81 | if not response_triple:
82 | # render response to create and cache the content byte string
83 | response = view_method(view_instance, request, *args, **kwargs)
84 | response = view_instance.finalize_response(request, response, *args, **kwargs)
85 | response.render()
86 |
87 | if not response.status_code >= 400 or self.cache_errors:
88 | # django 3.0 has not .items() method, django 3.2 has not ._headers
89 | if hasattr(response, '_headers'):
90 | headers = response._headers.copy()
91 | else:
92 | headers = {k: (k, v) for k, v in response.items()}
93 | response_triple = (
94 | response.rendered_content,
95 | response.status_code,
96 | headers
97 | )
98 | self.cache.set(key, response_triple, timeout)
99 | else:
100 | # build smaller Django HttpResponse
101 | content, status, headers = response_triple
102 | response = HttpResponse(content=content, status=status)
103 | for k, v in headers.values():
104 | response[k] = v
105 | if not hasattr(response, '_closable_objects'):
106 | response._closable_objects = []
107 |
108 | return response
109 |
110 | def calculate_key(self,
111 | view_instance,
112 | view_method,
113 | request,
114 | args,
115 | kwargs):
116 | if isinstance(self.key_func, str):
117 | key_func = getattr(view_instance, self.key_func)
118 | else:
119 | key_func = self.key_func
120 | return key_func(
121 | view_instance=view_instance,
122 | view_method=view_method,
123 | request=request,
124 | args=args,
125 | kwargs=kwargs,
126 | )
127 |
128 | def calculate_timeout(self, view_instance, **_):
129 | if isinstance(self.timeout, str):
130 | self.timeout = getattr(view_instance, self.timeout)
131 | return self.timeout
132 |
133 |
134 | cache_response = CacheResponse
135 |
--------------------------------------------------------------------------------
/rest_framework_extensions/cache/mixins.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.cache.decorators import cache_response
2 | from rest_framework_extensions.settings import extensions_api_settings
3 |
4 |
5 | class BaseCacheResponseMixin:
6 | # todo: test me. Create generic test like
7 | # test_cache_reponse(view_instance, method, should_rebuild_after_method_evaluation)
8 | object_cache_key_func = extensions_api_settings.DEFAULT_OBJECT_CACHE_KEY_FUNC
9 | list_cache_key_func = extensions_api_settings.DEFAULT_LIST_CACHE_KEY_FUNC
10 | object_cache_timeout = extensions_api_settings.DEFAULT_CACHE_RESPONSE_TIMEOUT
11 | list_cache_timeout = extensions_api_settings.DEFAULT_CACHE_RESPONSE_TIMEOUT
12 |
13 |
14 | class ListCacheResponseMixin(BaseCacheResponseMixin):
15 | @cache_response(key_func='list_cache_key_func', timeout='list_cache_timeout')
16 | def list(self, request, *args, **kwargs):
17 | return super().list(request, *args, **kwargs)
18 |
19 |
20 | class RetrieveCacheResponseMixin(BaseCacheResponseMixin):
21 | @cache_response(key_func='object_cache_key_func', timeout='object_cache_timeout')
22 | def retrieve(self, request, *args, **kwargs):
23 | return super().retrieve(request, *args, **kwargs)
24 |
25 |
26 | class CacheResponseMixin(RetrieveCacheResponseMixin,
27 | ListCacheResponseMixin):
28 | pass
29 |
--------------------------------------------------------------------------------
/rest_framework_extensions/compat.py:
--------------------------------------------------------------------------------
1 | """
2 | The `compat` module provides support for backwards compatibility with older
3 | versions of django/python, and compatibility wrappers around optional packages.
4 | """
5 |
6 |
7 | # handle different QuerySet representations
8 | def queryset_to_value_list(queryset):
9 | assert isinstance(queryset, str)
10 |
11 | # django 1.10 introduces syntax ""
12 | # we extract only the list of tuples from the string
13 | idx_bracket_open = queryset.find(u'[')
14 | idx_bracket_close = queryset.rfind(u']')
15 |
16 | return queryset[idx_bracket_open:idx_bracket_close + 1]
17 |
--------------------------------------------------------------------------------
/rest_framework_extensions/decorators.py:
--------------------------------------------------------------------------------
1 | def paginate(pagination_class=None, **kwargs):
2 | """
3 | Decorator that adds a pagination_class to GenericViewSet class.
4 | Custom pagination class also available.
5 |
6 | Usage :
7 | from rest_framework.pagination import CursorPagination
8 |
9 | @paginate(pagination_class=CursorPagination, page_size=5, ordering='-created_at')
10 | class FooViewSet(viewsets.GenericViewSet):
11 | ...
12 |
13 | """
14 | assert pagination_class is not None, (
15 | "@paginate missing required argument: 'pagination_class'"
16 | )
17 |
18 | class _Pagination(pagination_class):
19 | def __init__(self):
20 | self.__dict__.update(kwargs)
21 |
22 | def decorator(_class):
23 | _class.pagination_class = _Pagination
24 | return _class
25 |
26 | return decorator
27 |
--------------------------------------------------------------------------------
/rest_framework_extensions/etag/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/rest_framework_extensions/etag/decorators.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import wraps, WRAPPER_ASSIGNMENTS
3 |
4 | from django.utils.http import parse_etags, quote_etag
5 |
6 | from rest_framework import status
7 | from rest_framework.permissions import SAFE_METHODS
8 | from rest_framework.response import Response
9 | from rest_framework_extensions.exceptions import PreconditionRequiredException
10 |
11 | from rest_framework_extensions.utils import prepare_header_name
12 | from rest_framework_extensions.settings import extensions_api_settings
13 |
14 | logger = logging.getLogger('django.request')
15 |
16 |
17 | class ETAGProcessor:
18 | """Based on https://github.com/django/django/blob/master/django/views/decorators/http.py"""
19 |
20 | def __init__(self, etag_func=None, rebuild_after_method_evaluation=False):
21 | if not etag_func:
22 | etag_func = extensions_api_settings.DEFAULT_ETAG_FUNC
23 | self.etag_func = etag_func
24 | self.rebuild_after_method_evaluation = rebuild_after_method_evaluation
25 |
26 | def __call__(self, func):
27 | this = self
28 |
29 | @wraps(func, assigned=WRAPPER_ASSIGNMENTS)
30 | def inner(self, request, *args, **kwargs):
31 | return this.process_conditional_request(
32 | view_instance=self,
33 | view_method=func,
34 | request=request,
35 | args=args,
36 | kwargs=kwargs,
37 | )
38 |
39 | return inner
40 |
41 | def process_conditional_request(self,
42 | view_instance,
43 | view_method,
44 | request,
45 | args,
46 | kwargs):
47 | etags, if_none_match, if_match = self.get_etags_and_matchers(request)
48 | res_etag = self.calculate_etag(
49 | view_instance=view_instance,
50 | view_method=view_method,
51 | request=request,
52 | args=args,
53 | kwargs=kwargs,
54 | )
55 |
56 | if self.is_if_none_match_failed(res_etag, etags, if_none_match):
57 | if request.method in SAFE_METHODS:
58 | response = Response(status=status.HTTP_304_NOT_MODIFIED)
59 | else:
60 | response = self._get_and_log_precondition_failed_response(
61 | request=request)
62 | elif self.is_if_match_failed(res_etag, etags, if_match):
63 | response = self._get_and_log_precondition_failed_response(
64 | request=request)
65 | else:
66 | response = view_method(view_instance, request, *args, **kwargs)
67 | if self.rebuild_after_method_evaluation:
68 | res_etag = self.calculate_etag(
69 | view_instance=view_instance,
70 | view_method=view_method,
71 | request=request,
72 | args=args,
73 | kwargs=kwargs,
74 | )
75 |
76 | if res_etag and not response.has_header('ETag'):
77 | response['ETag'] = quote_etag(res_etag)
78 |
79 | return response
80 |
81 | def get_etags_and_matchers(self, request):
82 | etags = None
83 | if_none_match = request.META.get(prepare_header_name("if-none-match"))
84 | if_match = request.META.get(prepare_header_name("if-match"))
85 | if if_none_match or if_match:
86 | # There can be more than one ETag in the request, so we
87 | # consider the list of values.
88 | try:
89 | value_to_parse = if_none_match or if_match
90 | if value_to_parse:
91 | etag_list = [e.strip() for e in value_to_parse.split(' ') if e.strip()]
92 | etag_list = [e if e.startswith('"') else f'"{e}"' for e in etag_list]
93 | value_to_parse = ', '.join(etag_list)
94 | etags = parse_etags(value_to_parse)
95 | except ValueError:
96 | # In case of invalid etag ignore all ETag headers.
97 | # Apparently Opera sends invalidly quoted headers at times
98 | # (we should be returning a 400 response, but that's a
99 | # little extreme) -- this is Django bug #10681.
100 | if_none_match = None
101 | if_match = None
102 | return etags, if_none_match, if_match
103 |
104 | def calculate_etag(self,
105 | view_instance,
106 | view_method,
107 | request,
108 | args,
109 | kwargs):
110 | if isinstance(self.etag_func, str):
111 | etag_func = getattr(view_instance, self.etag_func)
112 | else:
113 | etag_func = self.etag_func
114 | return etag_func(
115 | view_instance=view_instance,
116 | view_method=view_method,
117 | request=request,
118 | args=args,
119 | kwargs=kwargs,
120 | )
121 |
122 | def is_if_none_match_failed(self, res_etag, etags, if_none_match):
123 | if res_etag and if_none_match:
124 | etags = [etag.strip('"') for etag in etags]
125 | return res_etag in etags or '*' in etags
126 | else:
127 | return False
128 |
129 | def is_if_match_failed(self, res_etag, etags, if_match):
130 | if res_etag and if_match:
131 | res_etag =res_etag.strip('"')
132 | etags = [etag.strip('"') for etag in etags]
133 | matches = res_etag in etags or '*' in etags
134 | return not matches
135 | else:
136 | return False
137 |
138 | def _get_and_log_precondition_failed_response(self, request):
139 | logger.warning('Precondition Failed: %s', request.path,
140 | extra={
141 | 'status_code': status.HTTP_412_PRECONDITION_FAILED,
142 | 'request': request
143 | }
144 | )
145 | return Response(status=status.HTTP_412_PRECONDITION_FAILED)
146 |
147 |
148 | class APIETAGProcessor(ETAGProcessor):
149 | """
150 | This class is responsible for calculating the ETag value given (a list of) model instance(s).
151 |
152 | It does not make sense to compute a default ETag here, because the processor would always issue a 304 response,
153 | even if the response was modified meanwhile.
154 | Therefore the `APIETAGProcessor` cannot be used without specifying an `etag_func` as keyword argument.
155 |
156 | According to RFC 6585, conditional headers may be enforced for certain services that support conditional
157 | requests. For optimistic locking, the server should respond status code 428 including a description on how
158 | to resubmit the request successfully, see https://tools.ietf.org/html/rfc6585#section-3.
159 | """
160 |
161 | # require a pre-conditional header (e.g. If-Match) for unsafe HTTP methods (RFC 6585)
162 | # override this defaults, if required
163 | precondition_map = {'PUT': ['If-Match'],
164 | 'PATCH': ['If-Match'],
165 | 'DELETE': ['If-Match']}
166 |
167 | def __init__(self, etag_func=None, rebuild_after_method_evaluation=False, precondition_map=None):
168 | assert etag_func is not None, ('None-type functions are not allowed for processing API ETags.'
169 | 'You must specify a proper function to calculate the API ETags '
170 | 'using the "etag_func" keyword argument.')
171 |
172 | if precondition_map is not None:
173 | self.precondition_map = precondition_map
174 | assert isinstance(self.precondition_map, dict), ('`precondition_map` must be a dict, where '
175 | 'the key is the HTTP verb, and the value is a list of '
176 | 'HTTP headers that must all be present for that request.')
177 |
178 | super().__init__(etag_func=etag_func,
179 | rebuild_after_method_evaluation=rebuild_after_method_evaluation)
180 |
181 | def get_etags_and_matchers(self, request):
182 | """Get the etags from the header and perform a validation against the required preconditions."""
183 | # evaluate the preconditions, raises 428 if condition is not met
184 | self.evaluate_preconditions(request)
185 | # alright, headers are present, extract the values and match the conditions
186 | return super().get_etags_and_matchers(request)
187 |
188 | def evaluate_preconditions(self, request):
189 | """Evaluate whether the precondition for the request is met."""
190 | if request.method.upper() in self.precondition_map.keys():
191 | required_headers = self.precondition_map.get(
192 | request.method.upper(), [])
193 | # check the required headers
194 | for header in required_headers:
195 | if not request.META.get(prepare_header_name(header)):
196 | # raise an error for each header that does not match
197 | logger.warning('Precondition required: %s', request.path,
198 | extra={
199 | 'status_code': status.HTTP_428_PRECONDITION_REQUIRED,
200 | 'request': request
201 | }
202 | )
203 | # raise an RFC 6585 compliant exception
204 | raise PreconditionRequiredException(detail='Precondition required. This "%s" request '
205 | 'is required to be conditional. '
206 | 'Try again using "%s".' % (
207 | request.method, header)
208 | )
209 | return True
210 |
211 |
212 | etag = ETAGProcessor
213 | api_etag = APIETAGProcessor
214 |
--------------------------------------------------------------------------------
/rest_framework_extensions/etag/mixins.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.etag.decorators import etag, api_etag
2 | from rest_framework_extensions.settings import extensions_api_settings
3 |
4 |
5 | class BaseETAGMixin:
6 | # todo: test me. Create generic test like test_etag(view_instance,
7 | # method, should_rebuild_after_method_evaluation)
8 | object_etag_func = extensions_api_settings.DEFAULT_OBJECT_ETAG_FUNC
9 | list_etag_func = extensions_api_settings.DEFAULT_LIST_ETAG_FUNC
10 |
11 |
12 | class ListETAGMixin(BaseETAGMixin):
13 | @etag(etag_func='list_etag_func')
14 | def list(self, request, *args, **kwargs):
15 | return super().list(request, *args, **kwargs)
16 |
17 |
18 | class RetrieveETAGMixin(BaseETAGMixin):
19 | @etag(etag_func='object_etag_func')
20 | def retrieve(self, request, *args, **kwargs):
21 | return super().retrieve(request, *args, **kwargs)
22 |
23 |
24 | class UpdateETAGMixin(BaseETAGMixin):
25 | @etag(etag_func='object_etag_func', rebuild_after_method_evaluation=True)
26 | def update(self, request, *args, **kwargs):
27 | return super().update(request, *args, **kwargs)
28 |
29 |
30 | class DestroyETAGMixin(BaseETAGMixin):
31 | @etag(etag_func='object_etag_func')
32 | def destroy(self, request, *args, **kwargs):
33 | return super().destroy(request, *args, **kwargs)
34 |
35 |
36 | class ReadOnlyETAGMixin(RetrieveETAGMixin,
37 | ListETAGMixin):
38 | pass
39 |
40 |
41 | class ETAGMixin(RetrieveETAGMixin,
42 | UpdateETAGMixin,
43 | DestroyETAGMixin,
44 | ListETAGMixin):
45 | pass
46 |
47 |
48 | class APIBaseETAGMixin:
49 | # todo: test me. Create generic test like test_etag(view_instance,
50 | # method, should_rebuild_after_method_evaluation)
51 | api_object_etag_func = extensions_api_settings.DEFAULT_API_OBJECT_ETAG_FUNC
52 | api_list_etag_func = extensions_api_settings.DEFAULT_API_LIST_ETAG_FUNC
53 |
54 |
55 | class APIListETAGMixin(APIBaseETAGMixin):
56 | @api_etag(etag_func='api_list_etag_func')
57 | def list(self, request, *args, **kwargs):
58 | return super().list(request, *args, **kwargs)
59 |
60 |
61 | class APIRetrieveETAGMixin(APIBaseETAGMixin):
62 | @api_etag(etag_func='api_object_etag_func')
63 | def retrieve(self, request, *args, **kwargs):
64 | return super().retrieve(request, *args, **kwargs)
65 |
66 |
67 | class APIUpdateETAGMixin(APIBaseETAGMixin):
68 | @api_etag(etag_func='api_object_etag_func', rebuild_after_method_evaluation=True)
69 | def update(self, request, *args, **kwargs):
70 | return super().update(request, *args, **kwargs)
71 |
72 |
73 | class APIDestroyETAGMixin(APIBaseETAGMixin):
74 | @api_etag(etag_func='api_object_etag_func')
75 | def destroy(self, request, *args, **kwargs):
76 | return super().destroy(request, *args, **kwargs)
77 |
78 |
79 | class APIReadOnlyETAGMixin(APIRetrieveETAGMixin,
80 | APIListETAGMixin):
81 | pass
82 |
83 |
84 | class APIETAGMixin(APIRetrieveETAGMixin,
85 | APIUpdateETAGMixin,
86 | APIDestroyETAGMixin,
87 | APIListETAGMixin):
88 | pass
89 |
--------------------------------------------------------------------------------
/rest_framework_extensions/exceptions.py:
--------------------------------------------------------------------------------
1 | from django.utils.translation import gettext_lazy as _
2 | from rest_framework import status
3 | from rest_framework.exceptions import APIException
4 |
5 |
6 | class PreconditionRequiredException(APIException):
7 | status_code = status.HTTP_428_PRECONDITION_REQUIRED
8 | default_detail = _('This "{method}" request is required to be conditional.')
9 | default_code = 'precondition_required'
10 |
--------------------------------------------------------------------------------
/rest_framework_extensions/fields.py:
--------------------------------------------------------------------------------
1 | from rest_framework.relations import HyperlinkedRelatedField
2 |
3 |
4 | class ResourceUriField(HyperlinkedRelatedField):
5 | """
6 | Represents a hyperlinking uri that points to the
7 | detail view for that object.
8 |
9 | Example:
10 | class SurveySerializer(serializers.ModelSerializer):
11 | resource_uri = ResourceUriField(view_name='survey-detail')
12 |
13 | class Meta:
14 | model = Survey
15 | fields = ('id', 'resource_uri')
16 |
17 | ...
18 | {
19 | "id": 1,
20 | "resource_uri": "http://localhost/v1/surveys/1/",
21 | }
22 | """
23 | # todo: test me
24 | read_only = True
25 |
26 | def __init__(self, *args, **kwargs):
27 | kwargs.setdefault('source', '*')
28 | super().__init__(*args, **kwargs)
29 |
--------------------------------------------------------------------------------
/rest_framework_extensions/key_constructor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/rest_framework_extensions/key_constructor/__init__.py
--------------------------------------------------------------------------------
/rest_framework_extensions/key_constructor/bits.py:
--------------------------------------------------------------------------------
1 | from django.utils.translation import get_language
2 | from django.db.models.query import EmptyQuerySet
3 | from django.core.exceptions import EmptyResultSet
4 |
5 | from django.utils.encoding import force_str
6 |
7 | from rest_framework_extensions import compat
8 |
9 |
10 | class AllArgsMixin:
11 |
12 | def __init__(self, params='*'):
13 | super().__init__(params)
14 |
15 |
16 | class KeyBitBase:
17 | def __init__(self, params=None):
18 | self.params = params
19 |
20 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
21 | """
22 | @rtype: dict
23 | """
24 | raise NotImplementedError()
25 |
26 |
27 | class KeyBitDictBase(KeyBitBase):
28 | """Base class for dict-like source data processing.
29 |
30 | Look at HeadersKeyBit and QueryParamsKeyBit
31 |
32 | """
33 |
34 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
35 | data = {}
36 |
37 | if params is not None:
38 | source_dict = self.get_source_dict(
39 | params=params,
40 | view_instance=view_instance,
41 | view_method=view_method,
42 | request=request,
43 | args=args,
44 | kwargs=kwargs
45 | )
46 |
47 | if params == '*':
48 | params = source_dict.keys()
49 |
50 | for key in params:
51 | value = source_dict.get(
52 | self.prepare_key_for_value_retrieving(key))
53 | if value is not None:
54 | data[self.prepare_key_for_value_assignment(
55 | key)] = force_str(value)
56 |
57 | return data
58 |
59 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs):
60 | raise NotImplementedError()
61 |
62 | def prepare_key_for_value_retrieving(self, key):
63 | return key
64 |
65 | def prepare_key_for_value_assignment(self, key):
66 | return key
67 |
68 |
69 | class UniqueViewIdKeyBit(KeyBitBase):
70 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
71 | return '.'.join([
72 | view_instance.__module__,
73 | view_instance.__class__.__name__
74 | ])
75 |
76 |
77 | class UniqueMethodIdKeyBit(KeyBitBase):
78 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
79 | return '.'.join([
80 | view_instance.__module__,
81 | view_instance.__class__.__name__,
82 | view_method.__name__
83 | ])
84 |
85 |
86 | class LanguageKeyBit(KeyBitBase):
87 | """
88 | Return example:
89 | 'en'
90 |
91 | """
92 |
93 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
94 | return force_str(get_language())
95 |
96 |
97 | class FormatKeyBit(KeyBitBase):
98 | """
99 | Return example for json:
100 | u'json'
101 |
102 | Return example for html:
103 | u'html'
104 | """
105 |
106 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
107 | return force_str(request.accepted_renderer.format)
108 |
109 |
110 | class UserKeyBit(KeyBitBase):
111 | """
112 | Return example for anonymous:
113 | u'anonymous'
114 |
115 | Return example for authenticated (value is user id):
116 | u'10'
117 | """
118 |
119 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
120 | if hasattr(request, 'user') and request.user and request.user.is_authenticated:
121 | return force_str(self._get_id_from_user(request.user))
122 | else:
123 | return 'anonymous'
124 |
125 | def _get_id_from_user(self, user):
126 | return user.id
127 |
128 |
129 | class HeadersKeyBit(KeyBitDictBase):
130 | """
131 | Return example:
132 | {'accept-language': u'ru', 'x-geobase-id': '123'}
133 |
134 | """
135 |
136 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs):
137 | return request.META
138 |
139 | def prepare_key_for_value_retrieving(self, key):
140 | from rest_framework_extensions.utils import prepare_header_name
141 |
142 | # Accept-Language => http_accept_language
143 | return prepare_header_name(key.lower())
144 |
145 | def prepare_key_for_value_assignment(self, key):
146 | return key.lower() # Accept-Language => accept-language
147 |
148 |
149 | class RequestMetaKeyBit(KeyBitDictBase):
150 | """
151 | Return example:
152 | {'REMOTE_ADDR': u'127.0.0.2', 'REMOTE_HOST': u'yandex.ru'}
153 |
154 | """
155 |
156 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs):
157 | return request.META
158 |
159 |
160 | class QueryParamsKeyBit(AllArgsMixin, KeyBitDictBase):
161 | """
162 | Return example:
163 | {'part': 'Londo', 'callback': 'jquery_callback'}
164 |
165 | """
166 |
167 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs):
168 | return request.GET
169 |
170 |
171 | class PaginationKeyBit(QueryParamsKeyBit):
172 | """
173 | Return example:
174 | {'page_size': 100, 'page': '1'}
175 |
176 | """
177 | paginator_attrs = [
178 | 'page_query_param', 'page_size_query_param',
179 | 'limit_query_param', 'offset_query_param',
180 | 'cursor_query_param',
181 | ]
182 |
183 | def get_data(self, **kwargs):
184 | kwargs['params'] = []
185 | paginator = getattr(kwargs['view_instance'], 'paginator', None)
186 |
187 | if paginator:
188 | for attr in self.paginator_attrs:
189 | param = getattr(paginator, attr, None)
190 | if param:
191 | kwargs['params'].append(param)
192 |
193 | return super().get_data(**kwargs)
194 |
195 |
196 | class SqlQueryKeyBitBase(KeyBitBase):
197 | def _get_queryset_query_string(self, queryset):
198 | if isinstance(queryset, EmptyQuerySet):
199 | return None
200 | else:
201 | try:
202 | return force_str(queryset.query.__str__())
203 | except EmptyResultSet:
204 | return None
205 |
206 |
207 | class ModelInstanceKeyBitBase(KeyBitBase):
208 | """
209 | Return the actual contents of the query set.
210 | This class is similar to the `SqlQueryKeyBitBase`.
211 | """
212 |
213 | def _get_queryset_query_values(self, queryset):
214 | if isinstance(queryset, EmptyQuerySet) or queryset.count() == 0:
215 | return None
216 | else:
217 | try:
218 | # run through the instances and collect all values in ordered fashion
219 | return compat.queryset_to_value_list(force_str(queryset.values_list()))
220 | except EmptyResultSet:
221 | return None
222 |
223 |
224 | class ListSqlQueryKeyBit(SqlQueryKeyBitBase):
225 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
226 | queryset = view_instance.filter_queryset(view_instance.get_queryset())
227 | return self._get_queryset_query_string(queryset)
228 |
229 |
230 | class RetrieveSqlQueryKeyBit(SqlQueryKeyBitBase):
231 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
232 | lookup_value = view_instance.kwargs[
233 | view_instance.lookup_url_kwarg or view_instance.lookup_field]
234 | try:
235 | queryset = view_instance.filter_queryset(view_instance.get_queryset()).filter(
236 | **{view_instance.lookup_field: lookup_value}
237 | )
238 | except ValueError:
239 | return None
240 | else:
241 | return self._get_queryset_query_string(queryset)
242 |
243 |
244 | class RetrieveModelKeyBit(ModelInstanceKeyBitBase):
245 | """
246 | A key bit reflecting the contents of the model instance.
247 | Return example:
248 | u"[(3, False)]"
249 | """
250 |
251 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
252 | lookup_value = view_instance.kwargs[view_instance.lookup_field]
253 | try:
254 | queryset = view_instance.filter_queryset(view_instance.get_queryset()).filter(
255 | **{view_instance.lookup_field: lookup_value}
256 | )
257 | except ValueError:
258 | return None
259 | else:
260 | return self._get_queryset_query_values(queryset)
261 |
262 |
263 | class ListModelKeyBit(ModelInstanceKeyBitBase):
264 | """
265 | A key bit reflecting the contents of a list of model instances.
266 | Return example:
267 | u"[(1, True), (2, True), (3, False)]"
268 | """
269 |
270 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
271 | queryset = view_instance.filter_queryset(view_instance.get_queryset())
272 | return self._get_queryset_query_values(queryset)
273 |
274 |
275 | class ArgsKeyBit(AllArgsMixin, KeyBitBase):
276 |
277 | def get_data(self, params, view_instance, view_method, request, args, kwargs):
278 | if params == '*':
279 | return args
280 | elif params is not None:
281 | return [args[i] for i in params]
282 | else:
283 | return []
284 |
285 |
286 | class KwargsKeyBit(AllArgsMixin, KeyBitDictBase):
287 |
288 | def get_source_dict(self, params, view_instance, view_method, request, args, kwargs):
289 | return kwargs
290 |
--------------------------------------------------------------------------------
/rest_framework_extensions/key_constructor/constructors.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 |
4 | from rest_framework_extensions.key_constructor import bits
5 | from rest_framework_extensions.settings import extensions_api_settings
6 |
7 |
8 | class KeyConstructor:
9 | def __init__(self, memoize_for_request=None, params=None):
10 | if memoize_for_request is None:
11 | self.memoize_for_request = extensions_api_settings.DEFAULT_KEY_CONSTRUCTOR_MEMOIZE_FOR_REQUEST
12 | else:
13 | self.memoize_for_request = memoize_for_request
14 | if params is None:
15 | self.params = {}
16 | else:
17 | self.params = params
18 | self.bits = self.get_bits()
19 |
20 | def get_bits(self):
21 | _bits = {}
22 | for attr in dir(self.__class__):
23 | attr_value = getattr(self.__class__, attr)
24 | if isinstance(attr_value, bits.KeyBitBase):
25 | _bits[attr] = attr_value
26 | return _bits
27 |
28 | def __call__(self, **kwargs):
29 | return self.get_key(**kwargs)
30 |
31 | def get_key(self, view_instance, view_method, request, args, kwargs):
32 | if self.memoize_for_request:
33 | memoization_key = self._get_memoization_key(
34 | view_instance=view_instance,
35 | view_method=view_method,
36 | args=args,
37 | kwargs=kwargs
38 | )
39 | if not hasattr(request, '_key_constructor_cache'):
40 | request._key_constructor_cache = {}
41 | if self.memoize_for_request and memoization_key in request._key_constructor_cache:
42 | return request._key_constructor_cache.get(memoization_key)
43 | else:
44 | value = self._get_key(
45 | view_instance=view_instance,
46 | view_method=view_method,
47 | request=request,
48 | args=args,
49 | kwargs=kwargs
50 | )
51 | if self.memoize_for_request:
52 | request._key_constructor_cache[memoization_key] = value
53 | return value
54 |
55 | def _get_memoization_key(self, view_instance, view_method, args, kwargs):
56 | from rest_framework_extensions.utils import get_unique_method_id
57 | return json.dumps({
58 | 'unique_method_id': get_unique_method_id(view_instance=view_instance, view_method=view_method),
59 | 'args': args,
60 | 'kwargs': kwargs,
61 | 'instance_id': id(self)
62 | })
63 |
64 | def _get_key(self, view_instance, view_method, request, args, kwargs):
65 | _kwargs = {
66 | 'view_instance': view_instance,
67 | 'view_method': view_method,
68 | 'request': request,
69 | 'args': args,
70 | 'kwargs': kwargs,
71 | }
72 | return self.prepare_key(
73 | self.get_data_from_bits(**_kwargs)
74 | )
75 |
76 | def prepare_key(self, key_dict):
77 | return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode('utf-8')).hexdigest()
78 |
79 | def get_data_from_bits(self, **kwargs):
80 | result_dict = {}
81 | for bit_name, bit_instance in self.bits.items():
82 | if bit_name in self.params:
83 | params = self.params[bit_name]
84 | else:
85 | try:
86 | params = bit_instance.params
87 | except AttributeError:
88 | params = None
89 | result_dict[bit_name] = bit_instance.get_data(
90 | params=params, **kwargs)
91 | return result_dict
92 |
93 |
94 | class DefaultKeyConstructor(KeyConstructor):
95 | unique_method_id = bits.UniqueMethodIdKeyBit()
96 | format = bits.FormatKeyBit()
97 | language = bits.LanguageKeyBit()
98 |
99 |
100 | class DefaultObjectKeyConstructor(DefaultKeyConstructor):
101 | retrieve_sql_query = bits.RetrieveSqlQueryKeyBit()
102 |
103 |
104 | class DefaultListKeyConstructor(DefaultKeyConstructor):
105 | list_sql_query = bits.ListSqlQueryKeyBit()
106 | pagination = bits.PaginationKeyBit()
107 |
108 |
109 | class DefaultAPIModelInstanceKeyConstructor(KeyConstructor):
110 | """
111 | Use this constructor when the values of the model instance are required
112 | to identify the resource.
113 | """
114 | retrieve_model_values = bits.RetrieveModelKeyBit()
115 |
116 |
117 | class DefaultAPIModelListKeyConstructor(KeyConstructor):
118 | """
119 | Use this constructor when the values of the model instance are required
120 | to identify many resources.
121 | """
122 | list_model_values = bits.ListModelKeyBit()
123 |
--------------------------------------------------------------------------------
/rest_framework_extensions/mixins.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.cache.mixins import CacheResponseMixin
2 | # from rest_framework_extensions.etag.mixins import ReadOnlyETAGMixin, ETAGMixin
3 | from rest_framework_extensions.bulk_operations.mixins import ListUpdateModelMixin, ListDestroyModelMixin
4 | from rest_framework_extensions.settings import extensions_api_settings
5 | from django.core.exceptions import ValidationError
6 | from django.http import Http404
7 | import uuid
8 |
9 | class DetailSerializerMixin:
10 | """
11 | Add custom serializer for detail view
12 | """
13 | serializer_detail_class = None
14 | queryset_detail = None
15 |
16 | def get_serializer_class(self):
17 | error_message = "'{0}' should include a 'serializer_detail_class' attribute".format(
18 | self.__class__.__name__)
19 | assert self.serializer_detail_class is not None, error_message
20 | if self._is_request_to_detail_endpoint():
21 | return self.serializer_detail_class
22 | else:
23 | return super().get_serializer_class()
24 |
25 | def get_queryset(self, *args, **kwargs):
26 | if self._is_request_to_detail_endpoint() and self.queryset_detail is not None:
27 | return self.queryset_detail.all() # todo: test all()
28 | else:
29 | return super().get_queryset(*args, **kwargs)
30 |
31 | def _is_request_to_detail_endpoint(self):
32 | if hasattr(self, 'lookup_url_kwarg'):
33 | lookup = self.lookup_url_kwarg or self.lookup_field
34 | return lookup and lookup in self.kwargs
35 |
36 |
37 | class PaginateByMaxMixin:
38 |
39 | def get_page_size(self, request):
40 | if self.page_size_query_param and self.max_page_size and request.query_params.get(self.page_size_query_param) == 'max':
41 | return self.max_page_size
42 | return super().get_page_size(request)
43 |
44 |
45 | # class ReadOnlyCacheResponseAndETAGMixin(ReadOnlyETAGMixin, CacheResponseMixin):
46 | # pass
47 |
48 |
49 | # class CacheResponseAndETAGMixin(ETAGMixin, CacheResponseMixin):
50 | # pass
51 |
52 |
53 | class NestedViewSetMixin:
54 | def get_queryset(self):
55 | return self.filter_queryset_by_parents_lookups(
56 | super().get_queryset()
57 | )
58 |
59 | def filter_queryset_by_parents_lookups(self, queryset):
60 | parents_query_dict = self.get_parents_query_dict()
61 | if parents_query_dict:
62 | try:
63 | # Try to validate UUID fields before filtering
64 | cleaned_dict = {}
65 | for key, value in parents_query_dict.items():
66 | if 'uuid' in key.lower() or key.endswith('_code'):
67 | try:
68 | # Try to validate as UUID
69 | cleaned_dict[key] = uuid.UUID(str(value))
70 | except ValueError:
71 | raise Http404
72 | else:
73 | cleaned_dict[key] = value
74 | return queryset.filter(**cleaned_dict)
75 | except (ValueError, ValidationError):
76 | raise Http404
77 | else:
78 | return queryset
79 |
80 | def get_parents_query_dict(self):
81 | result = {}
82 | for kwarg_name, kwarg_value in self.kwargs.items():
83 | if kwarg_name.startswith(extensions_api_settings.DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX):
84 | query_lookup = kwarg_name.replace(
85 | extensions_api_settings.DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX,
86 | '',
87 | 1
88 | )
89 | query_value = kwarg_value
90 | result[query_lookup] = query_value
91 | return result
92 |
--------------------------------------------------------------------------------
/rest_framework_extensions/permissions.py:
--------------------------------------------------------------------------------
1 | from rest_framework.permissions import DjangoObjectPermissions
2 |
3 |
4 | class ExtendedDjangoObjectPermissions(DjangoObjectPermissions):
5 | hide_forbidden_for_read_objects = True
6 |
7 | def has_object_permission(self, request, view, obj):
8 | if self.hide_forbidden_for_read_objects:
9 | return super().has_object_permission(request, view, obj)
10 | else:
11 | model_cls = getattr(view, 'model', None)
12 | queryset = getattr(view, 'queryset', None)
13 |
14 | if model_cls is None and queryset is not None:
15 | model_cls = queryset.model
16 |
17 | perms = self.get_required_object_permissions(
18 | request.method, model_cls)
19 | user = request.user
20 |
21 | return user.has_perms(perms, obj)
22 |
--------------------------------------------------------------------------------
/rest_framework_extensions/routers.py:
--------------------------------------------------------------------------------
1 | from rest_framework.routers import DefaultRouter, SimpleRouter
2 | from rest_framework_extensions.utils import compose_parent_pk_kwarg_name
3 |
4 |
5 | class NestedRegistryItem:
6 | def __init__(self, router, parent_prefix, parent_item=None, parent_viewset=None):
7 | self.router = router
8 | self.parent_prefix = parent_prefix
9 | self.parent_item = parent_item
10 | self.parent_viewset = parent_viewset
11 |
12 | def register(self, prefix, viewset, basename, parents_query_lookups):
13 | self.router._register(
14 | prefix=self.get_prefix(
15 | current_prefix=prefix,
16 | parents_query_lookups=parents_query_lookups),
17 | viewset=viewset,
18 | basename=basename,
19 | )
20 | return NestedRegistryItem(
21 | router=self.router,
22 | parent_prefix=prefix,
23 | parent_item=self,
24 | parent_viewset=viewset
25 | )
26 |
27 | def get_prefix(self, current_prefix, parents_query_lookups):
28 | return '{0}/{1}'.format(
29 | self.get_parent_prefix(parents_query_lookups),
30 | current_prefix
31 | )
32 |
33 | def get_parent_prefix(self, parents_query_lookups):
34 | prefix = '/'
35 | current_item = self
36 | i = len(parents_query_lookups) - 1
37 | while current_item:
38 | parent_lookup_value_regex = getattr(
39 | current_item.parent_viewset, 'lookup_value_regex', '[^/.]+')
40 | prefix = '{parent_prefix}/(?P<{parent_pk_kwarg_name}>{parent_lookup_value_regex})/{prefix}'.format(
41 | parent_prefix=current_item.parent_prefix,
42 | parent_pk_kwarg_name=compose_parent_pk_kwarg_name(
43 | parents_query_lookups[i]),
44 | parent_lookup_value_regex=parent_lookup_value_regex,
45 | prefix=prefix
46 | )
47 | i -= 1
48 | current_item = current_item.parent_item
49 | return prefix.strip('/')
50 |
51 |
52 | class NestedRouterMixin:
53 | def _register(self, *args, **kwargs):
54 | return super().register(*args, **kwargs)
55 |
56 | def register(self, *args, **kwargs):
57 | self._register(*args, **kwargs)
58 | return NestedRegistryItem(
59 | router=self,
60 | parent_prefix=self.registry[-1][0],
61 | parent_viewset=self.registry[-1][1]
62 | )
63 |
64 |
65 | class ExtendedRouterMixin(NestedRouterMixin):
66 | pass
67 |
68 |
69 | class ExtendedSimpleRouter(ExtendedRouterMixin, SimpleRouter):
70 | pass
71 |
72 |
73 | class ExtendedDefaultRouter(ExtendedRouterMixin, DefaultRouter):
74 | pass
75 |
--------------------------------------------------------------------------------
/rest_framework_extensions/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.utils import get_model_opts_concrete_fields
2 |
3 |
4 | def get_fields_for_partial_update(opts, init_data, fields, init_files=None):
5 | opts = opts.model._meta.concrete_model._meta
6 | partial_fields = list((init_data or {}).keys()) + \
7 | list((init_files or {}).keys())
8 | concrete_field_names = []
9 | for field in get_model_opts_concrete_fields(opts):
10 | if not field.primary_key:
11 | concrete_field_names.append(field.name)
12 | if field.name != field.attname:
13 | concrete_field_names.append(field.attname)
14 | update_fields = []
15 | for field_name in partial_fields:
16 | if field_name in fields:
17 | model_field_name = getattr(
18 | fields[field_name], 'source') or field_name
19 | if model_field_name in concrete_field_names:
20 | update_fields.append(model_field_name)
21 |
22 | # recurse on nested fields of same ('*') instance
23 | for k, v in (init_data or {}).items():
24 | if isinstance(v, dict) and k in fields and fields[k].source == '*':
25 | recursive_fields = get_fields_for_partial_update(
26 | opts, v, fields[k].fields.fields)
27 | update_fields.extend(recursive_fields)
28 |
29 | return sorted(set(update_fields))
30 |
31 |
32 | class PartialUpdateSerializerMixin:
33 | def save(self, **kwargs):
34 | self._update_fields = kwargs.get('update_fields', None)
35 | return super().save(**kwargs)
36 |
37 | def update(self, instance, validated_attrs):
38 | for attr, value in validated_attrs.items():
39 | if hasattr(getattr(instance, attr, None), 'set'):
40 | getattr(instance, attr).set(value)
41 | else:
42 | setattr(instance, attr, value)
43 | if self.partial and isinstance(instance, self.Meta.model):
44 | instance.save(
45 | update_fields=getattr(self, '_update_fields') or get_fields_for_partial_update(
46 | opts=self.Meta,
47 | init_data=self.get_initial(),
48 | fields=self.fields.fields
49 | )
50 | )
51 | else:
52 | instance.save()
53 | return instance
54 |
--------------------------------------------------------------------------------
/rest_framework_extensions/settings.py:
--------------------------------------------------------------------------------
1 | from django.conf import settings
2 |
3 | from rest_framework.settings import APISettings
4 |
5 |
6 | USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK_EXTENSIONS', None)
7 |
8 | DEFAULTS = {
9 | # caching
10 | 'DEFAULT_USE_CACHE': 'default',
11 | 'DEFAULT_CACHE_RESPONSE_TIMEOUT': None,
12 | 'DEFAULT_CACHE_ERRORS': True,
13 | 'DEFAULT_CACHE_KEY_FUNC': 'rest_framework_extensions.utils.default_cache_key_func',
14 | 'DEFAULT_OBJECT_CACHE_KEY_FUNC': 'rest_framework_extensions.utils.default_object_cache_key_func',
15 | 'DEFAULT_LIST_CACHE_KEY_FUNC': 'rest_framework_extensions.utils.default_list_cache_key_func',
16 |
17 | # ETAG
18 | 'DEFAULT_ETAG_FUNC': 'rest_framework_extensions.utils.default_etag_func',
19 | 'DEFAULT_OBJECT_ETAG_FUNC': 'rest_framework_extensions.utils.default_object_etag_func',
20 | 'DEFAULT_LIST_ETAG_FUNC': 'rest_framework_extensions.utils.default_list_etag_func',
21 |
22 | # API - ETAG
23 | 'DEFAULT_API_OBJECT_ETAG_FUNC': 'rest_framework_extensions.utils.default_api_object_etag_func',
24 | 'DEFAULT_API_LIST_ETAG_FUNC': 'rest_framework_extensions.utils.default_api_list_etag_func',
25 |
26 | # other
27 | 'DEFAULT_KEY_CONSTRUCTOR_MEMOIZE_FOR_REQUEST': False,
28 | 'DEFAULT_BULK_OPERATION_HEADER_NAME': 'X-BULK-OPERATION',
29 | 'DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX': 'parent_lookup_'
30 | }
31 |
32 | IMPORT_STRINGS = [
33 | 'DEFAULT_CACHE_KEY_FUNC',
34 | 'DEFAULT_OBJECT_CACHE_KEY_FUNC',
35 | 'DEFAULT_LIST_CACHE_KEY_FUNC',
36 | 'DEFAULT_ETAG_FUNC',
37 | 'DEFAULT_OBJECT_ETAG_FUNC',
38 | 'DEFAULT_LIST_ETAG_FUNC',
39 | # API - ETAG
40 | 'DEFAULT_API_OBJECT_ETAG_FUNC',
41 | 'DEFAULT_API_LIST_ETAG_FUNC',
42 | ]
43 |
44 |
45 | extensions_api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
46 |
--------------------------------------------------------------------------------
/rest_framework_extensions/test.py:
--------------------------------------------------------------------------------
1 | # Leaving this module here for backwards compatibility but this is just proxy
2 | # for rest_framework.test
3 |
4 | import warnings
5 |
6 | from rest_framework.test import (
7 | force_authenticate,
8 | APIRequestFactory,
9 | ForceAuthClientHandler,
10 | APIClient,
11 | APITransactionTestCase,
12 | APITestCase
13 | )
14 |
15 |
16 | __all__ = (
17 | 'force_authenticate,'
18 | 'APIRequestFactory,'
19 | 'ForceAuthClientHandler,'
20 | 'APIClient,'
21 | 'APITransactionTestCase,'
22 | 'APITestCase'
23 | )
24 |
25 | warnings.warn(
26 | 'Use of this module is deprecated! Use rest_framework.test instead.',
27 | DeprecationWarning
28 | )
29 |
--------------------------------------------------------------------------------
/rest_framework_extensions/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from packaging.version import Version
3 |
4 | import rest_framework
5 |
6 | from rest_framework_extensions.key_constructor.constructors import (
7 | DefaultKeyConstructor,
8 | DefaultObjectKeyConstructor,
9 | DefaultListKeyConstructor,
10 | DefaultAPIModelInstanceKeyConstructor,
11 | DefaultAPIModelListKeyConstructor
12 | )
13 | from rest_framework_extensions.settings import extensions_api_settings
14 |
15 |
16 | def get_rest_framework_version():
17 | return Version(rest_framework.VERSION).release
18 |
19 |
20 | def flatten(list_of_lists):
21 | """
22 | Takes an iterable of iterables,
23 | returns a single iterable containing all items
24 | """
25 | # todo: test me
26 | return itertools.chain(*list_of_lists)
27 |
28 |
29 | def prepare_header_name(name):
30 | """
31 | >> prepare_header_name('Accept-Language')
32 | http_accept_language
33 | """
34 | return 'http_{0}'.format(name.strip().replace('-', '_')).upper()
35 |
36 |
37 | def get_unique_method_id(view_instance, view_method):
38 | # todo: test me as UniqueMethodIdKeyBit
39 | return '.'.join([
40 | view_instance.__module__,
41 | view_instance.__class__.__name__,
42 | view_method.__name__
43 | ])
44 |
45 |
46 | def get_model_opts_concrete_fields(opts):
47 | # todo: test me
48 | if not hasattr(opts, 'concrete_fields'):
49 | opts.concrete_fields = [f for f in opts.fields if f.column is not None]
50 | return opts.concrete_fields
51 |
52 |
53 | def compose_parent_pk_kwarg_name(value):
54 | return '{0}{1}'.format(
55 | extensions_api_settings.DEFAULT_PARENT_LOOKUP_KWARG_NAME_PREFIX,
56 | value
57 | )
58 |
59 |
60 | default_cache_key_func = DefaultKeyConstructor()
61 | default_object_cache_key_func = DefaultObjectKeyConstructor()
62 | default_list_cache_key_func = DefaultListKeyConstructor()
63 |
64 | default_etag_func = default_cache_key_func
65 | default_object_etag_func = default_object_cache_key_func
66 | default_list_etag_func = default_list_cache_key_func
67 |
68 | # API (object-centered) functions
69 | default_api_object_etag_func = DefaultAPIModelInstanceKeyConstructor()
70 | default_api_list_etag_func = DefaultAPIModelListKeyConstructor()
71 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | universal = 1
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from setuptools import setup
3 | import re
4 | import os
5 | import sys
6 |
7 |
8 | def get_version(package):
9 | """
10 | Return package version as listed in `__version__` in `init.py`.
11 | """
12 | init_py = open(os.path.join(package, '__init__.py')).read()
13 | return re.match("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1)
14 |
15 |
16 | def get_packages(package):
17 | """
18 | Return root package and all sub-packages.
19 | """
20 | return [dirpath
21 | for dirpath, dirnames, filenames in os.walk(package)
22 | if os.path.exists(os.path.join(dirpath, '__init__.py'))]
23 |
24 |
25 | def get_package_data(package):
26 | """
27 | Return all files under the root package, that are not in a
28 | package themselves.
29 | """
30 | walk = [(dirpath.replace(package + os.sep, '', 1), filenames)
31 | for dirpath, dirnames, filenames in os.walk(package)
32 | if not os.path.exists(os.path.join(dirpath, '__init__.py'))]
33 |
34 | filepaths = []
35 | for base, filenames in walk:
36 | filepaths.extend([os.path.join(base, filename)
37 | for filename in filenames])
38 | return {package: filepaths}
39 |
40 |
41 | version = get_version('rest_framework_extensions')
42 |
43 |
44 | if sys.argv[-1] == 'publish':
45 | os.system("python setup.py sdist upload")
46 | os.system("python setup.py bdist_wheel upload")
47 | print("You probably want to also tag the version now:")
48 | print(" git tag -a %s -m 'version %s'" % (version, version))
49 | print(" git push --tags")
50 | sys.exit()
51 |
52 |
53 | setup(
54 | name='drf-extensions',
55 | version=version,
56 | url='http://github.com/chibisov/drf-extensions',
57 | download_url='https://pypi.python.org/pypi/drf-extensions/',
58 | license='BSD',
59 | install_requires=['djangorestframework>=3.10.3', 'packaging>=24.1', 'Django>=2.2,<6.0'],
60 | description='Extensions for Django REST Framework',
61 | long_description='DRF-extensions is a collection of custom extensions for Django REST Framework',
62 | author='Asif Saif Uddin, Gennady Chibisov',
63 | author_email='auvipy@gmail.com',
64 | packages=get_packages('rest_framework_extensions'),
65 | package_data=get_package_data('rest_framework_extensions'),
66 | test_suite='rest_framework_extensions.runtests.runtests.main',
67 | classifiers=[
68 | 'Development Status :: 5 - Production/Stable',
69 | 'Environment :: Web Environment',
70 | 'Framework :: Django',
71 | 'Intended Audience :: Developers',
72 | 'License :: OSI Approved :: MIT License',
73 | 'Operating System :: OS Independent',
74 | 'Framework :: Django',
75 | 'Framework :: Django :: 3.2',
76 | 'Framework :: Django :: 4.2',
77 | 'Framework :: Django :: 5.2',
78 | 'Programming Language :: Python',
79 | 'Programming Language :: Python :: 3',
80 | 'Programming Language :: Python :: 3.9',
81 | 'Programming Language :: Python :: 3.10',
82 | 'Programming Language :: Python :: 3.11',
83 | 'Programming Language :: Python :: 3.12',
84 | 'Topic :: Internet :: WWW/HTTP',
85 | ]
86 | )
87 |
--------------------------------------------------------------------------------
/tests_app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/__init__.py
--------------------------------------------------------------------------------
/tests_app/plugins.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 |
4 | from django_nose.plugin import AlwaysOnPlugin
5 |
6 | from django.test import TestCase
7 | from django.core.cache import cache
8 | from django.conf import settings
9 |
10 |
11 | class UnitTestDiscoveryPlugin(AlwaysOnPlugin):
12 | """
13 | Enables unittest compatibility mode (dont test functions, only TestCase
14 | subclasses, and only methods that start with [Tt]est).
15 | """
16 | enabled = True
17 |
18 | def wantModule(self, module):
19 | return True
20 |
21 | def wantFile(self, file):
22 | if file.endswith('.py'):
23 | return True
24 |
25 | def wantClass(self, cls):
26 | if not issubclass(cls, TestCase):
27 | return False
28 |
29 | def wantMethod(self, method):
30 | if not method.__name__.lower().startswith('test'):
31 | return False
32 |
33 | def wantFunction(self, function):
34 | return False
35 |
36 |
37 | class PrepareRestFrameworkSettingsPlugin(AlwaysOnPlugin):
38 | def begin(self):
39 | self._monkeypatch_default_settings()
40 |
41 | def _monkeypatch_default_settings(self):
42 | from rest_framework import settings
43 |
44 | PATCH_REST_FRAMEWORK = {
45 | # Testing
46 | 'TEST_REQUEST_RENDERER_CLASSES': (
47 | 'rest_framework.renderers.MultiPartRenderer',
48 | 'rest_framework.renderers.JSONRenderer'
49 | ),
50 | 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
51 | }
52 |
53 | for key, value in PATCH_REST_FRAMEWORK.items():
54 | if key not in settings.DEFAULTS:
55 | settings.DEFAULTS[key] = value
56 |
57 |
58 | class PrepareFileStorageDir(AlwaysOnPlugin):
59 | def begin(self):
60 | if not os.path.isdir(settings.MEDIA_ROOT):
61 | os.makedirs(settings.MEDIA_ROOT)
62 |
63 | def finalize(self, result):
64 | shutil.rmtree(settings.MEDIA_ROOT, ignore_errors=True)
65 |
66 |
67 | class FlushCache(AlwaysOnPlugin):
68 | # startTest didn't work :(
69 | def begin(self):
70 | self._monkeypatch_testcase()
71 |
72 | def _monkeypatch_testcase(self):
73 | old_run = TestCase.run
74 |
75 | def new_run(*args, **kwargs):
76 | cache.clear()
77 | return old_run(*args, **kwargs)
78 | TestCase.run = new_run
79 |
--------------------------------------------------------------------------------
/tests_app/requirements.txt:
--------------------------------------------------------------------------------
1 | pynose
2 | django-nose
3 | django-filter>=2.1.0
4 | mock
5 | ipdb
6 | uuid
--------------------------------------------------------------------------------
/tests_app/settings.py:
--------------------------------------------------------------------------------
1 | # Django settings for testproject project.
2 | import multiprocessing
3 |
4 | DEBUG = True
5 | DEBUG_PROPAGATE_EXCEPTIONS = True
6 |
7 | ALLOWED_HOSTS = ['*']
8 |
9 | ADMINS = (
10 | # ('Your Name', 'your_email@domain.com'),
11 | )
12 |
13 | MANAGERS = ADMINS
14 |
15 | DATABASES = {
16 | 'default': {
17 | 'ENGINE': 'django.db.backends.sqlite3',
18 | 'NAME': 'drf_extensions',
19 | 'TEST_CHARSET': 'utf8',
20 | },
21 | }
22 |
23 | CACHES = {
24 | 'default': {
25 | 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
26 | },
27 | 'special_cache': {
28 | 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
29 | },
30 | 'another_special_cache': {
31 | 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
32 | },
33 | }
34 |
35 | # Local time zone for this installation. Choices can be found here:
36 | # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
37 | # although not all choices may be available on all operating systems.
38 | # On Unix systems, a value of None will cause Django to use the same
39 | # timezone as the operating system.
40 | # If running in a Windows environment this must be set to the same as your
41 | # system time zone.
42 | TIME_ZONE = 'Europe/London'
43 |
44 | # Language code for this installation. All choices can be found here:
45 | # http://www.i18nguy.com/unicode/language-identifiers.html
46 | LANGUAGE_CODE = 'en-uk'
47 |
48 | SITE_ID = 1
49 |
50 | # If you set this to False, Django will make some optimizations so as not
51 | # to load the internationalization machinery.
52 | USE_I18N = True
53 |
54 | # If you set this to False, Django will not format dates, numbers and
55 | # calendars according to the current locale
56 | USE_L10N = True
57 |
58 | # Absolute filesystem path to the directory that will hold user-uploaded files.
59 | # Example: "/home/media/media.lawrence.com/"
60 | MEDIA_ROOT = 'tests_app/tests/files'
61 |
62 | # URL that handles the media served from MEDIA_ROOT. Make sure to use a
63 | # trailing slash if there is a path component (optional in other cases).
64 | # Examples: "http://media.lawrence.com", "http://example.com/media/"
65 | MEDIA_URL = ''
66 |
67 | # Make this unique, and don't share it with anybody.
68 | SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
69 |
70 | # List of callables that know how to import templates from various sources.
71 | TEMPLATES = [
72 | {
73 | 'BACKEND': 'django.template.backends.django.DjangoTemplates',
74 | 'DIRS': [],
75 | 'APP_DIRS': True,
76 | 'OPTIONS': {
77 | 'context_processors': [
78 | 'django.template.context_processors.debug',
79 | 'django.template.context_processors.request',
80 | 'django.contrib.auth.context_processors.auth',
81 | 'django.contrib.messages.context_processors.messages',
82 | ],
83 | },
84 | },
85 | ]
86 |
87 |
88 | MIDDLEWARE = [
89 | 'django.middleware.security.SecurityMiddleware',
90 | 'django.contrib.sessions.middleware.SessionMiddleware',
91 | 'django.middleware.common.CommonMiddleware',
92 | 'django.middleware.csrf.CsrfViewMiddleware',
93 | 'django.contrib.auth.middleware.AuthenticationMiddleware',
94 | 'django.contrib.messages.middleware.MessageMiddleware',
95 | 'django.middleware.clickjacking.XFrameOptionsMiddleware',
96 | ]
97 |
98 | ROOT_URLCONF = 'urls'
99 |
100 |
101 | INSTALLED_APPS = (
102 | 'django.contrib.auth',
103 | 'django.contrib.contenttypes',
104 | 'django.contrib.sessions',
105 | 'django.contrib.sites',
106 | 'django.contrib.messages',
107 | 'django.contrib.admin',
108 | # Uncomment the next line to enable the admin:
109 | # 'django.contrib.admin',
110 | # Uncomment the next line to enable admin documentation:
111 | # 'django.contrib.admindocs',
112 | 'django_nose',
113 | 'guardian',
114 | 'rest_framework_extensions',
115 |
116 | 'tests_app.tests.functional',
117 | 'tests_app.tests.unit',
118 | )
119 |
120 | STATIC_URL = '/static/'
121 |
122 | # Password validation
123 | # https://docs.djangoproject.com/en/{{ docs_version }}/ref/settings/#auth-password-validators
124 |
125 | AUTH_PASSWORD_VALIDATORS = [
126 | {
127 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
128 | },
129 | {
130 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
131 | },
132 | {
133 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
134 | },
135 | {
136 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
137 | },
138 | ]
139 |
140 | AUTH_USER_MODEL = 'auth.User'
141 |
142 |
143 | TEST_RUNNER = 'django.test.runner.DiscoverRunner'
144 |
145 | NOSE_ARGS = [
146 | '--processes=%s' % multiprocessing.cpu_count(),
147 | '--process-timeout=100',
148 | '--nocapture',
149 | ]
150 |
151 | NOSE_PLUGINS = [
152 | 'plugins.UnitTestDiscoveryPlugin',
153 | 'plugins.PrepareRestFrameworkSettingsPlugin',
154 | 'plugins.FlushCache',
155 | 'plugins.PrepareFileStorageDir'
156 | ]
157 |
158 | # guardian
159 | ANONYMOUS_USER_ID = -1
160 |
161 | AUTHENTICATION_BACKENDS = (
162 | 'django.contrib.auth.backends.ModelBackend', # this is default
163 | 'guardian.backends.ObjectPermissionBackend',
164 | )
165 |
166 | DEFAULT_AUTO_FIELD = 'django.db.models.AutoField'
167 |
--------------------------------------------------------------------------------
/tests_app/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/functional/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/functional/_concurrency/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/_concurrency/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/functional/_concurrency/conditional_request/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/_concurrency/conditional_request/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/functional/_concurrency/conditional_request/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class Book(models.Model):
5 | """A sample model for conditional requests."""
6 |
7 | name = models.CharField(max_length=100, default=None, blank=True, null=True)
8 | author = models.CharField(max_length=100, default=None, blank=True, null=True)
9 | issn = models.CharField(max_length=100, default=None, blank=True, null=True)
10 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_concurrency/conditional_request/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 | from .models import Book
3 |
4 |
5 | class BookSerializer(serializers.ModelSerializer):
6 | class Meta:
7 | model = Book
8 | fields = '__all__'
9 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_concurrency/conditional_request/urls.py:
--------------------------------------------------------------------------------
1 | from django.urls import re_path, include
2 | from rest_framework import routers
3 | from .views import (BookViewSet, BookListCreateView, BookChangeView, BookCustomDestroyView,
4 | BookUnconditionalDestroyView, BookUnconditionalUpdateView)
5 |
6 | router = routers.DefaultRouter()
7 | router.register(r'books', BookViewSet)
8 |
9 | urlpatterns = [
10 | # manually add endpoints for APIView instances
11 | re_path(r'books_view/(?P[0-9]+)/custom/delete/', BookCustomDestroyView.as_view(), name='book_view-custom_delete'),
12 | re_path(r'books_view/(?P[0-9]+)/unconditional/delete/', BookUnconditionalDestroyView.as_view(),
13 | name='book_view-unconditional_delete'),
14 | re_path(r'books_view/(?P[0-9]+)/unconditional/update/', BookUnconditionalUpdateView.as_view(),
15 | name='book_view-unconditional_update'),
16 | re_path(r'books_view/', BookListCreateView.as_view(), name='book_view-list'),
17 | re_path(r'books_view/(?P[0-9]+)/', BookChangeView.as_view(), name='book_view-detail'),
18 |
19 | # include the URLs from the default viewset
20 | re_path(r'^', include(router.urls)),
21 | ]
22 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_concurrency/conditional_request/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import viewsets
2 | from rest_framework import generics
3 | from rest_framework import status
4 | from rest_framework.response import Response
5 | from rest_framework_extensions.etag.mixins import APIETAGMixin
6 | from rest_framework_extensions.etag.decorators import api_etag
7 | from rest_framework_extensions.utils import default_api_object_etag_func
8 | from .models import Book
9 | from .serializers import BookSerializer
10 |
11 |
12 | class BookViewSet(APIETAGMixin,
13 | viewsets.ModelViewSet):
14 | """Test the mixin with DRF viewset."""
15 |
16 | queryset = Book.objects.all()
17 | serializer_class = BookSerializer
18 |
19 |
20 | class BookChangeView(APIETAGMixin,
21 | generics.RetrieveUpdateDestroyAPIView):
22 | """Test the mixin with DRF generic API views."""
23 |
24 | queryset = Book.objects.all()
25 | serializer_class = BookSerializer
26 |
27 |
28 | class BookListCreateView(APIETAGMixin,
29 | generics.ListCreateAPIView):
30 | """Test the mixin with DRF generic API views."""
31 |
32 | queryset = Book.objects.all()
33 | serializer_class = BookSerializer
34 |
35 |
36 | class BookCustomDestroyView(generics.DestroyAPIView):
37 | """Test the decorator with DRF generic API views."""
38 |
39 | # include the queryset here to enable the object lookup in `@api_etag`
40 | queryset = Book.objects.all()
41 |
42 | @api_etag(etag_func=default_api_object_etag_func)
43 | def delete(self, request, *args, **kwargs):
44 | obj = Book.objects.get(id=kwargs['pk'])
45 | obj.delete()
46 | return Response(status=status.HTTP_204_NO_CONTENT)
47 |
48 |
49 | class BookUnconditionalDestroyView(generics.DestroyAPIView):
50 | """Test the decorator with DRF generic API views."""
51 |
52 | # include the queryset here to enable the object lookup in `@api_etag`
53 | queryset = Book.objects.all()
54 |
55 | @api_etag(etag_func=default_api_object_etag_func, precondition_map={})
56 | def delete(self, request, *args, **kwargs):
57 | obj = Book.objects.get(id=kwargs['pk'])
58 | obj.delete()
59 | return Response(status=status.HTTP_204_NO_CONTENT)
60 |
61 |
62 | class BookUnconditionalUpdateView(generics.UpdateAPIView):
63 | """Test the decorator with DRF generic API views."""
64 |
65 | # include the queryset here to enable the object lookup in `@api_etag`
66 | queryset = Book.objects.all()
67 | serializer_class = BookSerializer
68 |
69 | @api_etag(etag_func=default_api_object_etag_func, precondition_map={})
70 | def update(self, request, *args, **kwargs):
71 | return super().update(request, *args, **kwargs)
--------------------------------------------------------------------------------
/tests_app/tests/functional/_examples/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_examples/etags/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/middleware.py:
--------------------------------------------------------------------------------
1 | try:
2 | from django.utils.deprecation import MiddlewareMixin
3 | except ImportError:
4 | MiddlewareMixin = object
5 |
6 |
7 | class RemoveEtagGzipPostfix(MiddlewareMixin):
8 | def process_response(self, request, response):
9 | if response.has_header('ETag') and response['ETag'][-6:] == ';gzip"':
10 | response['ETag'] = response['ETag'][:-6] + '"'
11 | return response
12 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase, override_settings
2 | from django.conf import settings
3 |
4 | @override_settings(ROOT_URLCONF='tests_app.tests.functional._examples.etags.remove_etag_gzip_postfix.urls')
5 | class RemoveEtagGzipPostfixTest(TestCase):
6 |
7 | @override_settings(MIDDLEWARE_CLASSES=(
8 | 'django.middleware.common.CommonMiddleware',
9 | 'django.middleware.gzip.GZipMiddleware',
10 | ))
11 | def test_without_middleware(self):
12 | response = self.client.get('/remove-etag-gzip-postfix/', **{
13 | 'HTTP_ACCEPT_ENCODING': 'gzip'
14 | })
15 |
16 | self.assertEqual(response.status_code, 200)
17 | # previously it was '"etag_value;gzip"' , instead of '"etag_value"', gzip don't append ;gzip suffix after encoding
18 | self.assertEqual(response['ETag'], '"etag_value"')
19 |
20 | @override_settings(MIDDLEWARE_CLASSES=(
21 | 'django.middleware.common.CommonMiddleware',
22 | 'django.middleware.gzip.GZipMiddleware',
23 | 'tests_app.tests.functional._examples.etags.remove_etag_gzip_postfix.middleware.RemoveEtagGzipPostfix',
24 | ))
25 | def test_with_middleware(self):
26 | response = self.client.get('/remove-etag-gzip-postfix/', **{
27 | 'HTTP_ACCEPT_ENCODING': 'gzip'
28 | })
29 | self.assertEqual(response.status_code, 200)
30 | self.assertEqual(response['ETag'], '"etag_value"')
31 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/urls.py:
--------------------------------------------------------------------------------
1 | from django.urls import re_path
2 |
3 | from .views import MyView
4 |
5 |
6 | urlpatterns = [
7 | re_path(r'^remove-etag-gzip-postfix/$', MyView.as_view()),
8 | ]
9 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/_examples/etags/remove_etag_gzip_postfix/views.py:
--------------------------------------------------------------------------------
1 | from django.views import View
2 | from django.http import HttpResponse
3 |
4 |
5 | class MyView(View):
6 | def get(self, request):
7 | """
8 | GZipMiddleware will NOT compress content if any of the following are true:
9 | * The content body is less than 200 bytes long.
10 | * The response has already set the Content-Encoding header.
11 | * The request (the browser) hasn’t sent an Accept-Encoding header containing gzip.
12 | """
13 | response = HttpResponse('r' * 300)
14 | response['ETag'] = '"etag_value"'
15 | return response
16 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/cache/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/cache/decorators/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/cache/decorators/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase, override_settings
2 | from django.utils.encoding import force_str
3 |
4 |
5 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.cache.decorators.urls')
6 | class TestCacheResponseFunctionally(TestCase):
7 |
8 | def test_should_return_response(self):
9 | resp = self.client.get('/hello/')
10 | self.assertEqual(force_str(resp.content), '"Hello world"')
11 |
12 | def test_should_return_same_response_if_cached(self):
13 | resp_1 = self.client.get('/hello/')
14 | resp_2 = self.client.get('/hello/')
15 | self.assertEqual(resp_1.content, resp_2.content)
16 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/cache/decorators/urls.py:
--------------------------------------------------------------------------------
1 | from django.urls import re_path
2 |
3 | from .views import HelloView
4 |
5 |
6 | urlpatterns = [
7 | re_path(r'^hello/$', HelloView.as_view(), name='hello'),
8 | ]
9 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/cache/decorators/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import views
2 | from rest_framework.response import Response
3 |
4 | from rest_framework_extensions.cache.decorators import cache_response
5 |
6 |
7 | class HelloView(views.APIView):
8 | @cache_response()
9 | def get(self, request, *args, **kwargs):
10 | return Response('Hello world')
11 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/key_constructor/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/key_constructor/bits/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/key_constructor/bits/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class KeyConstructorUserProperty(models.Model):
5 | name = models.CharField(max_length=100)
6 |
7 |
8 | class KeyConstructorUserModel(models.Model):
9 | property = models.ForeignKey(
10 | KeyConstructorUserProperty,
11 | on_delete=models.CASCADE
12 | )
13 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/key_constructor/bits/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 |
3 | from .models import KeyConstructorUserModel
4 |
5 |
6 | class UserModelSerializer(serializers.ModelSerializer):
7 | class Meta:
8 | model = KeyConstructorUserModel
9 | fields = '__all__'
10 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/key_constructor/bits/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import override_settings
2 |
3 | from rest_framework.test import APITestCase
4 |
5 | from .models import KeyConstructorUserProperty
6 |
7 |
8 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.key_constructor.bits.urls')
9 | class ListSqlQueryKeyBitTestBehaviour(APITestCase):
10 | """Regression tests for https://github.com/chibisov/drf-extensions/issues/28#issuecomment-51711927
11 |
12 | `rest_framework.filters.DjangoFilterBackend` uses default `FilterSet`.
13 | When there is no filtered fk in db, then `FilterSet.form` is invalid with errors:
14 | {'property': [u'Select a valid choice. That choice is not one of the available choices.']}
15 | In that case `FilterSet.qs` returns `self.queryset.none()`
16 | """
17 |
18 | def test_with_fk_in_db(self):
19 | KeyConstructorUserProperty.objects.create(name='some property')
20 |
21 | # list
22 | response = self.client.get('/users/?property=1')
23 | self.assertEqual(response.status_code, 200)
24 |
25 | # retrieve
26 | response = self.client.get('/users/1/?property=1')
27 | self.assertEqual(response.status_code, 404)
28 |
29 | def test_without_fk_in_db(self):
30 | # list
31 | response = self.client.get('/users/?property=1')
32 | self.assertEqual(response.status_code, 400)
33 |
34 | # retrieve
35 | response = self.client.get('/users/1/?property=1')
36 | self.assertEqual(response.status_code, 400)
37 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/key_constructor/bits/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework import routers
2 |
3 | from .views import UserModelViewSet
4 |
5 |
6 | viewset_router = routers.DefaultRouter()
7 | viewset_router.register('users', UserModelViewSet)
8 | urlpatterns = viewset_router.urls
9 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/key_constructor/bits/views.py:
--------------------------------------------------------------------------------
1 | import django_filters
2 | from rest_framework import viewsets
3 |
4 | from .models import KeyConstructorUserModel as UserModel
5 | from .serializers import UserModelSerializer
6 |
7 |
8 | class UserModelViewSet(viewsets.ModelViewSet):
9 | queryset = UserModel.objects.all()
10 | serializer_class = UserModelSerializer
11 | filter_backends = (django_filters.rest_framework.DjangoFilterBackend,)
12 | filterset_fields = ('property',)
13 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 2.2 on 2019-04-16 11:51
2 |
3 | from django.db import migrations, models
4 | import django.db.models.deletion
5 |
6 |
7 | class Migration(migrations.Migration):
8 |
9 | initial = True
10 |
11 | dependencies = [
12 | ('contenttypes', '0002_remove_content_type_name'),
13 | ]
14 |
15 | operations = [
16 | migrations.CreateModel(
17 | name='Comment',
18 | fields=[
19 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
20 | ('email', models.EmailField(max_length=254)),
21 | ('content', models.CharField(max_length=200)),
22 | ('created', models.DateTimeField(auto_now_add=True)),
23 | ],
24 | ),
25 | migrations.CreateModel(
26 | name='CommentForListDestroyModelMixin',
27 | fields=[
28 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
29 | ('email', models.EmailField(max_length=254)),
30 | ],
31 | ),
32 | migrations.CreateModel(
33 | name='CommentForListUpdateModelMixin',
34 | fields=[
35 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
36 | ('email', models.EmailField(max_length=254)),
37 | ],
38 | ),
39 | migrations.CreateModel(
40 | name='CommentForPaginateByMaxMixin',
41 | fields=[
42 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
43 | ('email', models.EmailField(max_length=254)),
44 | ('content', models.CharField(max_length=200)),
45 | ('created', models.DateTimeField(auto_now_add=True)),
46 | ],
47 | options={
48 | 'ordering': ['id'],
49 | },
50 | ),
51 | migrations.CreateModel(
52 | name='DefaultRouterGroupModel',
53 | fields=[
54 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
55 | ('name', models.CharField(max_length=10)),
56 | ],
57 | ),
58 | migrations.CreateModel(
59 | name='DefaultRouterPermissionModel',
60 | fields=[
61 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
62 | ('name', models.CharField(max_length=10)),
63 | ],
64 | ),
65 | migrations.CreateModel(
66 | name='KeyConstructorUserProperty',
67 | fields=[
68 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
69 | ('name', models.CharField(max_length=100)),
70 | ],
71 | ),
72 | migrations.CreateModel(
73 | name='NestedRouterMixinBookModel',
74 | fields=[
75 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
76 | ('title', models.CharField(max_length=30)),
77 | ],
78 | ),
79 | migrations.CreateModel(
80 | name='NestedRouterMixinGroupModel',
81 | fields=[
82 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
83 | ('name', models.CharField(max_length=10)),
84 | ],
85 | ),
86 | migrations.CreateModel(
87 | name='NestedRouterMixinPermissionModel',
88 | fields=[
89 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
90 | ('name', models.CharField(max_length=10)),
91 | ],
92 | ),
93 | migrations.CreateModel(
94 | name='NestedRouterMixinTaskModel',
95 | fields=[
96 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
97 | ('title', models.CharField(max_length=30)),
98 | ],
99 | ),
100 | migrations.CreateModel(
101 | name='PermissionsComment',
102 | fields=[
103 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
104 | ('text', models.CharField(max_length=100)),
105 | ],
106 | ),
107 | migrations.CreateModel(
108 | name='RouterTestModel',
109 | fields=[
110 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
111 | ('uuid', models.CharField(max_length=20)),
112 | ('text', models.CharField(max_length=200)),
113 | ],
114 | ),
115 | migrations.CreateModel(
116 | name='UserForListUpdateModelMixin',
117 | fields=[
118 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
119 | ('email', models.EmailField(max_length=254)),
120 | ('name', models.CharField(max_length=10)),
121 | ('age', models.IntegerField()),
122 | ('last_name', models.CharField(max_length=10)),
123 | ('password', models.CharField(max_length=100)),
124 | ],
125 | ),
126 | migrations.CreateModel(
127 | name='NestedRouterMixinUserModel',
128 | fields=[
129 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
130 | ('email', models.EmailField(blank=True, max_length=254, null=True)),
131 | ('name', models.CharField(max_length=10)),
132 | ('groups', models.ManyToManyField(related_name='user_groups', to='functional.NestedRouterMixinGroupModel')),
133 | ],
134 | ),
135 | migrations.AddField(
136 | model_name='nestedroutermixingroupmodel',
137 | name='permissions',
138 | field=models.ManyToManyField(to='functional.NestedRouterMixinPermissionModel'),
139 | ),
140 | migrations.CreateModel(
141 | name='NestedRouterMixinCommentModel',
142 | fields=[
143 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
144 | ('object_id', models.PositiveIntegerField(blank=True, null=True)),
145 | ('text', models.CharField(max_length=30)),
146 | ('content_type', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='contenttypes.ContentType')),
147 | ],
148 | ),
149 | migrations.CreateModel(
150 | name='KeyConstructorUserModel',
151 | fields=[
152 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
153 | ('property', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='functional.KeyConstructorUserProperty')),
154 | ],
155 | ),
156 | migrations.CreateModel(
157 | name='DefaultRouterUserModel',
158 | fields=[
159 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
160 | ('name', models.CharField(max_length=10)),
161 | ('groups', models.ManyToManyField(related_name='user_groups', to='functional.DefaultRouterGroupModel')),
162 | ],
163 | ),
164 | migrations.AddField(
165 | model_name='defaultroutergroupmodel',
166 | name='permissions',
167 | field=models.ManyToManyField(to='functional.DefaultRouterPermissionModel'),
168 | ),
169 | migrations.CreateModel(
170 | name='Book',
171 | fields=[
172 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
173 | ('name', models.CharField(max_length=100, default=None, blank=True, null=True)),
174 | ('author', models.CharField(max_length=100, default=None, blank=True, null=True)),
175 | ('issn', models.CharField(max_length=100, default=None, blank=True, null=True)),
176 | ],
177 | ),
178 | ]
179 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/migrations/0002_nestedroutermixinusermodel_code.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 3.2.18 on 2023-03-26 11:14
2 |
3 | from django.db import migrations, models
4 | import uuid
5 |
6 |
7 | class Migration(migrations.Migration):
8 |
9 | dependencies = [
10 | ('functional', '0001_initial'),
11 | ]
12 |
13 | operations = [
14 | migrations.AddField(
15 | model_name='nestedroutermixinusermodel',
16 | name='code',
17 | field=models.UUIDField(default=uuid.uuid4),
18 | ),
19 | ]
20 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/migrations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/migrations/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/functional/mixins/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/detail_serializer_mixin/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/detail_serializer_mixin/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class Comment(models.Model):
5 | email = models.EmailField()
6 | content = models.CharField(max_length=200)
7 | created = models.DateTimeField(auto_now_add=True)
8 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/detail_serializer_mixin/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 |
3 | from .models import Comment
4 |
5 |
6 | class CommentSerializer(serializers.ModelSerializer):
7 | class Meta:
8 | model = Comment
9 | fields = (
10 | 'id',
11 | 'email',
12 | )
13 |
14 |
15 | class CommentDetailSerializer(serializers.ModelSerializer):
16 | class Meta:
17 | model = Comment
18 | fields = (
19 | 'id',
20 | 'email',
21 | 'content',
22 | )
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/detail_serializer_mixin/tests.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from django.test import TestCase, override_settings
4 |
5 | # todo: use from rest_framework when released
6 | from rest_framework.test import APIRequestFactory
7 | from .models import Comment
8 |
9 |
10 | factory = APIRequestFactory()
11 |
12 |
13 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.detail_serializer_mixin.urls')
14 | class DetailSerializerMixinTest_serializer_detail_class(TestCase):
15 |
16 | def setUp(self):
17 | self.comment = Comment.objects.create(
18 | id=1,
19 | email='example@ya.ru',
20 | content='Hello world',
21 | created=datetime.datetime.now()
22 | )
23 |
24 | def test_serializer_class_response(self):
25 | resp = self.client.get('/comments/')
26 | expected = [{
27 | 'id': 1,
28 | 'email': 'example@ya.ru'
29 | }]
30 | self.assertEqual(resp.data, expected)
31 |
32 | def test_serializer_detail_class_response(self):
33 | resp = self.client.get('/comments/1/')
34 | expected = {
35 | 'id': 1,
36 | 'email': 'example@ya.ru',
37 | 'content': 'Hello world',
38 | }
39 | self.assertEqual(resp.data, expected, 'should use detail serializer for detail endpoint')
40 |
41 | def test_view_with_mixin_and_without__serializer_detail_class__should_raise_exception(self):
42 | msg = "'CommentWithoutDetailSerializerClassViewSet' should include a 'serializer_detail_class' attribute"
43 | self.assertRaisesMessage(AssertionError, msg, self.client.get, '/comments-2/')
44 |
45 |
46 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.detail_serializer_mixin.urls')
47 | class DetailSerializerMixin_queryset_detail(TestCase):
48 |
49 | def setUp(self):
50 | self.comments = [
51 | Comment.objects.create(
52 | id=1,
53 | email='example@ya.ru',
54 | content='Hello world',
55 | created=datetime.datetime.now()
56 | ),
57 | Comment.objects.create(
58 | id=2,
59 | email='example2@ya.ru',
60 | content='Hello world 2',
61 | created=datetime.datetime.now()
62 | ),
63 | ]
64 |
65 | def test_list_should_use_default_queryset_method(self):
66 | resp = self.client.get('/comments-3/')
67 | expected = [{
68 | 'id': 2,
69 | 'email': 'example2@ya.ru'
70 | }]
71 | self.assertEqual(resp.data, expected)
72 |
73 | def test_detail_view_should_use_default_queryset_if_queryset_detail_not_specified(self):
74 | resp = self.client.get('/comments-3/1/')
75 | self.assertEqual(resp.status_code, 404)
76 |
77 | resp = self.client.get('/comments-3/2/')
78 | expected = {
79 | 'id': 2,
80 | 'email': 'example2@ya.ru',
81 | 'content': 'Hello world 2',
82 | }
83 | self.assertEqual(resp.data, expected)
84 |
85 | def test_list_should_use_default_queryset_method_if_queryset_detail_specified(self):
86 | resp = self.client.get('/comments-4/')
87 | expected = [{
88 | 'id': 2,
89 | 'email': 'example2@ya.ru'
90 | }]
91 | self.assertEqual(resp.data, expected)
92 |
93 | def test_detail_view_should_use_custom_queryset_if_queryset_detail_specified(self):
94 | resp = self.client.get('/comments-4/2/')
95 | self.assertEqual(resp.status_code, 404)
96 |
97 | resp = self.client.get('/comments-4/1/')
98 | expected = {
99 | 'id': 1,
100 | 'email': 'example@ya.ru',
101 | 'content': 'Hello world',
102 | }
103 | self.assertEqual(resp.data, expected)
104 |
105 | def test_nested_model_view_with_mixin_should_use_get_detail_queryset(self):
106 | """
107 | Regression tests for https://github.com/chibisov/drf-extensions/pull/24
108 | """
109 | resp = self.client.get('/comments-5/1/')
110 | expected = {
111 | 'id': 1,
112 | 'email': 'example@ya.ru',
113 | 'content': 'Hello world',
114 | }
115 | self.assertEqual(resp.status_code, 200)
116 | self.assertEqual(resp.data, expected)
117 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/detail_serializer_mixin/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework import routers
2 |
3 | from .views import (
4 | CommentViewSet,
5 | CommentWithoutDetailSerializerClassViewSet,
6 | CommentWithIdTwoViewSet,
7 | CommentWithIdTwoAndIdOneForDetailViewSet,
8 | CommentWithDetailSerializerAndNoArgsForGetQuerySetViewSet
9 | )
10 |
11 |
12 | viewset_router = routers.DefaultRouter()
13 | viewset_router.register('comments', CommentViewSet, basename='alt1')
14 | viewset_router.register('comments-2', CommentWithoutDetailSerializerClassViewSet, basename='alt2')
15 | viewset_router.register('comments-3', CommentWithIdTwoViewSet, basename='alt3')
16 | viewset_router.register('comments-4', CommentWithIdTwoAndIdOneForDetailViewSet, basename='alt4')
17 | viewset_router.register('comments-5', CommentWithDetailSerializerAndNoArgsForGetQuerySetViewSet, basename='alt5')
18 |
19 | urlpatterns = viewset_router.urls
20 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/detail_serializer_mixin/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import viewsets
2 | from rest_framework_extensions.mixins import DetailSerializerMixin
3 |
4 | from .models import Comment
5 | from .serializers import CommentSerializer, CommentDetailSerializer
6 |
7 |
8 | class CommentViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet):
9 | serializer_class = CommentSerializer
10 | serializer_detail_class = CommentDetailSerializer
11 | queryset = Comment.objects.all()
12 |
13 |
14 | class CommentWithoutDetailSerializerClassViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet):
15 | serializer_class = CommentSerializer
16 | queryset = Comment.objects.all()
17 |
18 |
19 | class CommentWithIdTwoViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet):
20 | serializer_class = CommentSerializer
21 | serializer_detail_class = CommentDetailSerializer
22 | queryset = Comment.objects.filter(id=2)
23 |
24 |
25 | class CommentWithIdTwoAndIdOneForDetailViewSet(DetailSerializerMixin, viewsets.ReadOnlyModelViewSet):
26 | serializer_class = CommentSerializer
27 | serializer_detail_class = CommentDetailSerializer
28 | queryset = Comment.objects.filter(id=2)
29 | queryset_detail = Comment.objects.filter(id=1)
30 |
31 |
32 | class CommentWithDetailSerializerAndNoArgsForGetQuerySetViewSet(DetailSerializerMixin, viewsets.ModelViewSet):
33 | """
34 | For regression tests https://github.com/chibisov/drf-extensions/pull/24
35 | """
36 | serializer_class = CommentSerializer
37 | serializer_detail_class = CommentDetailSerializer
38 | queryset = Comment.objects.all()
39 | queryset_detail = Comment.objects.filter(id=1)
40 |
41 | def get_queryset(self):
42 | return super().get_queryset()
43 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_destroy_model_mixin/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_destroy_model_mixin/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class CommentForListDestroyModelMixin(models.Model):
5 | email = models.EmailField()
6 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_destroy_model_mixin/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import override_settings
2 |
3 | from rest_framework.test import APITestCase
4 | from rest_framework_extensions.settings import extensions_api_settings
5 | from rest_framework_extensions import utils
6 |
7 | from .models import CommentForListDestroyModelMixin as Comment
8 | from tests_app.testutils import override_extensions_api_settings
9 |
10 |
11 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.list_destroy_model_mixin.urls')
12 | class ListDestroyModelMixinTest(APITestCase):
13 |
14 | def setUp(self):
15 | self.comments = [
16 | Comment.objects.create(
17 | id=1,
18 | email='example@ya.ru'
19 | ),
20 | Comment.objects.create(
21 | id=2,
22 | email='example@gmail.com'
23 | )
24 | ]
25 | self.protection_headers = {
26 | utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true'
27 | }
28 |
29 | def test_simple_response(self):
30 | resp = self.client.get('/comments/')
31 | expected = [
32 | {
33 | 'id': 1,
34 | 'email': 'example@ya.ru'
35 | },
36 | {
37 | 'id': 2,
38 | 'email': 'example@gmail.com'
39 | }
40 | ]
41 | self.assertEqual(resp.data, expected)
42 |
43 | def test_filter_works(self):
44 | resp = self.client.get('/comments/?id=1')
45 | expected = [
46 | {
47 | 'id': 1,
48 | 'email': 'example@ya.ru'
49 | }
50 | ]
51 | self.assertEqual(resp.data, expected)
52 |
53 | def test_destroy_instance(self):
54 | resp = self.client.delete('/comments/1/')
55 | self.assertEqual(resp.status_code, 204)
56 | self.assertFalse(1 in Comment.objects.values_list('pk', flat=True))
57 |
58 | def test_bulk_destroy__without_protection_header(self):
59 | resp = self.client.delete('/comments/')
60 | self.assertEqual(resp.status_code, 400)
61 | expected_message = {
62 | 'detail': 'Header \'{0}\' should be provided for bulk operation.'.format(
63 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME
64 | )
65 | }
66 | self.assertEqual(resp.data, expected_message)
67 |
68 | def test_bulk_destroy__with_protection_header(self):
69 | resp = self.client.delete('/comments/', **self.protection_headers)
70 | self.assertEqual(resp.status_code, 204)
71 | self.assertEqual(Comment.objects.count(), 0)
72 |
73 | @override_extensions_api_settings(DEFAULT_BULK_OPERATION_HEADER_NAME=None)
74 | def test_bulk_destroy__without_protection_header__and_with_turned_off_protection_header(self):
75 | resp = self.client.delete('/comments/')
76 | self.assertEqual(resp.status_code, 204)
77 | self.assertEqual(Comment.objects.count(), 0)
78 |
79 | def test_bulk_destroy__should_destroy_filtered_queryset(self):
80 | resp = self.client.delete('/comments/?id=1', **self.protection_headers)
81 | self.assertEqual(resp.status_code, 204)
82 | self.assertEqual(Comment.objects.count(), 1)
83 | self.assertEqual(Comment.objects.all()[0], self.comments[1])
84 |
85 | def test_bulk_destroy__should_not_destroy_if_client_has_no_permissions(self):
86 | resp = self.client.delete('/comments-with-permission/', **self.protection_headers)
87 | self.assertEqual(resp.status_code, 404)
88 | self.assertEqual(Comment.objects.count(), 2)
89 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_destroy_model_mixin/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework import routers
2 |
3 | from .views import CommentViewSet, CommentViewSetWithPermissions
4 |
5 |
6 | viewset_router = routers.DefaultRouter()
7 | viewset_router.register('comments', CommentViewSet, basename='alt1')
8 | viewset_router.register('comments-with-permissions', CommentViewSetWithPermissions, basename='alt2')
9 | urlpatterns = viewset_router.urls
10 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_destroy_model_mixin/views.py:
--------------------------------------------------------------------------------
1 | import django_filters
2 | from rest_framework import viewsets, serializers
3 | from rest_framework import filters
4 | from rest_framework.permissions import DjangoModelPermissions
5 | from rest_framework_extensions.bulk_operations.mixins import ListDestroyModelMixin
6 |
7 | from .models import CommentForListDestroyModelMixin as Comment
8 |
9 |
10 | class CommentFilter(django_filters.FilterSet):
11 | class Meta:
12 | model = Comment
13 | fields = [
14 | 'id'
15 | ]
16 |
17 |
18 | class CommentSerializer(serializers.ModelSerializer):
19 | class Meta:
20 | model = Comment
21 | fields = '__all__'
22 |
23 |
24 | class CommentViewSet(ListDestroyModelMixin, viewsets.ModelViewSet):
25 | queryset = Comment.objects.all()
26 | serializer_class = CommentSerializer
27 | filter_backends = (django_filters.rest_framework.DjangoFilterBackend,)
28 | filterset_class = CommentFilter
29 |
30 |
31 | class CommentViewSetWithPermissions(CommentViewSet):
32 | permission_classes = (DjangoModelPermissions,)
33 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_update_model_mixin/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_update_model_mixin/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class CommentForListUpdateModelMixin(models.Model):
5 | email = models.EmailField()
6 |
7 |
8 | class UserForListUpdateModelMixin(models.Model):
9 | email = models.EmailField()
10 | name = models.CharField(max_length=10)
11 | age = models.IntegerField()
12 | last_name = models.CharField(max_length=10)
13 | password = models.CharField(max_length=100)
14 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_update_model_mixin/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 | from .models import (
3 | UserForListUpdateModelMixin as User,
4 | CommentForListUpdateModelMixin as Comment,
5 | )
6 |
7 |
8 | class UserSerializer(serializers.ModelSerializer):
9 | surname = serializers.CharField(source='last_name')
10 |
11 | class Meta:
12 | model = User
13 | extra_kwargs = {'password': {'write_only': True}}
14 | read_only_fields = ('name',)
15 | fields = [
16 | 'id',
17 | 'age',
18 | 'name',
19 | 'surname',
20 | 'password'
21 | ]
22 |
23 |
24 | class CommentSerializer(serializers.ModelSerializer):
25 | class Meta:
26 | model = Comment
27 | fields = '__all__'
28 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_update_model_mixin/tests.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import unittest
4 |
5 | import django
6 | from django.test import override_settings
7 |
8 | from rest_framework.test import APITestCase
9 | from rest_framework_extensions.settings import extensions_api_settings
10 | from rest_framework_extensions import utils
11 |
12 | from .models import (
13 | CommentForListUpdateModelMixin as Comment,
14 | UserForListUpdateModelMixin as User
15 | )
16 | from tests_app.testutils import override_extensions_api_settings
17 |
18 |
19 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.list_update_model_mixin.urls')
20 | class ListUpdateModelMixinTest(APITestCase):
21 |
22 | def setUp(self):
23 | self.comments = [
24 | Comment.objects.create(
25 | id=1,
26 | email='example@ya.ru'
27 | ),
28 | Comment.objects.create(
29 | id=2,
30 | email='example@gmail.com'
31 | )
32 | ]
33 | self.protection_headers = {
34 | utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true'
35 | }
36 | self.patch_data = {
37 | 'email': 'example@yandex.ru'
38 | }
39 |
40 | def test_simple_response(self):
41 | resp = self.client.get('/comments/')
42 | expected = [
43 | {
44 | 'id': 1,
45 | 'email': 'example@ya.ru'
46 | },
47 | {
48 | 'id': 2,
49 | 'email': 'example@gmail.com'
50 | }
51 | ]
52 | self.assertEqual(resp.data, expected)
53 |
54 | def test_filter_works(self):
55 | resp = self.client.get('/comments/?id=1')
56 | expected = [
57 | {
58 | 'id': 1,
59 | 'email': 'example@ya.ru'
60 | }
61 | ]
62 | self.assertEqual(resp.data, expected)
63 |
64 | def test_update_instance(self):
65 | data = {
66 | 'id': 1,
67 | 'email': 'example@yandex.ru'
68 | }
69 | resp = self.client.put('/comments/1/', data=json.dumps(data), content_type='application/json')
70 | self.assertEqual(resp.status_code, 200)
71 | self.assertEqual(Comment.objects.get(pk=1).email, 'example@yandex.ru')
72 |
73 | def test_partial_update_instance(self):
74 | data = {
75 | 'id': 1,
76 | 'email': 'example@yandex.ru'
77 | }
78 | resp = self.client.patch('/comments/1/', data=json.dumps(data), content_type='application/json')
79 | self.assertEqual(resp.status_code, 200)
80 | self.assertEqual(Comment.objects.get(pk=1).email, 'example@yandex.ru')
81 |
82 | def test_bulk_partial_update__without_protection_header(self):
83 | resp = self.client.patch('/comments/', data=json.dumps(self.patch_data), content_type='application/json')
84 | self.assertEqual(resp.status_code, 400)
85 | expected_message = {
86 | 'detail': 'Header \'{0}\' should be provided for bulk operation.'.format(
87 | extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME
88 | )
89 | }
90 | self.assertEqual(resp.data, expected_message)
91 |
92 | def test_bulk_partial_update__with_protection_header(self):
93 | resp = self.client.patch('/comments/', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers)
94 | self.assertEqual(resp.status_code, 204)
95 | for comment in Comment.objects.all():
96 | self.assertEqual(comment.email, self.patch_data['email'])
97 |
98 | @override_extensions_api_settings(DEFAULT_BULK_OPERATION_HEADER_NAME=None)
99 | def test_bulk_partial_update__without_protection_header__and_with_turned_off_protection_header(self):
100 | resp = self.client.patch('/comments/', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers)
101 | self.assertEqual(resp.status_code, 204)
102 | for comment in Comment.objects.all():
103 | self.assertEqual(comment.email, self.patch_data['email'])
104 |
105 | def test_bulk_partial_update__should_update_filtered_queryset(self):
106 | resp = self.client.patch('/comments/?id=1', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers)
107 | self.assertEqual(resp.status_code, 204)
108 | self.assertEqual(Comment.objects.get(pk=1).email, self.patch_data['email'])
109 | self.assertEqual(Comment.objects.get(pk=2).email, self.comments[1].email)
110 |
111 | def test_bulk_partial_update__should_not_update_if_client_has_no_permissions(self):
112 | resp = self.client.patch('/comments-with-permission/', data=json.dumps(self.patch_data), content_type='application/json', **self.protection_headers)
113 | self.assertEqual(resp.status_code, 404)
114 | for i, comment in enumerate(Comment.objects.all()):
115 | self.assertEqual(comment.email, self.comments[i].email)
116 |
117 |
118 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.list_update_model_mixin.urls')
119 | class ListUpdateModelMixinTestBehaviour__serializer_fields(APITestCase):
120 |
121 | def setUp(self):
122 | self.user = User.objects.create(
123 | id=1,
124 | name='Gennady',
125 | age=24,
126 | last_name='Chibisov',
127 | email='example@ya.ru',
128 | password='somepassword'
129 | )
130 | self.headers = {
131 | utils.prepare_header_name(extensions_api_settings.DEFAULT_BULK_OPERATION_HEADER_NAME): 'true'
132 | }
133 |
134 | def get_fresh_user(self):
135 | return User.objects.get(pk=self.user.pk)
136 |
137 | def test_simple_response(self):
138 | resp = self.client.get('/users/')
139 | expected = [
140 | {
141 | 'id': 1,
142 | 'age': 24,
143 | 'name': 'Gennady',
144 | 'surname': 'Chibisov'
145 | }
146 | ]
147 | self.assertEqual(resp.data, expected)
148 |
149 | def test_invalid_for_db_data(self):
150 | data = {
151 | 'age': 'Not integer value'
152 | }
153 | try:
154 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
155 | except ValueError:
156 | self.fail('Errors with invalid for DB data should be caught')
157 | else:
158 | self.assertEqual(resp.status_code, 400)
159 | if django.VERSION < (3, 0, 0):
160 | expected_message = {
161 | 'detail': "invalid literal for int() with base 10: 'Not integer value'"
162 | }
163 | else:
164 | expected_message = {
165 | 'detail': "Field 'age' expected a number but got 'Not integer value'."
166 | }
167 | self.assertEqual(resp.data, expected_message)
168 |
169 | def test_should_use_source_if_it_set_in_serializer(self):
170 | data = {
171 | 'surname': 'Ivanov'
172 | }
173 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
174 | self.assertEqual(resp.status_code, 204)
175 | self.assertEqual(self.get_fresh_user().last_name, data['surname'])
176 |
177 | def test_should_update_write_only_fields(self):
178 | data = {
179 | 'password': '123'
180 | }
181 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
182 | self.assertEqual(resp.status_code, 204)
183 | self.assertEqual(self.get_fresh_user().password, data['password'])
184 |
185 | def test_should_not_update_read_only_fields(self):
186 | data = {
187 | 'name': 'Ivan'
188 | }
189 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
190 | self.assertEqual(resp.status_code, 204)
191 | self.assertEqual(self.get_fresh_user().name, self.user.name)
192 |
193 | def test_should_not_update_hidden_fields(self):
194 | data = {
195 | 'email': 'example@gmail.com'
196 | }
197 | resp = self.client.patch('/users/', data=json.dumps(data), content_type='application/json', **self.headers)
198 | self.assertEqual(resp.status_code, 204)
199 | self.assertEqual(self.get_fresh_user().email, self.user.email)
200 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_update_model_mixin/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework import routers
2 |
3 | from .views import CommentViewSet, CommentViewSetWithPermissions, UserViewSet
4 |
5 |
6 | viewset_router = routers.DefaultRouter()
7 | viewset_router.register('comments', CommentViewSet, basename='alt1')
8 | viewset_router.register('comments-with-permissions', CommentViewSetWithPermissions, basename='alt2')
9 | viewset_router.register('users', UserViewSet, basename='alt3')
10 | urlpatterns = viewset_router.urls
11 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/list_update_model_mixin/views.py:
--------------------------------------------------------------------------------
1 | import django_filters
2 | from rest_framework import viewsets
3 | from rest_framework import filters
4 | from rest_framework.permissions import DjangoModelPermissions
5 | from rest_framework_extensions.mixins import ListUpdateModelMixin
6 |
7 | from .models import (
8 | CommentForListUpdateModelMixin as Comment,
9 | UserForListUpdateModelMixin as User
10 | )
11 | from .serializers import UserSerializer, CommentSerializer
12 |
13 |
14 | class CommentFilter(django_filters.FilterSet):
15 | class Meta:
16 | model = Comment
17 | fields = [
18 | 'id'
19 | ]
20 |
21 |
22 | class CommentViewSet(ListUpdateModelMixin, viewsets.ModelViewSet):
23 | queryset = Comment.objects.all()
24 | serializer_class = CommentSerializer
25 | filter_backends = (django_filters.rest_framework.DjangoFilterBackend,)
26 | filterset_class = CommentFilter
27 |
28 |
29 | class CommentViewSetWithPermissions(CommentViewSet):
30 | permission_classes = (DjangoModelPermissions,)
31 |
32 |
33 | class UserViewSet(ListUpdateModelMixin, viewsets.ModelViewSet):
34 | queryset = User.objects.all()
35 | serializer_class = UserSerializer
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/paginate_by_max_mixin/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/paginate_by_max_mixin/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class CommentForPaginateByMaxMixin(models.Model):
5 | email = models.EmailField()
6 | content = models.CharField(max_length=200)
7 | created = models.DateTimeField(auto_now_add=True)
8 |
9 | class Meta:
10 | ordering = ['id']
11 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/paginate_by_max_mixin/pagination.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.mixins import PaginateByMaxMixin
2 | from rest_framework.pagination import PageNumberPagination
3 |
4 |
5 | class WithMaxPagination(PaginateByMaxMixin, PageNumberPagination):
6 | page_size = 10
7 | page_size_query_param = 'limit'
8 | max_page_size = 20
9 |
10 |
11 | class FlexiblePagination(PageNumberPagination):
12 | page_size = 10
13 | page_size_query_param = 'page_size'
14 |
15 |
16 | class FixedPagination(PageNumberPagination):
17 | page_size = 10
18 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/paginate_by_max_mixin/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 |
3 | from .models import CommentForPaginateByMaxMixin
4 |
5 |
6 | class CommentSerializer(serializers.ModelSerializer):
7 | class Meta:
8 | model = CommentForPaginateByMaxMixin
9 | fields = (
10 | 'id',
11 | 'email',
12 | )
13 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/paginate_by_max_mixin/tests.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from django.test import TestCase, override_settings
4 |
5 | from .models import CommentForPaginateByMaxMixin
6 |
7 |
8 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.paginate_by_max_mixin.urls')
9 | class PaginateByMaxMixinTest(TestCase):
10 |
11 | def setUp(self):
12 | for i in range(30):
13 | CommentForPaginateByMaxMixin.objects.create(
14 | email='example@ya.ru',
15 | content='Hello world',
16 | created=datetime.datetime.now()
17 | )
18 |
19 | def test_default_page_size(self):
20 | resp = self.client.get('/comments/')
21 | self.assertEqual(len(resp.data['results']), 10)
22 |
23 | def test_custom_page_size__less_then_maximum(self):
24 | resp = self.client.get('/comments/?limit=15')
25 | self.assertEqual(len(resp.data['results']), 15)
26 |
27 | def test_custom_page_size__more_then_maximum(self):
28 | resp = self.client.get('/comments/?limit=25')
29 | self.assertEqual(len(resp.data['results']), 20)
30 |
31 | def test_custom_page_size_with_max_value(self):
32 | resp = self.client.get('/comments/?limit=max')
33 | self.assertEqual(len(resp.data['results']), 20)
34 |
35 | def test_custom_page_size_with_max_value__for_view_without__paginate_by_param__attribute(self):
36 | resp = self.client.get(
37 | '/comments-without-paginate-by-param-attribute/?page_size=max')
38 | self.assertEqual(len(resp.data['results']), 10)
39 |
40 | def test_custom_page_size_with_max_value__for_view_without__max_paginate_by__attribute(self):
41 | resp = self.client.get(
42 | '/comments-without-max-paginate-by-attribute/?page_size=max')
43 | self.assertEqual(len(resp.data['results']), 10)
44 |
45 |
46 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.mixins.paginate_by_max_mixin.urls')
47 | class PaginateByMaxMixinTestBehavior__should_not_affect_view_if_DRF_does_not_supports__max_paginate_by(TestCase):
48 |
49 | def setUp(self):
50 | for i in range(30):
51 | CommentForPaginateByMaxMixin.objects.create(
52 | email='example@ya.ru',
53 | content='Hello world',
54 | created=datetime.datetime.now()
55 | )
56 |
57 | def test_default_page_size(self):
58 | resp = self.client.get('/comments/')
59 | self.assertEqual(len(resp.data['results']), 10)
60 |
61 | def test_custom_page_size__less_then_maximum(self):
62 | resp = self.client.get('/comments/?limit=15')
63 | self.assertEqual(len(resp.data['results']), 15)
64 |
65 | def test_custom_page_size__more_then_maximum(self):
66 | resp = self.client.get('/comments/?limit=25')
67 | self.assertEqual(len(resp.data['results']), 20)
68 |
69 | def test_custom_page_size_with_max_value(self):
70 | resp = self.client.get('/comments/?limit=max')
71 | self.assertEqual(len(resp.data['results']), 20)
72 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/paginate_by_max_mixin/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework import routers
2 |
3 | from .views import (
4 | CommentViewSet,
5 | CommentWithoutPaginateByParamViewSet,
6 | CommentWithoutMaxPaginateByAttributeViewSet,
7 | )
8 |
9 |
10 | viewset_router = routers.DefaultRouter()
11 | viewset_router.register('comments', CommentViewSet, basename='1')
12 | viewset_router.register('comments-without-paginate-by-param-attribute', CommentWithoutPaginateByParamViewSet, basename='2')
13 | viewset_router.register('comments-without-max-paginate-by-attribute', CommentWithoutMaxPaginateByAttributeViewSet, basename='3')
14 | urlpatterns = viewset_router.urls
15 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/mixins/paginate_by_max_mixin/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import viewsets
2 |
3 | from .pagination import WithMaxPagination, FixedPagination, FlexiblePagination
4 | from .models import CommentForPaginateByMaxMixin
5 | from .serializers import CommentSerializer
6 |
7 |
8 | class CommentViewSet(viewsets.ReadOnlyModelViewSet):
9 | serializer_class = CommentSerializer
10 | pagination_class = WithMaxPagination
11 | queryset = CommentForPaginateByMaxMixin.objects.all().order_by('id')
12 |
13 |
14 | class CommentWithoutPaginateByParamViewSet(viewsets.ReadOnlyModelViewSet):
15 | serializer_class = CommentSerializer
16 | pagination_class = FixedPagination
17 | queryset = CommentForPaginateByMaxMixin.objects.all()
18 |
19 |
20 | class CommentWithoutMaxPaginateByAttributeViewSet(viewsets.ReadOnlyModelViewSet):
21 | pagination_class = FlexiblePagination
22 | serializer_class = CommentSerializer
23 | queryset = CommentForPaginateByMaxMixin.objects.all()
24 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/models.py:
--------------------------------------------------------------------------------
1 | # from .concurrency.conditional_request.models import *
2 | from .key_constructor.bits.models import *
3 | from .mixins.detail_serializer_mixin.models import *
4 | from .mixins.list_destroy_model_mixin.models import *
5 | from .mixins.list_update_model_mixin.models import *
6 | from .mixins.paginate_by_max_mixin.models import *
7 | from .permissions.extended_django_object_permissions.models import *
8 | from .routers.models import *
9 | from .routers.extended_default_router.models import *
10 | from .routers.nested_router_mixin.models import *
11 | from ._concurrency.conditional_request.models import *
12 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/permissions/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/permissions/extended_django_object_permissions/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/permissions/extended_django_object_permissions/models.py:
--------------------------------------------------------------------------------
1 | import django
2 | from django.db import models
3 |
4 |
5 | class PermissionsComment(models.Model):
6 | text = models.CharField(max_length=100)
7 |
8 | class Meta:
9 | if django.VERSION < (2, 1):
10 | permissions = (
11 | ('view_permissionscomment', 'Can view comment'),
12 | # add, change, delete built in to django
13 | )
14 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/permissions/extended_django_object_permissions/tests.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from django.contrib.auth.models import User, Group, Permission
4 | from django.contrib.contenttypes.models import ContentType
5 | from django.test import override_settings
6 |
7 | from rest_framework import status
8 | from rest_framework.test import APITestCase
9 |
10 | from tests_app.testutils import basic_auth_header
11 | from .models import PermissionsComment
12 |
13 |
14 | class ExtendedDjangoObjectPermissionTestMixin:
15 | def setUp(self):
16 | from guardian.shortcuts import assign_perm
17 |
18 | # create users
19 | create = User.objects.create_user
20 | users = {
21 | 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
22 | 'readonly': create('readonly', 'readonly@example.com', 'password'),
23 | 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
24 | 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
25 | }
26 |
27 | # create custom permission
28 | Permission.objects.get_or_create(
29 | codename='view_permissionscomment',
30 | content_type=ContentType.objects.get_for_model(PermissionsComment),
31 | defaults={'name': 'Can view comment'},
32 | )
33 |
34 | # give everyone model level permissions, as we are not testing those
35 | everyone = Group.objects.create(name='everyone')
36 | model_name = PermissionsComment._meta.model_name
37 | app_label = PermissionsComment._meta.app_label
38 | f = '{0}_{1}'.format
39 | perms = {
40 | 'view': f('view', model_name),
41 | 'change': f('change', model_name),
42 | 'delete': f('delete', model_name)
43 | }
44 | for perm in perms.values():
45 | perm = '{0}.{1}'.format(app_label, perm)
46 | assign_perm(perm, everyone)
47 | everyone.user_set.add(*users.values())
48 |
49 | # appropriate object level permissions
50 | readers = Group.objects.create(name='readers')
51 | writers = Group.objects.create(name='writers')
52 | deleters = Group.objects.create(name='deleters')
53 |
54 | model = PermissionsComment.objects.create(text='foo', id=1)
55 |
56 | assign_perm(perms['view'], readers, model)
57 | assign_perm(perms['change'], writers, model)
58 | assign_perm(perms['delete'], deleters, model)
59 |
60 | readers.user_set.add(users['fullaccess'], users['readonly'])
61 | writers.user_set.add(users['fullaccess'], users['writeonly'])
62 | deleters.user_set.add(users['fullaccess'], users['deleteonly'])
63 |
64 | self.credentials = {}
65 | for user in users.values():
66 | self.credentials[user.username] = basic_auth_header(user.username, 'password')
67 |
68 |
69 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.permissions.extended_django_object_permissions.urls')
70 | class ExtendedDjangoObjectPermissionsTest_should_inherit_standard(ExtendedDjangoObjectPermissionTestMixin,
71 | APITestCase):
72 |
73 | # Delete
74 | def test_can_delete_permissions(self):
75 | response = self.client.delete(
76 | '/comments/1/',
77 | **{'HTTP_AUTHORIZATION': self.credentials['deleteonly']})
78 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
79 |
80 | def test_cannot_delete_permissions(self):
81 | response = self.client.delete(
82 | '/comments/1/',
83 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']})
84 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
85 |
86 | # Update
87 | def test_can_update_permissions(self):
88 | response = self.client.patch(
89 | '/comments/1/',
90 | content_type='application/json',
91 | data=json.dumps({'text': 'foobar'}),
92 | **{
93 | 'HTTP_AUTHORIZATION': self.credentials['writeonly']
94 | }
95 | )
96 | self.assertEqual(response.status_code, status.HTTP_200_OK)
97 | self.assertEqual(response.data.get('text'), 'foobar')
98 |
99 | def test_cannot_update_permissions(self):
100 | response = self.client.patch(
101 | '/comments/1/',
102 | content_type='application/json',
103 | data=json.dumps({'text': 'foobar'}),
104 | **{
105 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly']
106 | }
107 | )
108 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
109 |
110 | def test_cannot_update_permissions_non_existing(self):
111 | response = self.client.patch(
112 | '/comments/999/',
113 | content_type='application/json',
114 | data=json.dumps({'text': 'foobar'}),
115 | **{
116 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly']
117 | }
118 | )
119 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
120 |
121 | # Read
122 | def test_can_read_permissions(self):
123 | response = self.client.get(
124 | '/comments/1/',
125 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']})
126 | self.assertEqual(response.status_code, status.HTTP_200_OK)
127 |
128 | def test_cannot_read_permissions(self):
129 | response = self.client.get(
130 | '/comments/1/',
131 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']})
132 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
133 |
134 | # Read list
135 | def test_can_read_list_permissions(self):
136 | response = self.client.get(
137 | '/comments-permission-filter-backend/',
138 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']}
139 | )
140 | self.assertEqual(response.status_code, status.HTTP_200_OK)
141 | self.assertEqual(response.data[0].get('id'), 1)
142 |
143 | def test_cannot_read_list_permissions(self):
144 | response = self.client.get(
145 | '/comments-permission-filter-backend/',
146 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']}
147 | )
148 | self.assertEqual(response.status_code, status.HTTP_200_OK)
149 | self.assertListEqual(response.data, [])
150 |
151 |
152 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.permissions.extended_django_object_permissions.urls')
153 | class ExtendedDjangoObjectPermissionsTest_without_hiding_forbidden_objects(ExtendedDjangoObjectPermissionTestMixin,
154 | APITestCase):
155 |
156 | # Delete
157 | def test_can_delete_permissions(self):
158 | response = self.client.delete(
159 | '/comments-without-hiding-forbidden-objects/1/',
160 | **{'HTTP_AUTHORIZATION': self.credentials['deleteonly']}
161 | )
162 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
163 |
164 | def test_cannot_delete_permissions(self):
165 | response = self.client.delete(
166 | '/comments-without-hiding-forbidden-objects/1/',
167 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']}
168 | )
169 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
170 |
171 | # Update
172 | def test_can_update_permissions(self):
173 | response = self.client.patch(
174 | '/comments-without-hiding-forbidden-objects/1/',
175 | content_type='application/json',
176 | data=json.dumps({'text': 'foobar'}),
177 | **{
178 | 'HTTP_AUTHORIZATION': self.credentials['writeonly']
179 | }
180 | )
181 | self.assertEqual(response.status_code, status.HTTP_200_OK)
182 | self.assertEqual(response.data.get('text'), 'foobar')
183 |
184 | def test_cannot_update_permissions(self):
185 | response = self.client.patch(
186 | '/comments-without-hiding-forbidden-objects/1/',
187 | content_type='application/json',
188 | data=json.dumps({'text': 'foobar'}),
189 | **{
190 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly']
191 | }
192 | )
193 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
194 |
195 | def test_cannot_update_permissions_non_existing(self):
196 | response = self.client.patch(
197 | '/comments-without-hiding-forbidden-objects/999/',
198 | content_type='application/json',
199 | data=json.dumps({'text': 'foobar'}),
200 | **{
201 | 'HTTP_AUTHORIZATION': self.credentials['deleteonly']
202 | }
203 | )
204 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
205 |
206 | # Read
207 | def test_can_read_permissions(self):
208 | response = self.client.get(
209 | '/comments-without-hiding-forbidden-objects/1/',
210 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']}
211 | )
212 | self.assertEqual(response.status_code, status.HTTP_200_OK)
213 |
214 | def test_cannot_read_permissions(self):
215 | response = self.client.get(
216 | '/comments-without-hiding-forbidden-objects/1/',
217 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']}
218 | )
219 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
220 |
221 | # Read list
222 | def test_can_read_list_permissions(self):
223 | response = self.client.get(
224 | '/comments-without-hiding-forbidden-objects-permission-filter-backend/',
225 | **{'HTTP_AUTHORIZATION': self.credentials['readonly']}
226 | )
227 | self.assertEqual(response.status_code, status.HTTP_200_OK)
228 | self.assertEqual(response.data[0].get('id'), 1)
229 |
230 | def test_cannot_read_list_permissions(self):
231 | response = self.client.get(
232 | '/comments-without-hiding-forbidden-objects-permission-filter-backend/',
233 | **{'HTTP_AUTHORIZATION': self.credentials['writeonly']}
234 | )
235 | self.assertEqual(response.status_code, status.HTTP_200_OK)
236 | self.assertListEqual(response.data, [])
237 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/permissions/extended_django_object_permissions/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework import routers
2 |
3 | from .views import (
4 | CommentViewSet,
5 | CommentViewSetPermissionFilterBackend,
6 | CommentViewSetWithoutHidingForbiddenObjects,
7 | CommentViewSetWithoutHidingForbiddenObjectsPermissionFilterBackend
8 | )
9 |
10 |
11 | viewset_router = routers.DefaultRouter()
12 | viewset_router.register('comments', CommentViewSet, basename='alt1')
13 | viewset_router.register('comments-permission-filter-backend', CommentViewSetPermissionFilterBackend, basename='alt2')
14 | viewset_router.register('comments-without-hiding-forbidden-objects', CommentViewSetWithoutHidingForbiddenObjects, basename='alt3')
15 | viewset_router.register(
16 | 'comments-without-hiding-forbidden-objects-permission-filter-backend',
17 | CommentViewSetWithoutHidingForbiddenObjectsPermissionFilterBackend,
18 | basename='alt4'
19 | )
20 | urlpatterns = viewset_router.urls
21 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/permissions/extended_django_object_permissions/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import viewsets, serializers
2 | from rest_framework import authentication
3 |
4 | try:
5 | # djangorestframework >= 3.9
6 | from rest_framework_guardian.filters import DjangoObjectPermissionsFilter
7 | except ImportError:
8 | from rest_framework.filters import DjangoObjectPermissionsFilter
9 |
10 | try:
11 | from rest_framework_extensions.permissions import ExtendedDjangoObjectPermissions
12 | except ImportError:
13 | class ExtendedDjangoObjectPermissions:
14 | pass
15 |
16 | from .models import PermissionsComment
17 |
18 |
19 | class CommentObjectPermissions(ExtendedDjangoObjectPermissions):
20 | perms_map = {
21 | 'GET': ['%(app_label)s.view_%(model_name)s'],
22 | 'OPTIONS': ['%(app_label)s.view_%(model_name)s'],
23 | 'HEAD': ['%(app_label)s.view_%(model_name)s'],
24 | 'POST': ['%(app_label)s.add_%(model_name)s'],
25 | 'PUT': ['%(app_label)s.change_%(model_name)s'],
26 | 'PATCH': ['%(app_label)s.change_%(model_name)s'],
27 | 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
28 | }
29 |
30 |
31 | class PermissionsCommentSerializer(serializers.ModelSerializer):
32 | class Meta:
33 | model = PermissionsComment
34 | fields = '__all__'
35 |
36 |
37 | class CommentObjectPermissionsWithoutHidingForbiddenObjects(CommentObjectPermissions):
38 | hide_forbidden_for_read_objects = False
39 |
40 |
41 | class CommentViewSet(viewsets.ModelViewSet):
42 | queryset = PermissionsComment.objects.all()
43 | serializer_class = PermissionsCommentSerializer
44 | authentication_classes = [authentication.BasicAuthentication]
45 | permission_classes = (CommentObjectPermissions,)
46 |
47 |
48 | class CommentViewSetPermissionFilterBackend(CommentViewSet):
49 | filter_backends = (DjangoObjectPermissionsFilter,)
50 |
51 |
52 | class CommentViewSetWithoutHidingForbiddenObjects(CommentViewSet):
53 | permission_classes = (CommentObjectPermissionsWithoutHidingForbiddenObjects,)
54 |
55 |
56 | class CommentViewSetWithoutHidingForbiddenObjectsPermissionFilterBackend(CommentViewSetWithoutHidingForbiddenObjects):
57 | filter_backends = (DjangoObjectPermissionsFilter,)
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/extended_default_router/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/extended_default_router/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class DefaultRouterUserModel(models.Model):
5 | name = models.CharField(max_length=10)
6 | groups = models.ManyToManyField('DefaultRouterGroupModel', related_name='user_groups')
7 |
8 |
9 | class DefaultRouterGroupModel(models.Model):
10 | name = models.CharField(max_length=10)
11 | permissions = models.ManyToManyField('DefaultRouterPermissionModel')
12 |
13 |
14 | class DefaultRouterPermissionModel(models.Model):
15 | name = models.CharField(max_length=10)
16 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/extended_default_router/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import override_settings
2 | from django.urls import NoReverseMatch
3 |
4 | from rest_framework.test import APITestCase
5 |
6 |
7 | @override_settings(ROOT_URLCONF='tests_app.tests.functional.routers.extended_default_router.urls')
8 | class ExtendedDefaultRouterTestBehaviour(APITestCase):
9 |
10 | def test_index_page(self):
11 | try:
12 | response = self.client.get('/')
13 | except NoReverseMatch:
14 | issue = 'https://github.com/chibisov/drf-extensions/issues/14'
15 | self.fail('DefaultRouter tries to reverse nested routes and breaks with error. NoReverseMatch should be '
16 | 'handled for nested routes. They must be excluded from index page. ' + issue)
17 | self.assertEqual(response.status_code, 200)
18 |
19 | expected = {
20 | 'users': 'http://testserver/users/',
21 | 'groups': 'http://testserver/groups/',
22 | 'permissions': 'http://testserver/permissions/',
23 | }
24 | self.assertEqual(response.data, expected)
25 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/extended_default_router/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.routers import ExtendedDefaultRouter
2 |
3 | from .views import (
4 | UserViewSet,
5 | GroupViewSet,
6 | PermissionViewSet,
7 | )
8 |
9 |
10 | router = ExtendedDefaultRouter()
11 | # nested routes
12 | (
13 | router.register(r'users', UserViewSet)
14 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user_groups'])
15 | .register(r'permissions', PermissionViewSet, 'users-groups-permission', parents_query_lookups=['group__user', 'group'])
16 | )
17 | # simple routes
18 | router.register(r'groups', GroupViewSet, 'group')
19 | router.register(r'permissions', PermissionViewSet, 'permission')
20 |
21 | urlpatterns = router.urls
22 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/extended_default_router/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework.viewsets import ModelViewSet
2 |
3 | from rest_framework_extensions.mixins import NestedViewSetMixin
4 |
5 | from .models import (
6 | DefaultRouterUserModel,
7 | DefaultRouterGroupModel,
8 | DefaultRouterPermissionModel,
9 | )
10 |
11 |
12 | class UserViewSet(NestedViewSetMixin, ModelViewSet):
13 | queryset = DefaultRouterUserModel.objects.all()
14 |
15 |
16 | class GroupViewSet(NestedViewSetMixin, ModelViewSet):
17 | queryset = DefaultRouterGroupModel.objects.all()
18 |
19 |
20 | class PermissionViewSet(NestedViewSetMixin, ModelViewSet):
21 | queryset = DefaultRouterPermissionModel.objects.all()
22 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class RouterTestModel(models.Model):
5 | uuid = models.CharField(max_length=20)
6 | text = models.CharField(max_length=200)
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/nested_router_mixin/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/nested_router_mixin/models.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | from django.db import models
4 | from django.contrib.contenttypes.fields import GenericForeignKey
5 |
6 |
7 | class NestedRouterMixinUserModel(models.Model):
8 | email = models.EmailField(blank=True, null=True)
9 | code = models.UUIDField(default=uuid.uuid4)
10 | name = models.CharField(max_length=10)
11 | groups = models.ManyToManyField(
12 | 'NestedRouterMixinGroupModel', related_name='user_groups')
13 |
14 |
15 | class NestedRouterMixinGroupModel(models.Model):
16 | name = models.CharField(max_length=10)
17 | permissions = models.ManyToManyField('NestedRouterMixinPermissionModel')
18 |
19 |
20 | class NestedRouterMixinPermissionModel(models.Model):
21 | name = models.CharField(max_length=10)
22 |
23 |
24 | class NestedRouterMixinTaskModel(models.Model):
25 | title = models.CharField(max_length=30)
26 |
27 |
28 | class NestedRouterMixinBookModel(models.Model):
29 | title = models.CharField(max_length=30)
30 |
31 |
32 | class NestedRouterMixinCommentModel(models.Model):
33 | content_type = models.ForeignKey(
34 | "contenttypes.ContentType",
35 | blank=True,
36 | null=True,
37 | on_delete=models.CASCADE,
38 | )
39 | object_id = models.PositiveIntegerField(blank=True, null=True)
40 | content_object = GenericForeignKey()
41 | text = models.CharField(max_length=30)
42 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/nested_router_mixin/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 |
3 | from .models import (
4 | NestedRouterMixinUserModel as UserModel,
5 | NestedRouterMixinGroupModel as GroupModel,
6 | NestedRouterMixinPermissionModel as PermissionModel,
7 | NestedRouterMixinTaskModel as TaskModel,
8 | NestedRouterMixinBookModel as BookModel,
9 | NestedRouterMixinCommentModel as CommentModel,
10 | )
11 |
12 |
13 | class UserSerializer(serializers.ModelSerializer):
14 | class Meta:
15 | model = UserModel
16 | fields = (
17 | 'id',
18 | 'name'
19 | )
20 |
21 |
22 | class GroupSerializer(serializers.ModelSerializer):
23 | class Meta:
24 | model = GroupModel
25 | fields = (
26 | 'id',
27 | 'name'
28 | )
29 |
30 |
31 | class PermissionSerializer(serializers.ModelSerializer):
32 | class Meta:
33 | model = PermissionModel
34 | fields = (
35 | 'id',
36 | 'name'
37 | )
38 |
39 |
40 | class TaskSerializer(serializers.ModelSerializer):
41 | class Meta:
42 | model = TaskModel
43 | fields = (
44 | 'id',
45 | 'title'
46 | )
47 |
48 |
49 | class BookSerializer(serializers.ModelSerializer):
50 | class Meta:
51 | model = BookModel
52 | fields = (
53 | 'id',
54 | 'title'
55 | )
56 |
57 |
58 | class CommentSerializer(serializers.ModelSerializer):
59 | class Meta:
60 | model = CommentModel
61 | fields = (
62 | 'id',
63 | 'content_type',
64 | 'object_id',
65 | 'text'
66 | )
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/nested_router_mixin/urls.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.routers import ExtendedSimpleRouter
2 |
3 | from .views import (
4 | UserViewSet,
5 | GroupViewSet,
6 | PermissionViewSet,
7 | )
8 |
9 |
10 | router = ExtendedSimpleRouter()
11 | # main routes
12 | (
13 | router.register(r'users', UserViewSet)
14 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user_groups'])
15 | .register(r'permissions', PermissionViewSet, 'users-groups-permission', parents_query_lookups=['group__user', 'group'])
16 | )
17 |
18 | # register on one depth
19 | permissions_routes = router.register(r'permissions', PermissionViewSet)
20 | permissions_routes.register(r'groups', GroupViewSet, 'permissions-group', parents_query_lookups=['permissions'])
21 | permissions_routes.register(r'users', UserViewSet, 'permissions-user', parents_query_lookups=['groups__permissions'])
22 |
23 | # simple routes
24 | router.register(r'groups', GroupViewSet, 'group')
25 | router.register(r'permissions', PermissionViewSet, 'permission')
26 |
27 | urlpatterns = router.urls
28 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/nested_router_mixin/urls_generic_relations.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.routers import ExtendedSimpleRouter
2 |
3 | from .views import (
4 | TaskViewSet,
5 | TaskCommentViewSet,
6 | BookViewSet,
7 | BookCommentViewSet
8 | )
9 |
10 |
11 | router = ExtendedSimpleRouter()
12 | # tasks route
13 | (
14 | router.register(r'tasks', TaskViewSet)
15 | .register(r'comments', TaskCommentViewSet, 'tasks-comment', parents_query_lookups=['object_id'])
16 | )
17 | # books route
18 | (
19 | router.register(r'books', BookViewSet)
20 | .register(r'comments', BookCommentViewSet, 'books-comment', parents_query_lookups=['object_id'])
21 | )
22 |
23 | urlpatterns = router.urls
24 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/nested_router_mixin/urls_parent_viewset_lookup.py:
--------------------------------------------------------------------------------
1 | from rest_framework_extensions.routers import ExtendedSimpleRouter
2 |
3 | from .views import (
4 | UserViewSetWithEmailLookup,
5 | UserViewSetWithUUIDLookup,
6 | GroupViewSet,
7 | )
8 |
9 |
10 | router = ExtendedSimpleRouter()
11 |
12 | # main routes
13 | (
14 | router.register(r'users', UserViewSetWithEmailLookup, basename='users-by-uuid')
15 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user_groups__email'])
16 | )
17 |
18 | # uuid routes
19 | (
20 | router.register(r'users-by-uuid', UserViewSetWithUUIDLookup)
21 | .register(r'groups', GroupViewSet, 'users-group-uuid', parents_query_lookups=['user_groups__code'])
22 | )
23 |
24 | urlpatterns = router.urls
25 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/nested_router_mixin/views.py:
--------------------------------------------------------------------------------
1 | from django.contrib.contenttypes.models import ContentType
2 | from django.core.exceptions import ValidationError
3 | from django.http import Http404
4 |
5 | from rest_framework.decorators import action
6 | from rest_framework.response import Response
7 | from rest_framework.viewsets import ModelViewSet
8 | import uuid
9 | from rest_framework_extensions.mixins import NestedViewSetMixin
10 |
11 | from .models import (
12 | NestedRouterMixinUserModel as UserModel,
13 | NestedRouterMixinGroupModel as GroupModel,
14 | NestedRouterMixinPermissionModel as PermissionModel,
15 | NestedRouterMixinTaskModel as TaskModel,
16 | NestedRouterMixinBookModel as BookModel,
17 | NestedRouterMixinCommentModel as CommentModel
18 | )
19 | from .serializers import (
20 | UserSerializer,
21 | GroupSerializer,
22 | PermissionSerializer,
23 | TaskSerializer,
24 | BookSerializer,
25 | CommentSerializer
26 | )
27 |
28 |
29 | class UserViewSet(NestedViewSetMixin, ModelViewSet):
30 | queryset = UserModel.objects.all()
31 | serializer_class = UserSerializer
32 |
33 | @action(detail=False, methods=['post'], url_path='users-list-action')
34 | def users_list_action(self, request, *args, **kwargs):
35 | return Response('users list action')
36 |
37 | @action(detail=True, methods=['post'], url_path='users-action')
38 | def users_action(self, request, *args, **kwargs):
39 | return Response('users action')
40 |
41 |
42 | class GroupViewSet(NestedViewSetMixin, ModelViewSet):
43 | queryset = GroupModel.objects.all()
44 | serializer_class = GroupSerializer
45 |
46 | @action(detail=False, url_path='groups-list-link')
47 | def groups_list_link(self, request, *args, **kwargs):
48 | return Response('groups list link')
49 |
50 | @action(detail=True, url_path='groups-link')
51 | def groups_link(self, request, *args, **kwargs):
52 | return Response('groups link')
53 |
54 |
55 | class PermissionViewSet(NestedViewSetMixin, ModelViewSet):
56 | queryset = PermissionModel.objects.all()
57 | serializer_class = PermissionSerializer
58 |
59 | @action(detail=False, methods=['post'], url_path='permissions-list-action')
60 | def permissions_list_action(self, request, *args, **kwargs):
61 | return Response('permissions list action')
62 |
63 | @action(detail=True, methods=['post'], url_path='permissions-action')
64 | def permissions_action(self, request, *args, **kwargs):
65 | return Response('permissions action')
66 |
67 |
68 | class TaskViewSet(NestedViewSetMixin, ModelViewSet):
69 | queryset = TaskModel.objects.all()
70 | serializer_class = TaskSerializer
71 |
72 |
73 | class BookViewSet(NestedViewSetMixin, ModelViewSet):
74 | queryset = BookModel.objects.all()
75 | serializer_class = BookSerializer
76 |
77 |
78 | class CommentViewSet(NestedViewSetMixin, ModelViewSet):
79 | queryset = CommentModel.objects.all()
80 | serializer_class = CommentSerializer
81 |
82 |
83 | class TaskCommentViewSet(CommentViewSet):
84 | def get_queryset(self):
85 | return super().get_queryset().filter(
86 | content_type=ContentType.objects.get_for_model(TaskModel)
87 | )
88 |
89 |
90 | class BookCommentViewSet(CommentViewSet):
91 | def get_queryset(self):
92 | return super().get_queryset().filter(
93 | content_type=ContentType.objects.get_for_model(BookModel)
94 | )
95 |
96 |
97 | class UserViewSetWithEmailLookup(NestedViewSetMixin, ModelViewSet):
98 | queryset = UserModel.objects.all()
99 | serializer_class = UserSerializer
100 | lookup_field = 'email'
101 | lookup_value_regex = r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+'
102 |
103 |
104 | class UserViewSetWithUUIDLookup(NestedViewSetMixin, ModelViewSet):
105 | queryset = UserModel.objects.all()
106 | serializer_class = UserSerializer
107 | lookup_field = 'code'
108 |
109 | def get_object(self):
110 | try:
111 | # Try to validate UUID before getting object
112 | lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
113 | if lookup_url_kwarg in self.kwargs:
114 | try:
115 | uuid.UUID(str(self.kwargs[lookup_url_kwarg]))
116 | except ValueError:
117 | raise Http404
118 | return super().get_object()
119 | except (ValueError, ValidationError):
120 | raise Http404
121 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 |
3 | from rest_framework_extensions.routers import ExtendedSimpleRouter
4 |
5 | from tests_app.testutils import get_url_pattern_by_regex_pattern
6 | from .views import RouterViewSet
7 |
8 |
9 | class TestTrailingSlashIncluded(TestCase):
10 | def test_urls_have_trailing_slash_by_default(self):
11 | router = ExtendedSimpleRouter()
12 | router.register(r'router-viewset', RouterViewSet)
13 | urls = router.urls
14 |
15 | for exp in ['^router-viewset/$',
16 | '^router-viewset/(?P[^/.]+)/$',
17 | '^router-viewset/list_controller/$',
18 | '^router-viewset/(?P[^/.]+)/detail_controller/$']:
19 | msg = 'Should find url pattern with regexp %s' % exp
20 | self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg)
21 |
22 |
23 | class TestTrailingSlashRemoved(TestCase):
24 | def test_urls_can_have_trailing_slash_removed(self):
25 | router = ExtendedSimpleRouter(trailing_slash=False)
26 | router.register(r'router-viewset', RouterViewSet)
27 | urls = router.urls
28 |
29 | for exp in ['^router-viewset$',
30 | '^router-viewset/(?P[^/.]+)$',
31 | '^router-viewset/list_controller$',
32 | '^router-viewset/(?P[^/.]+)/detail_controller$']:
33 | msg = 'Should find url pattern with regexp %s' % exp
34 | self.assertIsNotNone(get_url_pattern_by_regex_pattern(urls, exp), msg=msg)
35 |
--------------------------------------------------------------------------------
/tests_app/tests/functional/routers/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import viewsets
2 | from rest_framework.decorators import action
3 |
4 | from .models import RouterTestModel
5 |
6 |
7 | class RouterViewSet(viewsets.ModelViewSet):
8 | queryset = RouterTestModel.objects.all()
9 |
10 | @action(detail=True)
11 | def detail_controller(self):
12 | pass
13 |
14 | @action(detail=False)
15 | def list_controller(self):
16 | pass
17 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/unit/_etag/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/_etag/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/unit/_etag/decorators/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/cache/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/cache/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/unit/cache/decorators/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/decorators/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/decorators/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/unit/decorators/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 | from rest_framework import pagination, viewsets
3 | from rest_framework_extensions.decorators import paginate
4 |
5 |
6 | class TestPaginateDecorator(TestCase):
7 |
8 | def test_empty_pagination_class(self):
9 | msg = "@paginate missing required argument: 'pagination_class'"
10 | with self.assertRaisesMessage(AssertionError, msg):
11 | @paginate()
12 | class MockGenericViewSet(viewsets.GenericViewSet):
13 | pass
14 |
15 | def test_adding_page_number_pagination(self):
16 | """
17 | Other default pagination classes' test result will be same as this even if kwargs changed to anything.
18 | """
19 |
20 | @paginate(pagination_class=pagination.PageNumberPagination, page_size=5, ordering='-created_at')
21 | class MockGenericViewSet(viewsets.GenericViewSet):
22 | pass
23 |
24 | assert hasattr(MockGenericViewSet, 'pagination_class')
25 | assert MockGenericViewSet.pagination_class().page_size == 5
26 | assert MockGenericViewSet.pagination_class().ordering == '-created_at'
27 |
28 | def test_adding_custom_pagination(self):
29 | class CustomPagination(pagination.BasePagination):
30 | pass
31 |
32 | @paginate(pagination_class=CustomPagination, kwarg1='kwarg1', kwarg2='kwarg2')
33 | class MockGenericViewSet(viewsets.GenericViewSet):
34 | pass
35 |
36 | assert hasattr(MockGenericViewSet, 'pagination_class')
37 | assert MockGenericViewSet.pagination_class().kwarg1 == 'kwarg1'
38 | assert MockGenericViewSet.pagination_class().kwarg2 == 'kwarg2'
39 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/key_constructor/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/key_constructor/bits/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/key_constructor/bits/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class BitTestModel(models.Model):
5 | is_active = models.BooleanField(default=False)
6 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/key_constructor/constructor/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 2.2 on 2019-04-16 11:51
2 |
3 | from django.db import migrations, models
4 | import django.db.models.deletion
5 |
6 |
7 | class Migration(migrations.Migration):
8 |
9 | initial = True
10 |
11 | dependencies = [
12 | ]
13 |
14 | operations = [
15 | migrations.CreateModel(
16 | name='BitTestModel',
17 | fields=[
18 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
19 | ('is_active', models.BooleanField(default=False)),
20 | ],
21 | ),
22 | migrations.CreateModel(
23 | name='NestedRouterMixinGroupModel',
24 | fields=[
25 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
26 | ('name', models.CharField(max_length=10)),
27 | ],
28 | ),
29 | migrations.CreateModel(
30 | name='NestedRouterMixinPermissionModel',
31 | fields=[
32 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
33 | ('name', models.CharField(max_length=10)),
34 | ],
35 | ),
36 | migrations.CreateModel(
37 | name='UserModel',
38 | fields=[
39 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
40 | ('name', models.CharField(max_length=20)),
41 | ],
42 | ),
43 | migrations.CreateModel(
44 | name='NestedRouterMixinUserModel',
45 | fields=[
46 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
47 | ('name', models.CharField(max_length=10)),
48 | ('groups', models.ManyToManyField(related_name='user_groups', to='unit.NestedRouterMixinGroupModel')),
49 | ],
50 | ),
51 | migrations.AddField(
52 | model_name='nestedroutermixingroupmodel',
53 | name='permissions',
54 | field=models.ManyToManyField(to='unit.NestedRouterMixinPermissionModel'),
55 | ),
56 | migrations.CreateModel(
57 | name='CommentModel',
58 | fields=[
59 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
60 | ('title', models.CharField(max_length=20)),
61 | ('text', models.CharField(max_length=200)),
62 | ('attachment', models.FileField(blank=True, max_length=500, null=True, upload_to='test_serializers')),
63 | ('hidden_text', models.CharField(blank=True, max_length=200, null=True)),
64 | ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='comments', to='unit.UserModel')),
65 | ('users_liked', models.ManyToManyField(blank=True, to='unit.UserModel')),
66 | ],
67 | ),
68 | ]
69 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/migrations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/migrations/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/unit/models.py:
--------------------------------------------------------------------------------
1 | from .key_constructor.bits.models import *
2 | from .routers.nested_router_mixin.models import *
3 | from .serializers.models import *
4 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/routers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chibisov/drf-extensions/6f8db9763bf6ff100287d9c920159938ce534f9d/tests_app/tests/unit/routers/__init__.py
--------------------------------------------------------------------------------
/tests_app/tests/unit/routers/nested_router_mixin/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/routers/nested_router_mixin/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class NestedRouterMixinPermissionModel(models.Model):
5 | name = models.CharField(max_length=10)
6 |
7 |
8 | class NestedRouterMixinGroupModel(models.Model):
9 | name = models.CharField(max_length=10)
10 | permissions = models.ManyToManyField(
11 | 'NestedRouterMixinPermissionModel')
12 |
13 |
14 | class NestedRouterMixinUserModel(models.Model):
15 | name = models.CharField(max_length=10)
16 | groups = models.ManyToManyField(
17 | 'NestedRouterMixinGroupModel', related_name='user_groups')
18 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/routers/nested_router_mixin/tests.py:
--------------------------------------------------------------------------------
1 | from rest_framework.test import APITestCase
2 | from rest_framework_extensions.routers import ExtendedSimpleRouter
3 | from rest_framework_extensions.utils import compose_parent_pk_kwarg_name
4 | from .views import (
5 | UserViewSet,
6 | GroupViewSet,
7 | PermissionViewSet,
8 | CustomRegexUserViewSet,
9 | CustomRegexGroupViewSet,
10 | CustomRegexPermissionViewSet,
11 | )
12 |
13 |
14 | def get_regex_pattern(urlpattern):
15 | return urlpattern.pattern.regex.pattern
16 |
17 |
18 | class NestedRouterMixinTest(APITestCase):
19 | def get_lookup_regex(self, value):
20 | return '(?P<{0}>[^/.]+)'.format(value)
21 |
22 | def get_parent_lookup_regex(self, value):
23 | return '(?P<{0}>[^/.]+)'.format(compose_parent_pk_kwarg_name(value))
24 |
25 | def get_custom_regex_lookup(self, pk_kwarg_name, lookup_value_regex):
26 | """ Build lookup regex with custom regular expression. """
27 | return '(?P<{pk_kwarg_name}>{lookup_value_regex})'.format(
28 | pk_kwarg_name=pk_kwarg_name,
29 | lookup_value_regex=lookup_value_regex
30 | )
31 |
32 | def get_custom_regex_parent_lookup(self, parent_pk_kwarg_name,
33 | parent_lookup_value_regex):
34 | """ Build parent lookup regex with custom regular expression. """
35 | return self.get_custom_regex_lookup(
36 | compose_parent_pk_kwarg_name(parent_pk_kwarg_name),
37 | parent_lookup_value_regex
38 | )
39 |
40 | def test_one_route(self):
41 | router = ExtendedSimpleRouter()
42 | router.register(r'users', UserViewSet, 'user')
43 |
44 | # test user list
45 | self.assertEqual(router.urls[0].name, 'user-list')
46 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$')
47 |
48 | # test user detail
49 | self.assertEqual(router.urls[1].name, 'user-detail')
50 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format(self.get_lookup_regex('pk')))
51 |
52 | def test_nested_route(self):
53 | router = ExtendedSimpleRouter()
54 | (
55 | router.register(r'users', UserViewSet, 'user')
56 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user'])
57 | )
58 |
59 | # test user list
60 | self.assertEqual(router.urls[0].name, 'user-list')
61 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$')
62 |
63 | # test user detail
64 | self.assertEqual(router.urls[1].name, 'user-detail')
65 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format(self.get_lookup_regex('pk')))
66 |
67 | # test users group list
68 | self.assertEqual(router.urls[2].name, 'users-group-list')
69 | self.assertEqual(get_regex_pattern(router.urls[2]), r'^users/{0}/groups/$'.format(
70 | self.get_parent_lookup_regex('user')
71 | )
72 | )
73 |
74 | # test users group detail
75 | self.assertEqual(router.urls[3].name, 'users-group-detail')
76 | self.assertEqual(get_regex_pattern(router.urls[3]), r'^users/{0}/groups/{1}/$'.format(
77 | self.get_parent_lookup_regex('user'),
78 | self.get_lookup_regex('pk')
79 | ),
80 | )
81 |
82 | def test_nested_route_depth_3(self):
83 | router = ExtendedSimpleRouter()
84 | (
85 | router.register(r'users', UserViewSet, 'user')
86 | .register(r'groups', GroupViewSet, 'users-group', parents_query_lookups=['user'])
87 | .register(r'permissions', PermissionViewSet, 'users-groups-permission', parents_query_lookups=[
88 | 'group__user',
89 | 'group',
90 | ]
91 | )
92 | )
93 |
94 | # test user list
95 | self.assertEqual(router.urls[0].name, 'user-list')
96 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$')
97 |
98 | # test user detail
99 | self.assertEqual(router.urls[1].name, 'user-detail')
100 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format(self.get_lookup_regex('pk')))
101 |
102 | # test users group list
103 | self.assertEqual(router.urls[2].name, 'users-group-list')
104 | self.assertEqual(get_regex_pattern(router.urls[2]), r'^users/{0}/groups/$'.format(
105 | self.get_parent_lookup_regex('user')
106 | )
107 | )
108 |
109 | # test users group detail
110 | self.assertEqual(router.urls[3].name, 'users-group-detail')
111 | self.assertEqual(get_regex_pattern(router.urls[3]), r'^users/{0}/groups/{1}/$'.format(
112 | self.get_parent_lookup_regex('user'),
113 | self.get_lookup_regex('pk')
114 | ),
115 | )
116 |
117 | # test users groups permission list
118 | self.assertEqual(router.urls[4].name, 'users-groups-permission-list')
119 | self.assertEqual(get_regex_pattern(router.urls[4]), r'^users/{0}/groups/{1}/permissions/$'.format(
120 | self.get_parent_lookup_regex('group__user'),
121 | self.get_parent_lookup_regex('group'),
122 | )
123 | )
124 |
125 | # test users groups permission detail
126 | self.assertEqual(router.urls[5].name, 'users-groups-permission-detail')
127 | self.assertEqual(get_regex_pattern(router.urls[5]), r'^users/{0}/groups/{1}/permissions/{2}/$'.format(
128 | self.get_parent_lookup_regex('group__user'),
129 | self.get_parent_lookup_regex('group'),
130 | self.get_lookup_regex('pk')
131 | ),
132 | )
133 |
134 | def test_nested_route_depth_3_custom_regex(self):
135 | """
136 | Nested routes with over two level of depth should respect all parents'
137 | `lookup_value_regex` attribute.
138 | """
139 | router = ExtendedSimpleRouter()
140 | (
141 | router.register(r'users', CustomRegexUserViewSet, 'user')
142 | .register(r'groups', CustomRegexGroupViewSet, 'users-group',
143 | parents_query_lookups=['user'])
144 | .register(r'permissions', CustomRegexPermissionViewSet,
145 | 'users-groups-permission', parents_query_lookups=[
146 | 'group__user',
147 | 'group',
148 | ]
149 | )
150 | )
151 |
152 | # custom regex configuration
153 | user_viewset_regex = CustomRegexUserViewSet.lookup_value_regex
154 | group_viewset_regex = CustomRegexGroupViewSet.lookup_value_regex
155 | perm_viewset_regex = CustomRegexPermissionViewSet.lookup_value_regex
156 |
157 | # test user list
158 | self.assertEqual(router.urls[0].name, 'user-list')
159 | self.assertEqual(get_regex_pattern(router.urls[0]), r'^users/$')
160 |
161 | # test user detail
162 | self.assertEqual(router.urls[1].name, 'user-detail')
163 | self.assertEqual(get_regex_pattern(router.urls[1]), r'^users/{0}/$'.format(
164 | self.get_custom_regex_lookup('pk', user_viewset_regex))
165 | )
166 |
167 | # test users group list
168 | self.assertEqual(router.urls[2].name, 'users-group-list')
169 | self.assertEqual(get_regex_pattern(router.urls[2]), r'^users/{0}/groups/$'.format(
170 | self.get_custom_regex_parent_lookup('user', user_viewset_regex)
171 | )
172 | )
173 | # test users group detail
174 | self.assertEqual(router.urls[3].name, 'users-group-detail')
175 | self.assertEqual(get_regex_pattern(router.urls[3]), r'^users/{0}/groups/{1}/$'.format(
176 | self.get_custom_regex_parent_lookup('user', user_viewset_regex),
177 | self.get_custom_regex_lookup('pk', group_viewset_regex)
178 | ),
179 | )
180 | # test users groups permission list
181 | self.assertEqual(router.urls[4].name, 'users-groups-permission-list')
182 | self.assertEqual(get_regex_pattern(router.urls[4]), r'^users/{0}/groups/{1}/permissions/$'.format(
183 | self.get_custom_regex_parent_lookup('group__user', user_viewset_regex),
184 | self.get_custom_regex_parent_lookup('group', group_viewset_regex),
185 | )
186 | )
187 |
188 | # test users groups permission detail
189 | self.assertEqual(router.urls[5].name, 'users-groups-permission-detail')
190 | self.assertEqual(get_regex_pattern(router.urls[5]), r'^users/{0}/groups/{1}/permissions/{2}/$'.format(
191 | self.get_custom_regex_parent_lookup('group__user', user_viewset_regex),
192 | self.get_custom_regex_parent_lookup('group', group_viewset_regex),
193 | self.get_custom_regex_lookup('pk', perm_viewset_regex)
194 | ),
195 | )
196 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/routers/nested_router_mixin/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework.viewsets import ModelViewSet
2 |
3 | from .models import (
4 | NestedRouterMixinUserModel as UserModel,
5 | NestedRouterMixinGroupModel as GroupModel,
6 | NestedRouterMixinPermissionModel as PermissionModel,
7 | )
8 |
9 |
10 | class UserViewSet(ModelViewSet):
11 | model = UserModel
12 |
13 |
14 | class GroupViewSet(ModelViewSet):
15 | model = GroupModel
16 |
17 |
18 | class PermissionViewSet(ModelViewSet):
19 | model = PermissionModel
20 |
21 |
22 | class CustomRegexUserViewSet(ModelViewSet):
23 | lookup_value_regex = 'a'
24 | model = UserModel
25 |
26 |
27 | class CustomRegexGroupViewSet(ModelViewSet):
28 | lookup_value_regex = 'b'
29 | model = GroupModel
30 |
31 |
32 | class CustomRegexPermissionViewSet(ModelViewSet):
33 | lookup_value_regex = 'c'
34 | model = PermissionModel
35 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/routers/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 |
3 | from rest_framework import viewsets
4 | from rest_framework.decorators import action
5 | from rest_framework.response import Response
6 | from rest_framework_extensions.routers import ExtendedDefaultRouter
7 |
8 |
9 | class ExtendedDefaultRouterTest(TestCase):
10 | def setUp(self):
11 | self.router = ExtendedDefaultRouter()
12 |
13 | def get_routes_names(self, routes):
14 | return [i.name for i in routes]
15 |
16 | def get_dynamic_route_by_def_name(self, def_name, routes):
17 | try:
18 | return [i for i in routes if def_name in i.mapping.values()][0]
19 | except IndexError:
20 | return None
21 |
22 | def test_dynamic_list_route_should_come_before_detail_route(self):
23 | class BasicViewSet(viewsets.ViewSet):
24 | def list(self, request, *args, **kwargs):
25 | return Response({'method': 'list'})
26 |
27 | @action(detail=False)
28 | def detail1(self, request, *args, **kwargs):
29 | return Response({'method': 'detail1'})
30 |
31 | routes = self.router.get_routes(BasicViewSet)
32 | expected = [
33 | '{basename}-list',
34 | '{basename}-detail1',
35 | '{basename}-detail'
36 | ]
37 | msg = '@detail_route methods should come first in routes order'
38 | self.assertEqual(self.get_routes_names(routes), expected, msg)
39 |
40 | def test_detail_route(self):
41 | class BasicViewSet(viewsets.ViewSet):
42 | @action(detail=True)
43 | def action1(self, request, *args, **kwargs):
44 | pass
45 |
46 | routes = self.router.get_routes(BasicViewSet)
47 | action1_route = self.get_dynamic_route_by_def_name('action1', routes)
48 |
49 | msg = '@detail_route should map methods to def name'
50 | self.assertEqual(action1_route.mapping, {'get': 'action1'}, msg)
51 |
52 | msg = '@detail_route should use url with detail lookup'
53 | self.assertEqual(action1_route.url, u'^{prefix}/{lookup}/action1{trailing_slash}$', msg)
54 |
55 | def test_detail_route__with_methods(self):
56 | class BasicViewSet(viewsets.ViewSet):
57 | @action(detail=True, methods=['post'])
58 | def action1(self, request, *args, **kwargs):
59 | pass
60 |
61 | routes = self.router.get_routes(BasicViewSet)
62 | action1_route = self.get_dynamic_route_by_def_name('action1', routes)
63 |
64 | msg = '@detail_route should map methods to def name'
65 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg)
66 |
67 | msg = '@detail_route should use url with detail lookup'
68 | self.assertEqual(action1_route.url, u'^{prefix}/{lookup}/action1{trailing_slash}$', msg)
69 |
70 | def test_detail_route__with_methods__and__with_url_path(self):
71 | class BasicViewSet(viewsets.ViewSet):
72 | @action(detail=True, methods=['post'], url_path='action-one')
73 | def action1(self, request, *args, **kwargs):
74 | pass
75 |
76 | routes = self.router.get_routes(BasicViewSet)
77 | action1_route = self.get_dynamic_route_by_def_name('action1', routes)
78 |
79 | msg = '@detail_route should map methods to "url_path"'
80 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg)
81 |
82 | msg = '@detail_route should use url with detail lookup and "url_path" value'
83 | self.assertEqual(action1_route.url, u'^{prefix}/{lookup}/action-one{trailing_slash}$', msg)
84 |
85 | def test_list_route(self):
86 | class BasicViewSet(viewsets.ViewSet):
87 | @action(detail=False)
88 | def action1(self, request, *args, **kwargs):
89 | pass
90 |
91 | routes = self.router.get_routes(BasicViewSet)
92 | action1_route = self.get_dynamic_route_by_def_name('action1', routes)
93 |
94 | msg = '@list_route should map methods to def name'
95 | self.assertEqual(action1_route.mapping, {'get': 'action1'}, msg)
96 |
97 | msg = '@list_route should use url in list scope'
98 | self.assertEqual(action1_route.url, u'^{prefix}/action1{trailing_slash}$', msg)
99 |
100 | def test_list_route__with_methods(self):
101 | class BasicViewSet(viewsets.ViewSet):
102 | @action(detail=False, methods=['post'])
103 | def action1(self, request, *args, **kwargs):
104 | pass
105 |
106 | routes = self.router.get_routes(BasicViewSet)
107 | action1_route = self.get_dynamic_route_by_def_name('action1', routes)
108 |
109 | msg = '@list_route should map methods to def name'
110 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg)
111 |
112 | msg = '@list_route should use url in list scope'
113 | self.assertEqual(action1_route.url, u'^{prefix}/action1{trailing_slash}$', msg)
114 |
115 | def test_list_route__with_methods__and__with_url_path(self):
116 | class BasicViewSet(viewsets.ViewSet):
117 | @action(detail=False, methods=['post'], url_path='action-one')
118 | def action1(self, request, *args, **kwargs):
119 | pass
120 |
121 | routes = self.router.get_routes(BasicViewSet)
122 | action1_route = self.get_dynamic_route_by_def_name('action1', routes)
123 |
124 | msg = '@list_route should map methods to "url_path"'
125 | self.assertEqual(action1_route.mapping, {'post': 'action1'}, msg)
126 |
127 | msg = '@list_route should use url in list scope with "url_path" value'
128 | self.assertEqual(action1_route.url, u'^{prefix}/action-one{trailing_slash}$', msg)
129 |
130 | def test_list_route_and_detail_route_with_exact_names(self):
131 | class BasicViewSet(viewsets.ViewSet):
132 | @action(detail=False, url_path='action-one')
133 | def action1(self, request, *args, **kwargs):
134 | pass
135 |
136 | @action(detail=True, url_path='action-one')
137 | def action1_detail(self, request, *args, **kwargs):
138 | pass
139 |
140 | routes = self.router.get_routes(BasicViewSet)
141 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes)
142 | action1_detail_route = self.get_dynamic_route_by_def_name('action1_detail', routes)
143 |
144 | self.assertEqual(action1_list_route.mapping, {'get': 'action1'})
145 | self.assertEqual(action1_list_route.url, u'^{prefix}/action-one{trailing_slash}$')
146 |
147 | self.assertEqual(action1_detail_route.mapping, {'get': 'action1_detail'})
148 | self.assertEqual(action1_detail_route.url, u'^{prefix}/{lookup}/action-one{trailing_slash}$')
149 |
150 | def test_list_route_and_detail_route_names(self):
151 | class BasicViewSet(viewsets.ViewSet):
152 | @action(detail=False)
153 | def action1(self, request, *args, **kwargs):
154 | pass
155 |
156 | @action(detail=True)
157 | def action2(self, request, *args, **kwargs):
158 | pass
159 |
160 | routes = self.router.get_routes(BasicViewSet)
161 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes)
162 | action2_detail_route = self.get_dynamic_route_by_def_name('action2', routes)
163 |
164 | self.assertEqual(action1_list_route.name, u'{basename}-action1')
165 | self.assertEqual(action2_detail_route.name, u'{basename}-action2')
166 |
167 | def test_list_route_and_detail_route_default_names__with_endpoints(self):
168 | class BasicViewSet(viewsets.ViewSet):
169 | @action(detail=False, url_path='action_one')
170 | def action1(self, request, *args, **kwargs):
171 | pass
172 |
173 | @action(detail=True, url_path='action-two')
174 | def action2(self, request, *args, **kwargs):
175 | pass
176 |
177 | routes = self.router.get_routes(BasicViewSet)
178 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes)
179 | action2_detail_route = self.get_dynamic_route_by_def_name('action2', routes)
180 |
181 | self.assertEqual(action1_list_route.name, u'{basename}-action1')
182 | self.assertEqual(action2_detail_route.name, u'{basename}-action2')
183 |
184 | def test_list_route_and_detail_route_names__with_endpoints(self):
185 | class BasicViewSet(viewsets.ViewSet):
186 | @action(detail=False, url_path='action_one', url_name='action_one')
187 | def action1(self, request, *args, **kwargs):
188 | pass
189 |
190 | @action(detail=True, url_path='action-two', url_name='action-two')
191 | def action2(self, request, *args, **kwargs):
192 | pass
193 |
194 | routes = self.router.get_routes(BasicViewSet)
195 | action1_list_route = self.get_dynamic_route_by_def_name('action1', routes)
196 | action2_detail_route = self.get_dynamic_route_by_def_name('action2', routes)
197 |
198 | self.assertEqual(action1_list_route.name, u'{basename}-action_one')
199 | self.assertEqual(action2_detail_route.name, u'{basename}-action-two')
200 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/serializers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/serializers/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 |
4 | class UserModel(models.Model):
5 | name = models.CharField(max_length=20)
6 |
7 |
8 | class CommentModel(models.Model):
9 | user = models.ForeignKey(
10 | UserModel,
11 | related_name='comments',
12 | on_delete=models.CASCADE,
13 | )
14 | users_liked = models.ManyToManyField(UserModel, blank=True)
15 | title = models.CharField(max_length=20)
16 | text = models.CharField(max_length=200)
17 | attachment = models.FileField(
18 | upload_to='test_serializers', blank=True, null=True, max_length=500)
19 | hidden_text = models.CharField(max_length=200, blank=True, null=True)
20 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/serializers/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 | from rest_framework_extensions import serializers as drf_serializers
3 |
4 | from .models import CommentModel, UserModel
5 |
6 |
7 | class UserSerializer(drf_serializers.PartialUpdateSerializerMixin,
8 | serializers.ModelSerializer):
9 | class Meta:
10 | model = UserModel
11 | fields = (
12 | 'name',
13 | 'comments'
14 | )
15 |
16 |
17 | class CommentSerializer(drf_serializers.PartialUpdateSerializerMixin,
18 | serializers.ModelSerializer):
19 | title_from_source = serializers.CharField(source='title', required=False)
20 |
21 | class Meta:
22 | model = CommentModel
23 | fields = (
24 | 'id',
25 | 'user',
26 | 'users_liked',
27 | 'title',
28 | 'text',
29 | 'attachment',
30 | 'title_from_source'
31 | )
32 |
33 |
34 | class CommentTextSerializer(drf_serializers.PartialUpdateSerializerMixin,
35 | serializers.ModelSerializer):
36 |
37 | class Meta:
38 | model = CommentModel
39 | fields = (
40 | 'title',
41 | 'text'
42 | )
43 |
44 |
45 | class CommentSerializerWithGroupedFields(CommentSerializer):
46 | text_content = CommentTextSerializer(source='*')
47 |
48 | class Meta(CommentSerializer.Meta):
49 | fields = (
50 | 'id',
51 | 'user',
52 | 'users_liked',
53 | 'attachment',
54 | 'title_from_source',
55 | 'text_content'
56 | )
57 |
58 |
59 | class CommentSerializerWithAllowedUserId(CommentSerializer):
60 | user_id = serializers.IntegerField()
61 |
62 | class Meta(CommentSerializer.Meta):
63 | fields = ('user_id',) + CommentSerializer.Meta.fields
64 |
65 |
66 | class CommentSerializerWithExpandedUsersLiked(drf_serializers.PartialUpdateSerializerMixin,
67 | serializers.ModelSerializer):
68 | user = UserSerializer()
69 |
70 | class Meta:
71 | model = CommentModel
72 | fields = (
73 | 'title',
74 | 'user'
75 | )
76 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/serializers/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 | from django.core.files.base import ContentFile
3 | from rest_framework_extensions.serializers import get_fields_for_partial_update
4 |
5 | from .serializers import CommentSerializer, UserSerializer, \
6 | CommentSerializerWithGroupedFields, CommentSerializerWithAllowedUserId
7 | from .models import UserModel, CommentModel
8 |
9 |
10 | class PartialUpdateSerializerMixinTest(TestCase):
11 | def setUp(self):
12 | self.files = [
13 | ContentFile(u'file one'.encode('utf-8'), name='file1.txt'),
14 | ContentFile(u'file two'.encode('utf-8'), name='file2.txt'),
15 | ]
16 | self.files[0].size = 8
17 | self.files[1].size = 8
18 | self.user = UserModel.objects.create(name='gena')
19 | self.comment = CommentModel.objects.create(
20 | user=self.user,
21 | title='hello',
22 | text='world',
23 | attachment=self.files[0]
24 | )
25 |
26 | def get_comment(self):
27 | return CommentModel.objects.get(pk=self.comment.pk)
28 |
29 | def test_should_use_default_saving_without_partial(self):
30 | serializer = CommentSerializer(data={
31 | 'user': self.user.id,
32 | 'title': 'hola',
33 | 'text': 'amigos',
34 | })
35 |
36 | self.assertTrue(serializer.is_valid()) # bug for python3 comes from here
37 |
38 | saved_object = serializer.save()
39 | self.assertEqual(saved_object.user, self.user)
40 | self.assertEqual(saved_object.title, 'hola')
41 | self.assertEqual(saved_object.text, 'amigos')
42 |
43 | def test_should_save_partial(self):
44 | serializer = CommentSerializer(
45 | instance=self.comment, data={'title': 'hola'}, partial=True)
46 | self.assertTrue(serializer.is_valid())
47 | saved_object = serializer.save()
48 | self.assertEqual(saved_object.user, self.user)
49 | self.assertEqual(saved_object.title, 'hola')
50 | self.assertEqual(saved_object.text, 'world')
51 |
52 | def test_update_fields_correctly_determined(self):
53 | serializer = CommentSerializerWithGroupedFields(
54 | instance=self.comment,
55 | data={'text_content': {'title': 'hola', 'text': 'group'},
56 | 'title_from_source': 'hola', 'attachment': self.files[1]},
57 | partial=True)
58 | update_fields = get_fields_for_partial_update(
59 | serializer.Meta,
60 | serializer.get_initial(),
61 | serializer.fields.fields)
62 | self.assertListEqual(update_fields, ['attachment', 'text', 'title'])
63 |
64 | def test_should_save_grouped_partial(self):
65 | serializer = CommentSerializerWithGroupedFields(
66 | instance=self.comment,
67 | data={'text_content': {'title': 'hola', 'text': 'group'}},
68 | partial=True)
69 | self.assertTrue(serializer.is_valid())
70 | serializer.save()
71 | self.comment.refresh_from_db()
72 | self.assertEqual(self.comment.user, self.user)
73 | self.assertEqual(self.comment.title, 'hola')
74 | self.assertEqual(self.comment.text, 'group')
75 |
76 | def test_should_save_only_fields_from_data_for_partial_update(self):
77 | # it's important to use different instances for Comment,
78 | # because serializer's save method affects instance from arguments
79 | serializer_one = CommentSerializer(
80 | instance=self.get_comment(),
81 | data={'title': 'goodbye'}, partial=True)
82 | serializer_two = CommentSerializer(
83 | instance=self.get_comment(), data={'text': 'moon'}, partial=True)
84 | serializer_three_kwargs = {
85 | 'instance': self.get_comment(),
86 | 'partial': True
87 | }
88 | serializer_three_kwargs['data'] = {'attachment': self.files[1]}
89 | serializer_three = CommentSerializer(**serializer_three_kwargs)
90 | self.assertTrue(serializer_one.is_valid())
91 | self.assertTrue(serializer_two.is_valid())
92 | self.assertTrue(serializer_three.is_valid())
93 |
94 | # saving three serializers expecting they don't affect each other's saving
95 | serializer_one.save()
96 | serializer_two.save()
97 | serializer_three.save()
98 |
99 | fresh_instance = self.get_comment()
100 |
101 | self.assertEqual(fresh_instance.attachment.read(), u'file two'.encode('utf-8'))
102 | fresh_instance.attachment.close()
103 |
104 | self.assertEqual(fresh_instance.text, 'moon')
105 | self.assertEqual(fresh_instance.title, 'goodbye')
106 |
107 | def test_should_use_related_field_name_for_update_field_list(self):
108 | another_user = UserModel.objects.create(name='vova')
109 | data = {
110 | 'title': 'goodbye',
111 | 'user': another_user.pk
112 | }
113 | serializer = CommentSerializer(
114 | instance=self.get_comment(), data=data, partial=True)
115 | self.assertTrue(serializer.is_valid())
116 | serializer.save()
117 | fresh_instance = self.get_comment()
118 | self.assertEqual(fresh_instance.title, 'goodbye')
119 | self.assertEqual(fresh_instance.user, another_user)
120 |
121 | def test_should_use_field_source_value_for_searching_model_concrete_fields(self):
122 | data = {
123 | 'title_from_source': 'goodbye'
124 | }
125 | serializer = CommentSerializer(
126 | instance=self.get_comment(), data=data, partial=True)
127 | self.assertTrue(serializer.is_valid())
128 | serializer.save()
129 | fresh_instance = self.get_comment()
130 | self.assertEqual(fresh_instance.title, 'goodbye')
131 |
132 | def test_should_not_use_m2m_field_name_for_update_field_list(self):
133 | another_user = UserModel.objects.create(name='vova')
134 | data = {
135 | 'title': 'goodbye',
136 | 'users_liked': [self.user.pk, another_user.pk]
137 | }
138 | serializer = CommentSerializer(
139 | instance=self.get_comment(), data=data, partial=True)
140 | self.assertTrue(serializer.is_valid())
141 | try:
142 | serializer.save()
143 | except ValueError:
144 | self.fail(
145 | 'If m2m field used in partial update then it should not be used in update_fields list')
146 | fresh_instance = self.get_comment()
147 | self.assertEqual(fresh_instance.title, 'goodbye')
148 | users_liked = set(
149 | fresh_instance.users_liked.all().values_list('pk', flat=True))
150 | self.assertEqual(
151 | users_liked, set([self.user.pk, another_user.pk]))
152 |
153 | def test_should_not_use_related_set_field_name_for_update_field_list(self):
154 | another_user = UserModel.objects.create(name='vova')
155 | another_comment = CommentModel.objects.create(
156 | user=another_user,
157 | title='goodbye',
158 | text='moon',
159 | )
160 | data = {
161 | 'name': 'vova',
162 | 'comments': [another_comment.pk]
163 | }
164 | serializer = UserSerializer(instance=another_user, data=data, partial=True)
165 | self.assertTrue(serializer.is_valid())
166 | serializer.save()
167 | try:
168 | serializer.save()
169 | except ValueError:
170 | self.fail('If related set field used in partial update then it should not be used in update_fields list')
171 | fresh_comment = CommentModel.objects.get(pk=another_comment.pk)
172 | fresh_user = UserModel.objects.get(pk=another_user.pk)
173 | self.assertEqual(fresh_comment.user, another_user)
174 | self.assertEqual(fresh_user.name, 'vova')
175 |
176 | def test_should_not_try_to_update_fields_that_are_not_in_model(self):
177 | data = {
178 | 'title': 'goodbye',
179 | 'not_existing_field': 'moon'
180 | }
181 | serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True)
182 | self.assertTrue(serializer.is_valid())
183 | try:
184 | serializer.save()
185 | except ValueError:
186 | msg = 'Should not pass values to update_fields from data, if they are not in model'
187 | self.fail(msg)
188 | fresh_instance = self.get_comment()
189 | self.assertEqual(fresh_instance.title, 'goodbye')
190 | self.assertEqual(fresh_instance.text, 'world')
191 |
192 | def test_should_not_try_to_update_fields_that_are_not_allowed_from_serializer(self):
193 | data = {
194 | 'title': 'goodbye',
195 | 'hidden_text': 'do not change me'
196 | }
197 | serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True)
198 | self.assertTrue(serializer.is_valid())
199 | serializer.save()
200 | fresh_instance = self.get_comment()
201 | self.assertEqual(fresh_instance.title, 'goodbye')
202 | self.assertEqual(fresh_instance.text, 'world')
203 | self.assertEqual(fresh_instance.hidden_text, None)
204 |
205 | def test_should_use_list_of_fields_to_update_from_arguments_if_it_passed(self):
206 | data = {
207 | 'title': 'goodbye',
208 | 'text': 'moon'
209 | }
210 | serializer = CommentSerializer(instance=self.get_comment(), data=data, partial=True)
211 | self.assertTrue(serializer.is_valid())
212 | serializer.save(**{'update_fields': ['title']})
213 | fresh_instance = self.get_comment()
214 | self.assertEqual(fresh_instance.title, 'goodbye')
215 | self.assertEqual(fresh_instance.text, 'world')
216 |
217 | def test_should_not_use_field_attname_for_update_fields__if_attname_not_allowed_in_serializer_fields(self):
218 | another_user = UserModel.objects.create(name='vova')
219 | data = {
220 | 'title': 'goodbye',
221 | 'user_id': another_user.id
222 | }
223 | serializer = CommentSerializer(
224 | instance=self.get_comment(), data=data, partial=True)
225 | self.assertTrue(serializer.is_valid())
226 | serializer.save()
227 | fresh_instance = self.get_comment()
228 | self.assertEqual(fresh_instance.user_id, self.user.id)
229 |
230 | def test_should_use_field_attname_for_update_fields__if_attname_allowed_in_serializer_fields(self):
231 | another_user = UserModel.objects.create(name='vova')
232 | data = {
233 | 'title': 'goodbye',
234 | 'user_id': another_user.id
235 | }
236 | serializer = CommentSerializerWithAllowedUserId(
237 | instance=self.get_comment(), data=data, partial=True)
238 | self.assertTrue(serializer.is_valid())
239 | serializer.save()
240 | fresh_instance = self.get_comment()
241 | self.assertEqual(fresh_instance.user_id, another_user.id)
242 |
243 | def test_should_not_use_pk_field_for_update_fields(self):
244 | old_pk = self.get_comment().pk
245 | data = {
246 | 'id': old_pk + 1,
247 | 'title': 'goodbye'
248 | }
249 | serializer = CommentSerializer(
250 | instance=self.get_comment(), data=data, partial=True)
251 | self.assertTrue(serializer.is_valid())
252 | try:
253 | serializer.save()
254 | except ValueError:
255 | self.fail(
256 | 'Primary key field should be excluded from update_fields list')
257 | fresh_instance = self.get_comment()
258 | self.assertEqual(fresh_instance.pk, old_pk)
259 | self.assertEqual(fresh_instance.title, u'goodbye')
260 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tests_app/tests/unit/utils/tests.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | try:
3 | from unittest import mock
4 | except ImportError:
5 | import mock
6 |
7 | from django.test import TestCase
8 |
9 | from rest_framework_extensions.utils import prepare_header_name, get_rest_framework_version
10 |
11 |
12 | @contextlib.contextmanager
13 | def parsed_version(version):
14 | with mock.patch('rest_framework.VERSION', version):
15 | yield get_rest_framework_version()
16 |
17 |
18 | class TestPrepareHeaderName(TestCase):
19 | def test_upper(self):
20 | self.assertEqual(prepare_header_name('Accept'), 'HTTP_ACCEPT')
21 |
22 | def test_replace_dash_with_underscores(self):
23 | self.assertEqual(
24 | prepare_header_name('Accept-Language'), 'HTTP_ACCEPT_LANGUAGE')
25 |
26 | def test_strips_whitespaces(self):
27 | self.assertEqual(
28 | prepare_header_name(' Accept-Language '), 'HTTP_ACCEPT_LANGUAGE')
29 |
30 | def test_adds_http_prefix(self):
31 | self.assertEqual(
32 | prepare_header_name('Accept-Language'), 'HTTP_ACCEPT_LANGUAGE')
33 |
34 | def test_get_rest_framework_version_exotic_version(self):
35 | """See """
36 | with parsed_version('1.2a2') as version:
37 | self.assertEqual(version, (1, 2))
38 |
39 | def test_get_rest_framework_version_normal_version(self):
40 | """See """
41 | with parsed_version('3.14.16') as version:
42 | self.assertEqual(version, (3, 14, 16))
43 |
--------------------------------------------------------------------------------
/tests_app/testutils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | try:
3 | from unittest.mock import patch
4 | except ImportError:
5 | from mock import patch
6 |
7 | from rest_framework import HTTP_HEADER_ENCODING
8 |
9 | from rest_framework_extensions.key_constructor import bits
10 | from rest_framework_extensions.key_constructor.constructors import KeyConstructor
11 |
12 |
13 | def get_url_pattern_by_regex_pattern(patterns, pattern_string):
14 | # todo: test me
15 | for pattern in patterns:
16 | if pattern.pattern.regex.pattern == pattern_string:
17 | return pattern
18 |
19 |
20 | def override_extensions_api_settings(**kwargs):
21 | return patch.multiple(
22 | 'rest_framework_extensions.settings.extensions_api_settings',
23 | **kwargs
24 | )
25 |
26 |
27 | def basic_auth_header(username, password):
28 | credentials = ('%s:%s' % (username, password))
29 | base64_credentials = base64.b64encode(
30 | credentials.encode(HTTP_HEADER_ENCODING)
31 | ).decode(HTTP_HEADER_ENCODING)
32 | return 'Basic %s' % base64_credentials
33 |
34 |
35 | class TestFormatKeyBit(bits.KeyBitBase):
36 | def get_data(self, **kwargs):
37 | return u'json'
38 |
39 |
40 | class TestLanguageKeyBit(bits.KeyBitBase):
41 | def get_data(self, **kwargs):
42 | return u'ru'
43 |
44 |
45 | class TestUsedKwargsKeyBit(bits.KeyBitBase):
46 | def get_data(self, **kwargs):
47 | return kwargs
48 |
49 |
50 | class TestKeyConstructor(KeyConstructor):
51 | format = TestFormatKeyBit()
52 | language = TestLanguageKeyBit()
53 |
--------------------------------------------------------------------------------
/tests_app/urls.py:
--------------------------------------------------------------------------------
1 | urlpatterns = []
2 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py{38,39}-django{22}-drf{311,312}
3 | py{38,39,310,311,312}-django{32}-drf{312,313,314}
4 | py{38,39,310,311,312}-django{42}-drf{314,315}
5 | py{310,311,312}-django{52}-drf{314,315}
6 |
7 |
8 | [testenv]
9 | deps=
10 | -rtests_app/requirements.txt
11 | django-guardian>=1.4.4
12 | drf311: djangorestframework>=3.11,<3.12
13 | djangorestframework-guardian
14 | drf312: djangorestframework>=3.12,<3.13
15 | djangorestframework-guardian
16 | drf313: djangorestframework>=3.13,<3.14
17 | djangorestframework-guardian
18 | drf314: djangorestframework>=3.14,<3.15
19 | djangorestframework-guardian
20 | drf315: djangorestframework>=3.15,<3.16
21 | djangorestframework-guardian
22 | django22: Django>=2.2,<3.0
23 | django32: Django>=3.2,<4.0
24 | django42: Django>=4.2,<5.0
25 | django52: Django>=5.2,<6.0
26 |
27 |
28 | setenv =
29 | PYTHONPATH = {toxinidir}:{toxinidir}/tests_app
30 | commands =
31 | python --version
32 | pip freeze
33 | python -Wd {envbindir}/django-admin test --settings=settings {posargs}
34 |
--------------------------------------------------------------------------------