├── test_project ├── __init__.py ├── testapp │ ├── __init__.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_utils.py │ │ ├── test_validators.py │ │ ├── test_params.py │ │ ├── test_typed_action.py │ │ ├── test_decorators.py │ │ └── test_typed_api_view.py │ ├── management │ │ ├── __init__.py │ │ └── commands │ │ │ ├── __init__.py │ │ │ └── create_test_user.py │ ├── migrations │ │ ├── __init__.py │ │ └── 0001_initial.py │ ├── serializers.py │ ├── view_sets.py │ ├── models.py │ └── views.py ├── wsgi.py ├── urls.py └── settings.py ├── pypi_submit.py ├── .gitignore ├── rest_typed_views ├── validators │ ├── __init__.py │ ├── default_validator.py │ ├── type_system_validator.py │ ├── pydantic_validator.py │ ├── marshmallow_validator.py │ ├── current_user_validator.py │ └── validator_factory.py ├── __init__.py ├── utils.py ├── params.py ├── param_settings.py └── decorators.py ├── requirements.txt ├── manage.py ├── LICENSE.md ├── setup.py └── README.md /test_project/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_project/testapp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_project/testapp/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_project/testapp/management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_project/testapp/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_project/testapp/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_project/testapp/tests/test_validators.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_project/testapp/management/commands/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pypi_submit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.system("python setup.py sdist --verbose") 4 | os.system("twine upload dist/*") 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | MANIFEST 3 | README.txt 4 | dist/ 5 | .mypy_cache/ 6 | .vscode/ 7 | db.sqlite3 8 | drf_typed_views.egg-info/ -------------------------------------------------------------------------------- /test_project/wsgi.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from django.core.wsgi import get_wsgi_application 4 | 5 | application = get_wsgi_application() 6 | -------------------------------------------------------------------------------- /test_project/testapp/serializers.py: -------------------------------------------------------------------------------- 1 | from test_project.testapp.models import Movie 2 | from rest_framework.serializers import ModelSerializer 3 | 4 | 5 | class MovieSerializer(ModelSerializer): 6 | class Meta: 7 | model = Movie 8 | fields = ["id", "title", "rating", "genre"] 9 | 10 | -------------------------------------------------------------------------------- /rest_typed_views/validators/__init__.py: -------------------------------------------------------------------------------- 1 | from .pydantic_validator import PydanticValidator 2 | from .type_system_validator import TypeSystemValidator 3 | from .default_validator import DefaultValidator 4 | from .current_user_validator import CurrentUserValidator 5 | from .marshmallow_validator import MarshMallowValidator 6 | from .validator_factory import ValidatorFactory 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.3 2 | asgiref==3.2.3 3 | attrs==19.1.0 4 | black==19.3b0 5 | Click==7.0 6 | dataclasses==0.6 7 | Django==3.1.12 8 | djangorestframework==3.11.0 9 | entrypoints==0.3 10 | flake8==3.7.8 11 | marshmallow==3.2.0 12 | mccabe==0.6.1 13 | mypy==0.720 14 | mypy-extensions==0.4.1 15 | pycodestyle==2.5.0 16 | pydantic==1.6.2 17 | pyflakes==2.1.1 18 | pytz==2019.1 19 | sqlparse==0.3.0 20 | toml==0.10.0 21 | typed-ast==1.4.0 22 | typesystem==0.2.4 23 | typing-extensions==3.7.4 24 | -------------------------------------------------------------------------------- /rest_typed_views/validators/default_validator.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from rest_framework.exceptions import ValidationError 4 | from rest_framework.fields import empty 5 | 6 | 7 | class DefaultValidator(object): 8 | def __init__(self, default: Any): 9 | self.default = default 10 | 11 | def run_validation(self, data: Any): 12 | value = self.default if data is empty else data 13 | 14 | if value is empty: 15 | raise ValidationError("A value for this parameter is required") 16 | 17 | return value 18 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "test_project.settings") 7 | try: 8 | from django.core.management import execute_from_command_line 9 | except ImportError as exc: 10 | raise ImportError( 11 | "Couldn't import Django. Are you sure it's installed and " 12 | "available on your PYTHONPATH environment variable? Did you " 13 | "forget to activate a virtual environment?" 14 | ) 15 | execute_from_command_line(sys.argv) 16 | -------------------------------------------------------------------------------- /rest_typed_views/__init__.py: -------------------------------------------------------------------------------- 1 | from .decorators import typed_action, typed_api_view 2 | from .param_settings import ParamSettings 3 | 4 | 5 | def Query(*args, **kwargs): 6 | return ParamSettings("query_param", *args, **kwargs) 7 | 8 | 9 | def Path(*args, **kwargs): 10 | return ParamSettings("path", *args, **kwargs) 11 | 12 | 13 | def CurrentUser(*args, **kwargs): 14 | return ParamSettings("current_user", *args, **kwargs) 15 | 16 | 17 | def Body(*args, **kwargs): 18 | return ParamSettings("body", *args, **kwargs) 19 | 20 | 21 | def Header(*args, **kwargs): 22 | return ParamSettings("header", *args, **kwargs) 23 | 24 | 25 | def Param(*args, **kwargs): 26 | return ParamSettings(*args, **kwargs) 27 | -------------------------------------------------------------------------------- /rest_typed_views/validators/type_system_validator.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from django.http import QueryDict 4 | from rest_framework.exceptions import ValidationError 5 | 6 | 7 | class TypeSystemValidator(object): 8 | def __init__(self, TypeSystemSchemaClass): 9 | self.TypeSystemSchemaClass = TypeSystemSchemaClass 10 | 11 | def run_validation(self, data: Union[dict, QueryDict]): 12 | if isinstance(data, QueryDict): 13 | # Note that QueryDict is subclass of dict 14 | data = data.dict() 15 | instance, errors = self.TypeSystemSchemaClass.validate_or_error(data) 16 | 17 | if errors: 18 | raise ValidationError(dict(errors)) 19 | 20 | return instance 21 | -------------------------------------------------------------------------------- /rest_typed_views/validators/pydantic_validator.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from django.http import QueryDict 4 | from rest_framework.exceptions import ValidationError 5 | 6 | 7 | class PydanticValidator(object): 8 | def __init__(self, PydanticModelClass): 9 | self.PydanticModelClass = PydanticModelClass 10 | 11 | def run_validation(self, data: Union[dict, QueryDict]): 12 | from pydantic import ValidationError as PydanticValidationError 13 | 14 | try: 15 | if isinstance(data, QueryDict): 16 | # Note that QueryDict is subclass of dict 17 | return self.PydanticModelClass(**data.dict()) 18 | return self.PydanticModelClass(**data) 19 | except PydanticValidationError as e: 20 | raise ValidationError(e.errors()) 21 | -------------------------------------------------------------------------------- /rest_typed_views/validators/marshmallow_validator.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from django.http import QueryDict 4 | from rest_framework.exceptions import ValidationError 5 | 6 | 7 | class MarshMallowValidator(object): 8 | def __init__(self, MashMallowSchemaClass): 9 | self.MashMallowSchemaClass = MashMallowSchemaClass 10 | 11 | def run_validation(self, data: Union[dict, QueryDict]): 12 | from marshmallow import ValidationError as MarshMallowValidationError 13 | 14 | try: 15 | if isinstance(data, QueryDict): 16 | # Note that QueryDict is subclass of dict 17 | return self.MashMallowSchemaClass().load(data.dict()) 18 | return self.MashMallowSchemaClass().load(data) 19 | except MarshMallowValidationError as err: 20 | raise ValidationError(err.messages) 21 | -------------------------------------------------------------------------------- /test_project/testapp/management/commands/create_test_user.py: -------------------------------------------------------------------------------- 1 | from django.contrib.auth.models import User, Group 2 | from django.core.management.base import BaseCommand 3 | from rest_framework.authtoken.models import Token 4 | 5 | 6 | class Command(BaseCommand): 7 | def handle(self, *args, **options): 8 | user = User.objects.filter(username="robert").first() 9 | admin = Group.objects.filter(name="admin").first() 10 | 11 | if user: 12 | user.delete() 13 | 14 | if not admin: 15 | admin = Group.objects.create(name="admin") 16 | 17 | user = User.objects.create_user( 18 | "robert", 19 | email="robertgsinger@gmail.com", 20 | first_name="Robert", 21 | last_name="Singer", 22 | ) 23 | 24 | user.groups.add(admin) 25 | 26 | token = Token.objects.create(user=user) 27 | print("Token is: ", token.key) 28 | -------------------------------------------------------------------------------- /test_project/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls import include, url 2 | from rest_framework import routers 3 | 4 | from test_project.testapp.views import ( 5 | create_booking, 6 | create_user, 7 | get_logs, 8 | create_band_member, 9 | get_cache_header, 10 | test_view, 11 | ) 12 | from test_project.testapp.view_sets import MovieViewSet 13 | 14 | router = routers.SimpleRouter() 15 | 16 | router.register(r"movies", MovieViewSet, basename="movie") 17 | 18 | urlpatterns = [ 19 | url(r"^logs/(?P[0-9])/", get_logs, name="get-log-entry"), 20 | url(r"^users/", create_user, name="create-user"), 21 | url(r"^test/", test_view, name="test-view"), 22 | url(r"^bookings/", create_booking, name="create-booking"), 23 | url(r"^test/", test_view, name="test-view"), 24 | url(r"^band-members/", create_band_member, name="create-band-member"), 25 | url(r"^get-cache-header/", get_cache_header, name="get-cache-header"), 26 | url(r"^", include(router.urls)), 27 | ] 28 | -------------------------------------------------------------------------------- /test_project/testapp/view_sets.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | from rest_framework import viewsets 5 | from rest_framework.decorators import action 6 | from rest_framework.response import Response 7 | 8 | from rest_typed_views import typed_action 9 | from test_project.testapp.models import Movie 10 | from test_project.testapp.serializers import MovieSerializer 11 | 12 | 13 | class Actor(BaseModel): 14 | id: int 15 | name: str 16 | movies: List[int] = [] 17 | 18 | 19 | class MovieViewSet(viewsets.ModelViewSet): 20 | serializer_class = MovieSerializer 21 | 22 | def get_queryset(self): 23 | return Movie.objects.all() 24 | 25 | @typed_action(detail=True, methods=["get"]) 26 | def reviews(self, request, pk: int, test_qp: str, title: str = "My default title"): 27 | obj = self.get_object() 28 | return Response({"id": obj.id, "test_qp": test_qp, "title": title}) 29 | 30 | @typed_action(detail=False, methods=["POST"]) 31 | def actors(self, actor: Actor): 32 | return Response(dict(actor)) 33 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 rsinger86 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /test_project/testapp/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from django.utils import timezone 4 | from django.core import mail 5 | from django.db import models 6 | from django.utils.functional import cached_property 7 | 8 | 9 | class UserAccount(models.Model): 10 | username = models.CharField(max_length=100) 11 | first_name = models.CharField(max_length=100) 12 | last_name = models.CharField(max_length=100) 13 | password = models.CharField(max_length=200) 14 | email = models.EmailField(null=True) 15 | password_updated_at = models.DateTimeField(null=True) 16 | joined_at = models.DateTimeField(null=True) 17 | has_trial = models.BooleanField(default=False) 18 | 19 | status = models.CharField( 20 | default="active", 21 | max_length=30, 22 | choices=(("active", "Active"), ("banned", "Banned"), ("inactive", "Inactive")), 23 | ) 24 | 25 | 26 | class Movie(models.Model): 27 | title = models.CharField(max_length=100) 28 | rating = models.FloatField(default=0) 29 | 30 | genre = models.CharField( 31 | max_length=30, choices=(("comedy", "Comedy"), ("drama", "Drama")) 32 | ) 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from codecs import open 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def readme(): 8 | with open("README.md", "r") as infile: 9 | return infile.read() 10 | 11 | 12 | classifiers = [ 13 | # Pick your license as you wish (should match "license" above) 14 | "Development Status :: 4 - Beta", 15 | "License :: OSI Approved :: MIT License", 16 | "Programming Language :: Python :: 3.3", 17 | "Programming Language :: Python :: 3.4", 18 | "Programming Language :: Python :: 3.5", 19 | "Programming Language :: Python :: 3.6", 20 | ] 21 | setup( 22 | name="drf-typed-views", 23 | version="0.3.0", 24 | description="Use type annotations for automatic request validation in Django REST Framework", 25 | author="Robert Singer", 26 | author_email="robertgsinger@gmail.com", 27 | packages=find_packages(exclude=["test_project*"]), 28 | url="https://github.com/rsinger86/drf-typed-views", 29 | license="MIT", 30 | keywords="django rest type annotations automatic validation validate", 31 | long_description=readme(), 32 | classifiers=classifiers, 33 | long_description_content_type="text/markdown", 34 | ) 35 | -------------------------------------------------------------------------------- /rest_typed_views/validators/current_user_validator.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | from django.contrib.auth.models import User 3 | from rest_framework.exceptions import ValidationError 4 | 5 | if TYPE_CHECKING: 6 | from rest_typed_views import ParamSettings 7 | 8 | 9 | class CurrentUserValidator(object): 10 | def __init__(self, settings: "ParamSettings"): 11 | self.settings = settings 12 | 13 | def run_validation(self, user: User) -> User: 14 | if self.settings.member_of is not None: 15 | queryset = user.groups.all().filter(name=self.settings.member_of) 16 | 17 | if queryset.count() == 0: 18 | raise ValidationError( 19 | f"User must be a member of the '{self.settings.member_of}' group" 20 | ) 21 | 22 | if len(self.settings.member_of_any) > 0: 23 | queryset = user.groups.all().filter(name__in=self.settings.member_of_any) 24 | 25 | if queryset.count() == 0: 26 | raise ValidationError( 27 | f"User must be a member of at least one of these groups: " 28 | f"'{self.settings.member_of_any}'" 29 | ) 30 | return user 31 | -------------------------------------------------------------------------------- /test_project/testapp/tests/test_params.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | from rest_framework.fields import empty 4 | from rest_framework.test import APITestCase 5 | 6 | from rest_typed_views import ParamSettings 7 | from rest_typed_views.params import BodyParam 8 | 9 | 10 | class ParamsTests(APITestCase): 11 | def fake_request(self, data={}, query_params={}): 12 | return MagicMock(data=data, query_params=query_params) 13 | 14 | def test_body_raw_value_should_be_request_data_when_not_set(self): 15 | body_param = BodyParam( 16 | MagicMock(), self.fake_request(data={"a": "b"}), ParamSettings() 17 | ) 18 | 19 | self.assertEqual(body_param._get_raw_value(), {"a": "b"}) 20 | 21 | def test_body_raw_value_should_be_request_data_when_wildcard_set(self): 22 | body_param = BodyParam( 23 | MagicMock(), self.fake_request(data={"a": "b"}), ParamSettings(source="*") 24 | ) 25 | 26 | self.assertEqual(body_param._get_raw_value(), {"a": "b"}) 27 | 28 | def test_body_raw_value_should_be_empty_when_src_specified_but_not_found(self): 29 | body_param = BodyParam( 30 | MagicMock(), self.fake_request(data={"a": "b"}), ParamSettings(source="c") 31 | ) 32 | 33 | self.assertEqual(body_param._get_raw_value(), empty) 34 | 35 | def test_body_raw_value_should_be_found_when_src_specified(self): 36 | body_param = BodyParam( 37 | MagicMock(), self.fake_request(data={"a": "b"}), ParamSettings(source="a") 38 | ) 39 | 40 | self.assertEqual(body_param._get_raw_value(), "b") 41 | -------------------------------------------------------------------------------- /test_project/testapp/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 2.2.1 on 2019-09-22 22:34 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | initial = True 9 | 10 | dependencies = [ 11 | ] 12 | 13 | operations = [ 14 | migrations.CreateModel( 15 | name='Movie', 16 | fields=[ 17 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 18 | ('title', models.CharField(max_length=100)), 19 | ('rating', models.FloatField(default=0)), 20 | ('genre', models.CharField(choices=[('comedy', 'Comedy'), ('drama', 'Drama')], max_length=30)), 21 | ], 22 | ), 23 | migrations.CreateModel( 24 | name='UserAccount', 25 | fields=[ 26 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 27 | ('username', models.CharField(max_length=100)), 28 | ('first_name', models.CharField(max_length=100)), 29 | ('last_name', models.CharField(max_length=100)), 30 | ('password', models.CharField(max_length=200)), 31 | ('email', models.EmailField(max_length=254, null=True)), 32 | ('password_updated_at', models.DateTimeField(null=True)), 33 | ('joined_at', models.DateTimeField(null=True)), 34 | ('has_trial', models.BooleanField(default=False)), 35 | ('status', models.CharField(choices=[('active', 'Active'), ('banned', 'Banned'), ('inactive', 'Inactive')], default='active', max_length=30)), 36 | ], 37 | ), 38 | ] 39 | -------------------------------------------------------------------------------- /test_project/testapp/tests/test_typed_action.py: -------------------------------------------------------------------------------- 1 | from rest_framework.reverse import reverse 2 | from rest_framework.test import APITestCase 3 | 4 | from test_project.testapp.models import Movie 5 | 6 | 7 | class TypedActionTests(APITestCase): 8 | def setUp(self): 9 | Movie.objects.all().delete() 10 | 11 | def test_get_reviews_ok(self): 12 | movie = Movie.objects.create(title="My movie", rating=5.0, genre="comedy") 13 | url = reverse("movie-reviews", args=[movie.id]) 14 | 15 | response = self.client.get(url, {"test_qp": "cats"}, format="json") 16 | 17 | self.assertEqual(response.status_code, 200) 18 | 19 | self.assertEqual( 20 | response.data, {"id": 1, "test_qp": "cats", "title": "My default title"} 21 | ) 22 | 23 | def test_get_reviews_error(self): 24 | movie = Movie.objects.create(title="My movie", rating=5.0, genre="comedy") 25 | url = reverse("movie-reviews", args=[movie.id]) 26 | response = self.client.get(url, format="json") 27 | self.assertEqual(response.status_code, 400) 28 | self.assertEqual(response.json(), {"test_qp": ["This field is required."]}) 29 | 30 | def test_create_actor_ok(self): 31 | url = reverse("movie-actors") 32 | 33 | response = self.client.post( 34 | url, {"id": 123, "name": "Tom Cruze"}, format="json" 35 | ) 36 | 37 | self.assertEqual(response.status_code, 200) 38 | self.assertEqual(response.data, {"id": 123, "name": "Tom Cruze", "movies": []}) 39 | 40 | def test_create_actor_error(self): 41 | url = reverse("movie-actors") 42 | 43 | response = self.client.post(url, {"id": 123}, format="json") 44 | 45 | self.assertEqual(response.status_code, 400) 46 | 47 | self.assertEqual( 48 | response.json(), 49 | { 50 | "actor": [ 51 | { 52 | "loc": "('name',)", 53 | "msg": "field required", 54 | "type": "value_error.missing", 55 | } 56 | ] 57 | }, 58 | ) 59 | -------------------------------------------------------------------------------- /rest_typed_views/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import operator 3 | from enum import Enum 4 | from functools import reduce 5 | from typing import Any, List, Optional, Tuple 6 | 7 | from django.conf import settings 8 | from rest_framework.fields import empty 9 | from rest_framework.request import Request 10 | 11 | from .param_settings import ParamSettings 12 | 13 | 14 | def parse_list_annotation(annotation) -> Tuple[bool, Any]: 15 | if "List[" in str(annotation): 16 | return True, annotation.__args__[0] 17 | return False, None 18 | 19 | 20 | def parse_enum_annotation(annotation) -> Tuple[bool, List[Any]]: 21 | if inspect.isclass(annotation) and issubclass(annotation, Enum): 22 | return True, [_.value for _ in annotation] 23 | return False, [] 24 | 25 | 26 | def parse_complex_type(annotation) -> Tuple[bool, Optional[str]]: 27 | if hasattr(settings, "DRF_TYPED_VIEWS"): 28 | enabled = settings.DRF_TYPED_VIEWS.get("schema_packages", []) 29 | else: 30 | enabled = [] 31 | 32 | if "pydantic" in enabled: 33 | from pydantic import BaseModel as PydanticBaseModel 34 | 35 | if inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel): 36 | return True, "pydantic" 37 | 38 | if "typesystem" in enabled: 39 | from typesystem import Schema as TypeSystemSchema 40 | 41 | if inspect.isclass(annotation) and issubclass(annotation, TypeSystemSchema): 42 | return True, "typesystem" 43 | 44 | if "marshmallow" in enabled: 45 | from marshmallow import Schema as MarshmallowSchema 46 | 47 | if inspect.isclass(annotation) and issubclass(annotation, MarshmallowSchema): 48 | return True, "marshmallow" 49 | return False, None 50 | 51 | 52 | def get_nested_value(dic: dict, path: str, fallback=None) -> Any: 53 | try: 54 | return reduce(operator.getitem, path.split("."), dic) 55 | except (TypeError, KeyError, ValueError): 56 | return fallback 57 | 58 | 59 | def get_default_value(param: inspect.Parameter) -> Any: 60 | if ( 61 | not is_default_used_to_pass_settings(param) 62 | and param.default is not inspect.Parameter.empty 63 | ): 64 | return param.default 65 | return empty 66 | 67 | 68 | def is_default_used_to_pass_settings(param: inspect.Parameter) -> bool: 69 | return get_explicit_param_settings(param) is not None 70 | 71 | 72 | def get_explicit_param_settings(param: inspect.Parameter) -> Optional[ParamSettings]: 73 | try: 74 | param_type = param.default.param_type 75 | return param.default 76 | except AttributeError: 77 | return None 78 | 79 | 80 | def is_implicit_body_param(param: inspect.Parameter) -> bool: 81 | is_complex_type, package = parse_complex_type(param.annotation) 82 | return is_complex_type 83 | 84 | 85 | def is_explicit_request_param(param: inspect.Parameter) -> bool: 86 | return param.annotation is Request 87 | 88 | 89 | def is_implicit_request_param(param: inspect.Parameter) -> bool: 90 | return param.name == "request" and param.annotation is inspect.Parameter.empty 91 | 92 | 93 | def find_request(original_args: list) -> Request: 94 | for arg in original_args: 95 | if isinstance(arg, Request): 96 | return arg 97 | raise Exception("Could not find request in args:" + str(original_args)) 98 | -------------------------------------------------------------------------------- /test_project/testapp/views.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime, time, timedelta 2 | from decimal import Decimal 3 | from enum import Enum 4 | from typing import List, Optional 5 | 6 | import marshmallow 7 | import typesystem 8 | from django.contrib.auth.models import User 9 | from pydantic import BaseModel 10 | from rest_framework.response import Response 11 | 12 | from rest_typed_views import ( 13 | Body, 14 | CurrentUser, 15 | Param, 16 | Header, 17 | Path, 18 | Query, 19 | typed_api_view, 20 | ) 21 | 22 | """ 23 | http://localhost:8000/logs/2/?title=1231234&price=33.43&latitude=3.333333333333333&is_pretty=no&email=robert@hotmail.com&upper_alpha_string=CAT&identifier=cat&website=https://www.nytimes.com/&identity=e028aa46-8411-4c83-b970-76be868c9413&file=/tmp/test.html&ip=162.254.168.185×tamp=2019-04-03T10:10&start_date=1200-05-05&start_time=20:19&duration=3%205555:45&bag=paper&numbers=1,2,3 24 | 25 | """ 26 | 27 | 28 | class BagOptions(str, Enum): 29 | paper = "paper" 30 | plastic = "plastic" 31 | 32 | 33 | class SuperUser(BaseModel): 34 | id: int 35 | name = "John Doe" 36 | signup_ts: datetime = None 37 | friends: List[int] = [] 38 | 39 | 40 | class Booking(typesystem.Schema): 41 | start_date = typesystem.Date() 42 | end_date = typesystem.Date() 43 | room = typesystem.Choice( 44 | choices=[ 45 | ("double", "Double room"), 46 | ("twin", "Twin room"), 47 | ("single", "Single room"), 48 | ] 49 | ) 50 | include_breakfast = typesystem.Boolean(title="Include breakfast", default=False) 51 | 52 | 53 | @typed_api_view(["POST"]) 54 | def create_user(user: SuperUser): 55 | return Response(dict(user)) 56 | 57 | 58 | @typed_api_view(["POST"]) 59 | def create_booking(booking: Booking = Body(source="_data.item")): 60 | return Response(dict(booking)) 61 | 62 | 63 | @typed_api_view(["GET"]) 64 | def get_cache_header(cache: str = Header()): 65 | return Response(cache) 66 | 67 | 68 | class BandMemberSchema(marshmallow.Schema): 69 | name = marshmallow.fields.String(required=True) 70 | email = marshmallow.fields.Email() 71 | 72 | 73 | @typed_api_view(["POST"]) 74 | def create_band_member(band_member: BandMemberSchema): 75 | return Response(dict(band_member)) 76 | 77 | 78 | @typed_api_view(["GET"]) 79 | def get_logs( 80 | myid: int = Path(source="id"), 81 | latitude: Decimal = Query(decimal_places=20), 82 | title: str = Query(min_length=6), 83 | price: float = Query(min_value=6), 84 | user: User = CurrentUser(member_of_any=[]), 85 | is_pretty: bool = Query(), 86 | email: str = Query(format="email"), 87 | upper_alpha_string: str = Query(regex=r"^[A-Z]+$"), 88 | identifier: str = Query(format="slug"), 89 | website: str = Query(format="url"), 90 | identity: str = Query(format="uuid"), 91 | # file: str = Query(format="file_path", path="/tmp/"), 92 | ip: str = Query(format="ipv4"), 93 | timestamp: datetime = Query(), 94 | start_date: date = Query(), 95 | start_time: time = Query(), 96 | duration: timedelta = Query(), 97 | bag_type: BagOptions = Query(source="bag"), 98 | numbers: List[int] = Query(child=Param(min_value=0)), 99 | ): 100 | 101 | return Response( 102 | data={ 103 | "id": myid, 104 | "title": title, 105 | "price": price, 106 | "latitude": latitude, 107 | "is_pretty": is_pretty, 108 | "email": email, 109 | "upper_alpha_string": upper_alpha_string, 110 | "identifier": identifier, 111 | "website": website, 112 | "identity": identity, 113 | # "file": file, 114 | "ip": ip, 115 | "timestamp": timestamp, 116 | "start_date": start_date, 117 | "start_time": start_time, 118 | "duration": duration, 119 | "bag_type": bag_type, 120 | "numbers": numbers, 121 | }, 122 | status=200, 123 | ) 124 | 125 | 126 | @typed_api_view(["GET"]) 127 | def test_view(c: str = Header(source="connection")): 128 | return Response({"v": c}) 129 | -------------------------------------------------------------------------------- /rest_typed_views/params.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Tuple 3 | 4 | from rest_framework.exceptions import ValidationError 5 | from rest_framework.fields import Field, empty 6 | from rest_framework.request import Request 7 | 8 | from rest_typed_views.param_settings import ParamSettings 9 | from rest_typed_views.utils import get_nested_value, parse_list_annotation 10 | from rest_typed_views.validators import CurrentUserValidator, ValidatorFactory 11 | 12 | 13 | class Param(object): 14 | def __init__( 15 | self, 16 | param: inspect.Parameter, 17 | request: Request, 18 | settings: ParamSettings, 19 | raw_value: Any = empty, 20 | ): 21 | self.param = param 22 | self.request = request 23 | self.settings = settings 24 | self.raw_value = raw_value 25 | 26 | def _get_validator(self) -> Field: 27 | return ValidatorFactory.make(self.param.annotation, self.settings) 28 | 29 | def _get_raw_value(self): 30 | raise Exception("Must implement in concrete class!") 31 | 32 | @property 33 | def _source(self) -> str: 34 | return self.settings.source or self.param.name 35 | 36 | def validate_or_error(self) -> Tuple[Any, Any]: 37 | validator = self._get_validator() 38 | 39 | try: 40 | value = validator.run_validation(self._get_raw_value()) 41 | return value, None 42 | except ValidationError as e: 43 | return None, {self._source: e.detail} 44 | 45 | 46 | class QueryParam(Param): 47 | def _get_raw_value(self): 48 | if self.settings.source == "*": 49 | raw = self.request.query_params.dict() 50 | else: 51 | key = self.settings.source or self.param.name 52 | raw = self.request.query_params.get(key, empty) 53 | raw = empty if raw == "" else raw 54 | is_list_type, item_type = parse_list_annotation(self.param.annotation) 55 | 56 | if raw is not empty and is_list_type: 57 | raw = raw.split(self.settings.delimiter) 58 | 59 | return raw 60 | 61 | 62 | class PathParam(Param): 63 | def _get_raw_value(self): 64 | raw = self.raw_value 65 | return raw 66 | 67 | 68 | class PassThruParam(object): 69 | def __init__(self, value: Any): 70 | self.value = value 71 | 72 | def validate_or_error(self) -> Tuple[Any, Any]: 73 | return self.value, None 74 | 75 | 76 | class BodyParam(Param): 77 | def _get_raw_value(self): 78 | if self.settings.source in ("*", None): 79 | return self.request.data 80 | return get_nested_value(self.request.data, self.settings.source, fallback=empty) 81 | 82 | 83 | class HeaderParam(Param): 84 | def _get_raw_value(self): 85 | headers = { 86 | str(key).lower(): value for key, value in self.request.headers.items() 87 | } 88 | 89 | if self.settings.source == "*": 90 | raw = headers 91 | else: 92 | if self.settings.source: 93 | key = self.settings.source 94 | else: 95 | key = self.param.name.replace("_", "-").lower() 96 | 97 | raw = headers.get(key) 98 | 99 | return raw 100 | 101 | 102 | class CurrentUserParam(Param): 103 | def _get_raw_value(self): 104 | if self.settings.source in ("*", None): 105 | return self.request.user 106 | 107 | obj = self.request.user 108 | 109 | for path in self.settings.source.split("."): 110 | if hasattr(obj, path): 111 | obj = getattr(obj, path) 112 | else: 113 | obj = None 114 | break 115 | 116 | return obj 117 | 118 | def validate_or_error(self) -> Tuple[Any, Any]: 119 | value = self._get_raw_value() 120 | generic_validator = self._get_validator() 121 | current_user_validator = CurrentUserValidator(self.settings) 122 | 123 | try: 124 | value = generic_validator.run_validation(value) 125 | current_user_validator.run_validation(self.request.user) 126 | return value, None 127 | except ValidationError as e: 128 | return None, {self._source: e.detail} 129 | -------------------------------------------------------------------------------- /test_project/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for coremodel project. 3 | 4 | Generated by 'django-admin startproject' using Django 2.0.3. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.0/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/2.0/ref/settings/ 11 | """ 12 | 13 | import os 14 | import django 15 | 16 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 17 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 18 | 19 | 20 | # Quick-start development settings - unsuitable for production 21 | # See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/ 22 | 23 | # SECURITY WARNING: keep the secret key used in production secret! 24 | SECRET_KEY = "o)04)%_us9ed1l7*cv&5@t(2*r#$^r7o(q^4p@y9@b20_ay_jv" 25 | 26 | # SECURITY WARNING: don't run with debug turned on in production! 27 | DEBUG = True 28 | 29 | ALLOWED_HOSTS = [] 30 | 31 | 32 | # Application definition 33 | 34 | INSTALLED_APPS = [ 35 | "django.contrib.staticfiles", 36 | "django.contrib.admin", 37 | "django.contrib.auth", 38 | "django.contrib.messages", 39 | "django.contrib.contenttypes", 40 | "django.contrib.sessions", 41 | "rest_framework", 42 | "test_project.testapp", 43 | "rest_framework.authtoken", 44 | ] 45 | 46 | 47 | MIDDLEWARE = [ 48 | "django.middleware.security.SecurityMiddleware", 49 | "django.contrib.sessions.middleware.SessionMiddleware", 50 | "django.middleware.common.CommonMiddleware", 51 | "django.middleware.csrf.CsrfViewMiddleware", 52 | "django.contrib.auth.middleware.AuthenticationMiddleware", 53 | "django.contrib.messages.middleware.MessageMiddleware", 54 | "django.middleware.clickjacking.XFrameOptionsMiddleware", 55 | ] 56 | 57 | ROOT_URLCONF = "test_project.urls" 58 | 59 | TEMPLATES = [ 60 | { 61 | "BACKEND": "django.template.backends.django.DjangoTemplates", 62 | "DIRS": [], 63 | "APP_DIRS": True, 64 | "OPTIONS": { 65 | "context_processors": [ 66 | "django.template.context_processors.debug", 67 | "django.template.context_processors.request", 68 | "django.contrib.auth.context_processors.auth", 69 | "django.contrib.messages.context_processors.messages", 70 | ] 71 | }, 72 | } 73 | ] 74 | 75 | WSGI_APPLICATION = "test_project.wsgi.application" 76 | 77 | 78 | # Database 79 | # https://docs.djangoproject.com/en/2.0/ref/settings/#databases 80 | 81 | DATABASES = { 82 | "default": { 83 | "ENGINE": "django.db.backends.sqlite3", 84 | "NAME": os.path.join(BASE_DIR, "db.sqlite3"), 85 | } 86 | } 87 | 88 | 89 | # Internationalization 90 | # https://docs.djangoproject.com/en/2.0/topics/i18n/ 91 | 92 | LANGUAGE_CODE = "en-us" 93 | 94 | TIME_ZONE = "UTC" 95 | 96 | USE_I18N = True 97 | 98 | USE_L10N = True 99 | 100 | USE_TZ = True 101 | 102 | 103 | # Static files (CSS, JavaScript, Images) 104 | # https://docs.djangoproject.com/en/2.0/howto/static-files/ 105 | 106 | STATIC_URL = "/static/" 107 | STATIC_ROOT = BASE_DIR + "/static/" 108 | 109 | REST_FRAMEWORK = { 110 | "PAGE_SIZE": 100, 111 | "URL_FIELD_NAME": "self_link", 112 | "DEFAULT_AUTHENTICATION_CLASSES": [ 113 | "rest_framework.authentication.TokenAuthentication" 114 | ], 115 | "DEFAULT_RENDERER_CLASSES": [ 116 | "rest_framework.renderers.JSONRenderer", 117 | "rest_framework.renderers.BrowsableAPIRenderer", 118 | ], 119 | } 120 | 121 | 122 | LOGGING = { 123 | "disable_existing_loggers": False, 124 | "version": 1, 125 | "formatters": {"standard": {"format": "%(asctime)s %(levelname)s %(message)s"}}, 126 | "handlers": { 127 | "console": { 128 | # logging handler that outputs log messages to terminal 129 | "class": "logging.StreamHandler", 130 | "level": "DEBUG", # message level to be written to console 131 | } 132 | }, 133 | "loggers": { 134 | "": { 135 | # this sets root level logger to log debug and higher level 136 | # logs to console. All other loggers inherit settings from 137 | # root level logger. 138 | "handlers": ["console"], 139 | "level": "DEBUG", 140 | "propagate": False, # this tells logger to send logging message 141 | # to its parent (will send if set to True) 142 | }, 143 | "django.db": {"level": "DEBUG", "handers": ["console"]}, 144 | "requests": {"handlers": ["console"], "level": "INFO"}, 145 | }, 146 | } 147 | 148 | DRF_TYPED_VIEWS = {"schema_packages": ["pydantic", "marshmallow", "typesystem"]} 149 | -------------------------------------------------------------------------------- /rest_typed_views/param_settings.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from rest_framework.fields import empty 4 | 5 | 6 | class ParamSettings(object): 7 | param_type: Optional[str] 8 | default: Any 9 | source: Optional[str] 10 | min_value: Optional[int] 11 | max_value: Optional[int] 12 | input_formats: Optional[List[str]] 13 | format: Optional[str] 14 | regex: Optional[str] 15 | min_length: Optional[int] 16 | max_length: Optional[int] 17 | trim_whitespace: bool 18 | allow_blank: bool 19 | default_timezone: Optional[Any] 20 | choices: Optional[List[Any]] 21 | delimiter: str 22 | max_digits: Optional[int] 23 | decimal_places: Optional[int] 24 | rounding: Optional[str] 25 | coerce_to_string: bool 26 | localize: bool 27 | path: Optional[str] 28 | match: Optional[str] 29 | recursive: bool 30 | allow_files: bool 31 | allow_folders: bool 32 | protocol: str 33 | child: Optional["ParamSettings"] 34 | allow_empty: Optional[bool] 35 | member_of: Optional[str] 36 | member_of_any: List[str] 37 | 38 | def __init__( 39 | self, 40 | param_type: Optional[str] = None, 41 | default: Any = empty, 42 | source: str = None, 43 | min_value: int = None, 44 | max_value: int = None, 45 | input_formats: List[str] = None, 46 | format: str = None, 47 | regex: str = None, 48 | min_length: int = None, 49 | max_length: int = None, 50 | trim_whitespace: bool = True, 51 | allow_blank: bool = False, 52 | default_timezone=None, 53 | choices: List[Any] = None, 54 | delimiter: str = ",", 55 | # DecimalField args 56 | max_digits: int = None, 57 | decimal_places: int = None, 58 | rounding: str = None, 59 | coerce_to_string: bool = False, 60 | localize: bool = False, 61 | # FilePathField args 62 | path: str = None, 63 | match: str = None, 64 | recursive: bool = False, 65 | allow_files: bool = True, 66 | allow_folders: bool = False, 67 | # IPAddressField args 68 | protocol: str = "both", 69 | # ListField arg 70 | child: "ParamSettings" = None, 71 | allow_empty: bool = True, 72 | # Current user validator arg 73 | member_of: str = None, 74 | member_of_any: List[str] = [], 75 | ): 76 | self.param_type = param_type 77 | self.default = default 78 | self.source = source 79 | self.min_value = min_value 80 | self.max_value = max_value 81 | self.input_formats = input_formats 82 | self.format = format 83 | self.regex = regex 84 | self.min_length = min_length 85 | self.max_length = max_length 86 | self.trim_whitespace = trim_whitespace 87 | self.allow_blank = allow_blank 88 | self.default_timezone = default_timezone 89 | self.choices = choices 90 | self.delimiter = delimiter 91 | self.max_digits = max_digits 92 | self.decimal_places = decimal_places 93 | self.rounding = rounding 94 | self.coerce_to_string = coerce_to_string 95 | self.localize = localize 96 | self.path = path 97 | self.match = match 98 | self.recursive = recursive 99 | self.allow_files = allow_files 100 | self.allow_folders = allow_folders 101 | self.protocol = protocol 102 | self.child = child 103 | self.allow_empty = allow_empty 104 | self.member_of = member_of 105 | self.member_of_any = member_of_any 106 | 107 | if self.regex and self.format: 108 | raise Exception("Cannot set both 'regex' and 'format'") 109 | 110 | if self.protocol not in ("both", "IPv4", "IPv6"): 111 | raise Exception( 112 | "'protocol' (for validating IP addresses) must be one of: both, IPv4, IPv6" 113 | ) 114 | 115 | if self.format is not None and self.format not in ( 116 | "uuid", 117 | "email", 118 | "slug", 119 | "url", 120 | "ipv4", 121 | "ipv6", 122 | "file_path", 123 | ): 124 | raise Exception( 125 | "'format' must be one of: uuid, email, slug, url, ip_address, file_path" 126 | ) 127 | 128 | if self.param_type and self.param_type not in ( 129 | "body", 130 | "query_param", 131 | "path", 132 | "current_user", 133 | "header", 134 | # "cookie", 135 | ): 136 | raise Exception( 137 | "'param_type' must be one of: body, query_param, path, current_user, header" 138 | ) 139 | -------------------------------------------------------------------------------- /test_project/testapp/tests/test_decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from unittest.mock import MagicMock, patch 3 | 4 | from pydantic import BaseModel 5 | from rest_framework.exceptions import ValidationError 6 | from rest_framework.request import Request 7 | from rest_framework.test import APITestCase 8 | 9 | from rest_typed_views import Body, CurrentUser, ParamSettings, Path, Query 10 | from rest_typed_views.decorators import ( 11 | build_explicit_param, 12 | get_view_param, 13 | transform_view_params, 14 | ) 15 | from rest_typed_views.params import ( 16 | BodyParam, 17 | CurrentUserParam, 18 | PassThruParam, 19 | PathParam, 20 | QueryParam, 21 | ) 22 | 23 | 24 | class DecoratorTests(APITestCase): 25 | def fake_request(data={}, query_params={}): 26 | return MagicMock(data=data, query_params=query_params) 27 | 28 | def get_params(self, func): 29 | return list(inspect.signature(func).parameters.values()) 30 | 31 | def test_transform_view_params_succeeds(self): 32 | def example_function(id: int, q: str): 33 | return 34 | 35 | request = self.fake_request(query_params={"q": "cats"}) 36 | typed_params = inspect.signature(example_function).parameters.values() 37 | result = transform_view_params(typed_params, request, {"id": "1"}) 38 | self.assertEqual(result, [1, "cats"]) 39 | 40 | def test_transform_view_params_throws_error(self): 41 | def example_function(id: int, q: str): 42 | return 43 | 44 | request = self.fake_request(query_params={}) 45 | typed_params = self.get_params(example_function) 46 | 47 | with self.assertRaises(ValidationError) as context: 48 | transform_view_params(typed_params, request, {"id": "one"}) 49 | 50 | self.assertTrue("A valid integer is required" in str(context.exception)) 51 | self.assertTrue("This field is required" in str(context.exception)) 52 | 53 | @patch("rest_typed_views.decorators.build_explicit_param") 54 | def test_get_view_param_if_explicit_settings(self, mock_build_explicit_param): 55 | def example_function(body: str = Body(source="name")): 56 | return 57 | 58 | get_view_param(self.get_params(example_function)[0], self.fake_request(), {}) 59 | mock_build_explicit_param.assert_called_once() 60 | 61 | def test_get_view_param_if_explicit_request_param(self): 62 | def example_function(request: Request): 63 | return 64 | 65 | result = get_view_param( 66 | self.get_params(example_function)[0], self.fake_request(), {} 67 | ) 68 | 69 | self.assertTrue(isinstance(result, PassThruParam)) 70 | 71 | def test_get_view_param_if_implicit_path_param(self): 72 | def example_function(pk: int): 73 | return 74 | 75 | result = get_view_param( 76 | self.get_params(example_function)[0], self.fake_request(), {"pk": 1} 77 | ) 78 | 79 | self.assertTrue(isinstance(result, PathParam)) 80 | 81 | def test_get_view_param_if_implicit_body_param(self): 82 | class User(BaseModel): 83 | id: int 84 | name = "John Doe" 85 | 86 | def example_function(user: User): 87 | return 88 | 89 | result = get_view_param( 90 | self.get_params(example_function)[0], self.fake_request(), {} 91 | ) 92 | 93 | self.assertTrue(isinstance(result, BodyParam)) 94 | 95 | def test_get_view_param_if_implicit_query_param(self): 96 | def example_function(q: str): 97 | return 98 | 99 | result = get_view_param( 100 | self.get_params(example_function)[0], self.fake_request(), {} 101 | ) 102 | 103 | self.assertTrue(isinstance(result, QueryParam)) 104 | 105 | def test_build_explicit_param_for_query(self): 106 | def example_function(q: str = Query()): 107 | return 108 | 109 | result = build_explicit_param( 110 | self.get_params(example_function)[0], 111 | self.fake_request(), 112 | ParamSettings(param_type="query_param"), 113 | {}, 114 | ) 115 | 116 | self.assertTrue(isinstance(result, QueryParam)) 117 | 118 | def test_build_explicit_param_for_path(self): 119 | def example_function(q: str = Path()): 120 | return 121 | 122 | result = build_explicit_param( 123 | self.get_params(example_function)[0], 124 | self.fake_request(), 125 | ParamSettings(param_type="path"), 126 | {}, 127 | ) 128 | 129 | self.assertTrue(isinstance(result, PathParam)) 130 | 131 | def test_build_explicit_param_for_body(self): 132 | def example_function(q: str = Body()): 133 | return 134 | 135 | result = build_explicit_param( 136 | self.get_params(example_function)[0], 137 | self.fake_request(), 138 | ParamSettings(param_type="body"), 139 | {}, 140 | ) 141 | 142 | self.assertTrue(isinstance(result, BodyParam)) 143 | 144 | def test_build_explicit_param_for_current_user(self): 145 | def example_function(q: str = CurrentUser()): 146 | return 147 | 148 | result = build_explicit_param( 149 | self.get_params(example_function)[0], 150 | self.fake_request(), 151 | ParamSettings(param_type="current_user"), 152 | {}, 153 | ) 154 | 155 | self.assertTrue(isinstance(result, CurrentUserParam)) 156 | -------------------------------------------------------------------------------- /rest_typed_views/decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Dict, List 3 | 4 | from rest_framework.views import APIView 5 | from rest_framework.decorators import action, api_view 6 | from rest_framework.exceptions import ValidationError 7 | from rest_framework.fields import empty 8 | from rest_framework.request import Request 9 | 10 | from rest_typed_views.utils import ( 11 | find_request, 12 | get_default_value, 13 | get_explicit_param_settings, 14 | is_explicit_request_param, 15 | is_implicit_body_param, 16 | is_implicit_request_param, 17 | ) 18 | 19 | from .param_settings import ParamSettings 20 | from .params import BodyParam, CurrentUserParam, PassThruParam, PathParam, QueryParam, HeaderParam 21 | 22 | 23 | 24 | def wraps_drf(view): 25 | def _wraps_drf(func): 26 | def wrapper(*args, **kwargs): 27 | return func(*args, **kwargs) 28 | 29 | wrapper.__name__ = view.__name__ 30 | wrapper.__module__ = view.__module__ 31 | wrapper.renderer_classes = getattr( 32 | view, "renderer_classes", APIView.renderer_classes 33 | ) 34 | wrapper.parser_classes = getattr(view, "parser_classes", APIView.parser_classes) 35 | wrapper.authentication_classes = getattr( 36 | view, "authentication_classes", APIView.authentication_classes 37 | ) 38 | wrapper.throttle_classes = getattr( 39 | view, "throttle_classes", APIView.throttle_classes 40 | ) 41 | wrapper.permission_classes = getattr( 42 | view, "permission_classes", APIView.permission_classes 43 | ) 44 | return wrapper 45 | 46 | return _wraps_drf 47 | 48 | 49 | def build_explicit_param( 50 | param: inspect.Parameter, request: Request, settings: ParamSettings, path_args: dict 51 | ): 52 | if settings.param_type == "path": 53 | key = settings.source or param.name 54 | raw_value = path_args.get(key, empty) 55 | return PathParam(param, request, settings=settings, raw_value=raw_value) 56 | elif settings.param_type == "body": 57 | return BodyParam(param, request, settings=settings) 58 | elif settings.param_type == "header": 59 | return HeaderParam(param, request, settings=settings) 60 | elif settings.param_type == "current_user": 61 | return CurrentUserParam(param, request, settings=settings) 62 | elif settings.param_type == "query_param": 63 | return QueryParam(param, request, settings=settings) 64 | 65 | 66 | def get_view_param(param: inspect.Parameter, request: Request, path_args: dict): 67 | explicit_settings = get_explicit_param_settings(param) 68 | default = get_default_value(param) 69 | 70 | if explicit_settings: 71 | return build_explicit_param(param, request, explicit_settings, path_args) 72 | elif is_explicit_request_param(param): 73 | return PassThruParam(request) 74 | elif param.name in path_args: 75 | return PathParam( 76 | param, 77 | request, 78 | settings=ParamSettings(param_type="path", default=default), 79 | raw_value=path_args.get(param.name), 80 | ) 81 | elif is_implicit_body_param(param): 82 | return BodyParam( 83 | param, request, settings=ParamSettings(param_type="body", default=default) 84 | ) 85 | elif is_implicit_request_param(param): 86 | return PassThruParam(request) 87 | else: 88 | return QueryParam( 89 | param, 90 | request, 91 | settings=ParamSettings(param_type="query_param", default=default), 92 | ) 93 | 94 | 95 | def transform_view_params( 96 | typed_params: List[inspect.Parameter], request: Request, path_args: dict 97 | ): 98 | validated_params = [] 99 | errors: Dict[str, Any] = {} 100 | 101 | for param in typed_params: 102 | p = get_view_param(param, request, path_args) 103 | value, error = p.validate_or_error() 104 | 105 | if error: 106 | errors.update(error) 107 | else: 108 | validated_params.append(value) 109 | 110 | if len(errors) > 0: 111 | raise ValidationError(errors) 112 | 113 | return validated_params 114 | 115 | 116 | def prevalidate(view_func, for_method: bool = False): 117 | arg_info = inspect.getfullargspec(view_func) 118 | 119 | if arg_info.varargs is not None or arg_info.varkw is not None: 120 | raise Exception( 121 | f"{view_func.__name__}: variable-length argument lists and dictionaries cannot be used with typed views" 122 | ) 123 | 124 | if for_method: 125 | error_msg = "For typed methods, 'self' must be passed as the first arg with no annotation" 126 | 127 | if ( 128 | len(arg_info.args) < 1 129 | or arg_info.args[0] != "self" 130 | or "self" in arg_info.annotations 131 | ): 132 | raise Exception(error_msg) 133 | 134 | 135 | def typed_api_view(methods): 136 | def wrap_validate_and_render(view): 137 | prevalidate(view) 138 | 139 | @api_view(methods) 140 | @wraps_drf(view) 141 | def wrapper(*original_args, **original_kwargs): 142 | original_args = list(original_args) 143 | request = find_request(original_args) 144 | transformed = transform_view_params( 145 | inspect.signature(view).parameters.values(), request, original_kwargs 146 | ) 147 | return view(*transformed) 148 | 149 | return wrapper 150 | 151 | return wrap_validate_and_render 152 | 153 | 154 | def typed_action(**action_kwargs): 155 | def wrap_validate_and_render(view): 156 | prevalidate(view, for_method=True) 157 | 158 | @action(**action_kwargs) 159 | @wraps_drf(view) 160 | def wrapper(*original_args, **original_kwargs): 161 | original_args = list(original_args) 162 | request = find_request(original_args) 163 | _self = original_args.pop(0) 164 | 165 | typed_params = [ 166 | p for n, p in inspect.signature(view).parameters.items() if n != "self" 167 | ] 168 | 169 | transformed = transform_view_params(typed_params, request, original_kwargs) 170 | return view(_self, *transformed) 171 | 172 | return wrapper 173 | 174 | return wrap_validate_and_render 175 | -------------------------------------------------------------------------------- /rest_typed_views/validators/validator_factory.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime, time, timedelta 2 | from decimal import Decimal 3 | from typing import Any 4 | 5 | from rest_framework import serializers 6 | 7 | from rest_typed_views.param_settings import ParamSettings 8 | from rest_typed_views.utils import ( 9 | parse_complex_type, 10 | parse_enum_annotation, 11 | parse_list_annotation, 12 | ) 13 | from rest_typed_views.validators import ( 14 | DefaultValidator, 15 | PydanticValidator, 16 | TypeSystemValidator, 17 | MarshMallowValidator, 18 | ) 19 | 20 | 21 | class ValidatorFactory(object): 22 | @classmethod 23 | def make_string_validator(cls, settings: ParamSettings): 24 | if settings.regex: 25 | return serializers.RegexField( 26 | settings.regex, 27 | default=settings.default, 28 | max_length=settings.max_length, 29 | min_length=settings.min_length, 30 | ) 31 | 32 | if settings.format is None: 33 | return serializers.CharField( 34 | default=settings.default, 35 | max_length=settings.max_length, 36 | min_length=settings.min_length, 37 | trim_whitespace=settings.trim_whitespace, 38 | ) 39 | 40 | if settings.format == "email": 41 | return serializers.EmailField( 42 | default=settings.default, 43 | max_length=settings.max_length, 44 | min_length=settings.min_length, 45 | ) 46 | 47 | if settings.format == "slug": 48 | return serializers.SlugField( 49 | default=settings.default, 50 | max_length=settings.max_length, 51 | min_length=settings.min_length, 52 | ) 53 | 54 | if settings.format == "url": 55 | return serializers.URLField( 56 | default=settings.default, 57 | max_length=settings.max_length, 58 | min_length=settings.min_length, 59 | ) 60 | 61 | if settings.format == "uuid": 62 | return serializers.UUIDField(default=settings.default) 63 | 64 | if settings.format == "file_path": 65 | return serializers.FilePathField( 66 | default=settings.default, 67 | path=settings.path, 68 | match=settings.match, 69 | recursive=settings.recursive, 70 | allow_files=settings.allow_files, 71 | allow_folders=settings.allow_folders, 72 | ) 73 | 74 | if settings.format == "ipv6": 75 | return serializers.IPAddressField(default=settings.default, protocol="IPv6") 76 | 77 | if settings.format == "ipv4": 78 | return serializers.IPAddressField(default=settings.default, protocol="IPv4") 79 | 80 | if settings.format == "ip": 81 | return serializers.IPAddressField(default=settings.default, protocol="both") 82 | 83 | @classmethod 84 | def make_list_validator(cls, item_type: Any, settings: ParamSettings): 85 | options = { 86 | "min_length": settings.min_length, 87 | "max_length": settings.max_length, 88 | "allow_empty": settings.allow_empty, 89 | "default": settings.default, 90 | } 91 | if item_type is not Any: 92 | options["child"] = ValidatorFactory.make( 93 | item_type, settings.child or ParamSettings() 94 | ) 95 | 96 | return serializers.ListField(**options) 97 | 98 | @classmethod 99 | def make(cls, annotation: Any, settings: ParamSettings): 100 | if annotation is bool: 101 | return serializers.BooleanField(default=settings.default) 102 | 103 | if annotation is str: 104 | return cls.make_string_validator(settings) 105 | 106 | if annotation is int: 107 | return serializers.IntegerField( 108 | default=settings.default, 109 | max_value=settings.max_value, 110 | min_value=settings.min_value, 111 | ) 112 | 113 | if annotation is float: 114 | return serializers.FloatField( 115 | default=settings.default, 116 | max_value=settings.max_value, 117 | min_value=settings.min_value, 118 | ) 119 | 120 | if annotation is Decimal: 121 | return serializers.DecimalField( 122 | default=settings.default, 123 | max_digits=settings.max_digits, 124 | decimal_places=settings.decimal_places, 125 | coerce_to_string=settings.coerce_to_string, 126 | localize=settings.localize, 127 | rounding=settings.rounding, 128 | max_value=settings.max_value, 129 | min_value=settings.min_value, 130 | ) 131 | 132 | if annotation is datetime: 133 | return serializers.DateTimeField( 134 | default=settings.default, 135 | input_formats=settings.input_formats, 136 | default_timezone=settings.default_timezone, 137 | ) 138 | 139 | if annotation is date: 140 | return serializers.DateField( 141 | default=settings.default, input_formats=settings.input_formats 142 | ) 143 | 144 | if annotation is time: 145 | return serializers.TimeField( 146 | default=settings.default, input_formats=settings.input_formats 147 | ) 148 | 149 | if annotation is timedelta: 150 | return serializers.DurationField(default=settings.default) 151 | 152 | is_enum_type, values = parse_enum_annotation(annotation) 153 | 154 | if is_enum_type: 155 | return serializers.ChoiceField(choices=values) 156 | 157 | is_list_type, item_type = parse_list_annotation(annotation) 158 | 159 | if is_list_type: 160 | return cls.make_list_validator(item_type, settings) 161 | 162 | is_complex_type, package = parse_complex_type(annotation) 163 | 164 | if is_complex_type and package == "pydantic": 165 | return PydanticValidator(annotation) 166 | 167 | if is_complex_type and package == "typesystem": 168 | return TypeSystemValidator(annotation) 169 | 170 | if is_complex_type and package == "marshmallow": 171 | return MarshMallowValidator(annotation) 172 | 173 | return DefaultValidator(default=settings.default) 174 | -------------------------------------------------------------------------------- /test_project/testapp/tests/test_typed_api_view.py: -------------------------------------------------------------------------------- 1 | from rest_framework.test import APITestCase 2 | from rest_framework.reverse import reverse 3 | from decimal import Decimal 4 | from uuid import UUID 5 | import datetime 6 | 7 | 8 | class TypedAPIViewTests(APITestCase): 9 | def test_get_logs__ok(self): 10 | url = reverse("get-log-entry", args=[7]) 11 | 12 | response = self.client.get( 13 | url, 14 | { 15 | "latitude": "63.44", 16 | "title": "My title", 17 | "price": 7.5, 18 | "is_pretty": "yes", 19 | "email": "homer@hotmail.com", 20 | "upper_alpha_string": "NBA", 21 | "identifier": "24adfads", 22 | "website": "http://bloomgerg.com", 23 | "identity": "a1e77325-8429-480e-a990-8764f33db2d8", 24 | "ip": "162.254.168.185", 25 | "timestamp": "2013-07-16T19:23:00Z", 26 | "start_date": "2013-07-16", 27 | "start_time": "10:00", 28 | "duration": "12 2:23", 29 | "bag": "paper", 30 | "numbers": "1,2,3", 31 | }, 32 | format="json", 33 | ) 34 | 35 | self.assertEqual(response.status_code, 200) 36 | 37 | self.assertEqual( 38 | response.data, 39 | { 40 | "id": 7, 41 | "title": "My title", 42 | "price": 7.5, 43 | "latitude": Decimal("63.44000000000000000000"), 44 | "is_pretty": True, 45 | "email": "homer@hotmail.com", 46 | "upper_alpha_string": "NBA", 47 | "identifier": "24adfads", 48 | "website": "http://bloomgerg.com", 49 | "identity": UUID("a1e77325-8429-480e-a990-8764f33db2d8"), 50 | "ip": "162.254.168.185", 51 | "timestamp": datetime.datetime( 52 | 2013, 7, 16, 19, 23, tzinfo=datetime.timezone.utc 53 | ), 54 | "start_date": datetime.date(2013, 7, 16), 55 | "start_time": datetime.time(10, 0), 56 | "duration": datetime.timedelta(12, 143), 57 | "bag_type": "paper", 58 | "numbers": [1, 2, 3], 59 | }, 60 | ) 61 | 62 | def test_get_logs_error(self): 63 | url = reverse("get-log-entry", args=[7]) 64 | 65 | response = self.client.get( 66 | url, 67 | { 68 | "is_pretty": "maybe", 69 | "email": "homerathotmail.com", 70 | "upper_alpha_string": "mma", 71 | "identifier": "**()", 72 | "identity": "a1e77325-8429-480e-a990-8764f33db2d8", 73 | "ip": "162.254.168.185", 74 | "timestamp": "700BC-07-16T19:23:00Z", 75 | "start_date": "i'll get to it", 76 | "start_time": "when i wake up", 77 | "duration": "forever", 78 | "bag": "scrotum", 79 | "numbers": "fiver", 80 | }, 81 | format="json", 82 | ) 83 | 84 | self.assertEqual(response.status_code, 400) 85 | 86 | self.assertEqual( 87 | response.json(), 88 | { 89 | "latitude": ["This field is required."], 90 | "title": ["This field is required."], 91 | "price": ["This field is required."], 92 | "is_pretty": ["Must be a valid boolean."], 93 | "email": ["Enter a valid email address."], 94 | "upper_alpha_string": [ 95 | "This value does not match the required pattern." 96 | ], 97 | "identifier": [ 98 | 'Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.' 99 | ], 100 | "website": ["This field is required."], 101 | "timestamp": [ 102 | "Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]." 103 | ], 104 | "start_date": [ 105 | "Date has wrong format. Use one of these formats instead: YYYY-MM-DD." 106 | ], 107 | "start_time": [ 108 | "Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]." 109 | ], 110 | "duration": [ 111 | "Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu]." 112 | ], 113 | "bag": ['"scrotum" is not a valid choice.'], 114 | "numbers": {"0": ["A valid integer is required."]}, 115 | }, 116 | ) 117 | 118 | def test_create_user_ok(self): 119 | url = reverse("create-user") 120 | response = self.client.post( 121 | url, 122 | { 123 | "id": 12, 124 | "name": "Robert", 125 | "signup_ts": "2013-07-16T19:23:00Z", 126 | "friends": [3], 127 | }, 128 | format="json", 129 | ) 130 | 131 | self.assertEqual(response.status_code, 200) 132 | 133 | self.assertEqual( 134 | response.data, 135 | { 136 | "id": 12, 137 | "name": "Robert", 138 | "signup_ts": datetime.datetime( 139 | 2013, 7, 16, 19, 23, tzinfo=datetime.timezone.utc 140 | ), 141 | "friends": [3], 142 | }, 143 | ) 144 | 145 | def test_create_user_error(self): 146 | url = reverse("create-user") 147 | response = self.client.post( 148 | url, 149 | {"name": "Robert", "signup_ts": "2013-07-16T19:23:00Z", "friends": [3]}, 150 | format="json", 151 | ) 152 | 153 | self.assertEqual(response.status_code, 400) 154 | 155 | self.assertEqual( 156 | response.json(), 157 | { 158 | "user": [ 159 | { 160 | "loc": "('id',)", 161 | "msg": "field required", 162 | "type": "value_error.missing", 163 | } 164 | ] 165 | }, 166 | ) 167 | 168 | def test_cache_header_ok(self): 169 | url = reverse("get-cache-header") 170 | response = self.client.get( 171 | url, HTTP_CACHE="no" 172 | ) 173 | self.assertEqual(response.status_code, 200) 174 | 175 | self.assertEqual( 176 | response.data, 177 | "no", 178 | ) 179 | 180 | def test_cache_header_error(self): 181 | url = reverse("get-cache-header") 182 | response = self.client.get( 183 | url, 184 | ) 185 | self.assertEqual(response.status_code, 400) 186 | 187 | self.assertEqual( 188 | response.json(), 189 | {'cache': ['This field may not be null.']}, 190 | ) 191 | 192 | def test_create_booking_ok(self): 193 | url = reverse("create-booking") 194 | response = self.client.post( 195 | url, 196 | { 197 | "_data": { 198 | "item": { 199 | "start_date": "2019-11-11", 200 | "end_date": "2019-11-13", 201 | "include_breakfast": True, 202 | "room": "twin", 203 | } 204 | } 205 | }, 206 | format="json", 207 | ) 208 | self.assertEqual(response.status_code, 200) 209 | self.assertEqual( 210 | response.data, 211 | { 212 | "start_date": "2019-11-11", 213 | "end_date": "2019-11-13", 214 | "room": "twin", 215 | "include_breakfast": True, 216 | }, 217 | ) 218 | 219 | def test_create_booking_error(self): 220 | url = reverse("create-booking") 221 | response = self.client.post( 222 | url, 223 | { 224 | "_data": { 225 | "item": { 226 | "start_date": "2019-11-11", 227 | "end_date": "2019-11-13", 228 | "include_breakfast": True, 229 | } 230 | } 231 | }, 232 | format="json", 233 | ) 234 | 235 | self.assertEqual(response.status_code, 400) 236 | 237 | self.assertEqual( 238 | response.json(), {"_data.item": {"room": "This field is required."}} 239 | ) 240 | 241 | def test_create_band_member_ok(self): 242 | url = reverse("create-band-member") 243 | response = self.client.post( 244 | url, {"name": "Homer", "email": "homer@hotmail.com"}, format="json" 245 | ) 246 | self.assertEqual(response.status_code, 200) 247 | self.assertEqual(response.data, {"name": "Homer", "email": "homer@hotmail.com"}) 248 | 249 | def test_create_band_member_error(self): 250 | url = reverse("create-band-member") 251 | response = self.client.post(url, {"email": "homer@hotmail.com"}, format="json") 252 | self.assertEqual(response.status_code, 400) 253 | self.assertEqual( 254 | response.json(), 255 | {"band_member": {"name": ["Missing data for required field."]}}, 256 | ) 257 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NOT MAINTAINED 2 | 3 | Use this project instead: [drf-typed](https://github.com/rsinger86/drf-typed). 4 | 5 | It includes everything this project does, plus it includes typed features for serializers and helpful type stubs. 6 | 7 | ## Django REST Framework - Typed Views 8 | 9 | [![Package version](https://badge.fury.io/py/drf-typed-views.svg)](https://pypi.python.org/pypi/drf-typed-views) 10 | [![Python versions](https://img.shields.io/pypi/status/drf-typed-views.svg)](https://img.shields.io/pypi/status/drf-typed-views.svg/) 11 | 12 | This project extends [Django Rest Framework](https://www.django-rest-framework.org/) to allow use of Python's type annotations for automatically validating and casting view parameters. This pattern makes for code that is easier to read and write. View inputs are individually declared, not buried inside all-encompassing `request` objects. Meanwhile, you get even more out of type annotations: they can replace repetitive validation/sanitization code. 13 | 14 | More features: 15 | - [Pydantic](https://pydantic-docs.helpmanual.io/) models and [Marshmallow](https://marshmallow.readthedocs.io) schemas are compatible types for view parameters. Annotate your POST/PUT functions with them to automatically validate incoming request bodies. 16 | - Advanced validators for more than just the type: `min_value`/`max_value` for numbers 17 | - Validate string formats: `email`, `uuid` and `ipv4/6`; use Python's native `Enum` for "choices" validation 18 | 19 | Quick example: 20 | ```python 21 | from rest_typed_views import typed_api_view 22 | 23 | @typed_api_view(["GET"]) 24 | def get_users(registered_on: date = None, groups: List[int] = None, is_staff: bool = None): 25 | print(registered_on, groups, is_staff) 26 | ``` 27 | 28 | GET `/users/registered/?registered_on=2019-03-03&groups=4,5&is_staff=yes`
29 | Status Code: 200 30 | ``` 31 | date(2019, 3, 3) [4, 5] True 32 | ``` 33 | 34 | GET `/users/?registered_on=9999&groups=admin&is_staff=maybe`
35 | :no_entry_sign: Status Code: 400 *ValidationError raised* 36 | ```json 37 | { 38 | "registered_on": "'9999' is not a valid date", 39 | "groups": "'admin' is not a valid integer", 40 | "is_staff": "'maybe' is not a valid boolean" 41 | } 42 | ``` 43 | ## Table of Contents 44 | * [Install & Decorators](#install--decorators) 45 | * [How It Works: Simple Usage](#how-it-works-simple-usage) 46 | * [Basic GET Request](#basic-get-request) 47 | * [Basic POST Request](#basic-post-request) 48 | * [How It Works: Advanced Usage](#how-it-works-advanced-usage) 49 | * [Additional Validation Rules](#additional-validation-rules) 50 | * [Nested Body Fields](#nested-body-fields) 51 | * [List Validation](#list-validation) 52 | * [Accessing the Request Object](#accessing-the-request-object) 53 | * [Interdependent Query Parameter Validation](#interdependent-query-parameter-validation) 54 | * [(Simple) Access Control](#simple-access-control) 55 | * [Enabling Marshmallow, Pydantic Schemas](#enabling-3rd-party-validators) 56 | * [Request Element Classes](#request-element-classes) 57 | * [Query](#query) 58 | * [Body](#body) 59 | * [Path](#path) 60 | * [Header](#header) 61 | * [CurrentUser](#currentuser) 62 | * [Supported Types/Validator Rules](#supported-types-and-validator-rules) 63 | * [int](#int) 64 | * [float](#float) 65 | * [Decimal](#decimal) 66 | * [str](#str) 67 | * [bool](#bool) 68 | * [datetime](#datetime) 69 | * [date](#date) 70 | * [time](#time) 71 | * [timedelta](#timedelta) 72 | * [List](#list) 73 | * [Enum](#enum) 74 | * [marshmallow.Schema](#marshmallowschema) 75 | * [pydantic.BaseModel](#pydanticbasemodel) 76 | * [Change Log](#changes) 77 | * [Motivation & Inspiration](#motivation) 78 | 79 | ## Install & Decorators 80 | 81 | ``` 82 | pip install drf-typed-views 83 | ``` 84 | 85 | You can add type annotation-enabled features to either `ViewSet` methods or function-based views using the `typed_action` and `typed_api_view` decorators. They take the exact same arguments as Django REST's [`api_view`](https://www.django-rest-framework.org/api-guide/views/#api_view) and [`action`](https://www.django-rest-framework.org/api-guide/viewsets/#marking-extra-actions-for-routing) decorators. 86 | 87 | ## How It Works: Simple Usage 88 | 89 | For many cases, you can rely on implicit behavior for how different parts of the request (URL path variables, query parameters, body) map to the parameters of a view function/method. 90 | 91 | The value of a view parameter will come from... 92 | - the URL path if the path variable and the view argument have the same name, *or*: 93 | - the request body if the view argument is annotated using a class from a supported library for complex object validation (Pydantic, MarshMallow), *or*: 94 | - a query parameter with the same name 95 | 96 | Unless a default value is given, the parameter is **required** and a [`ValidationError`](https://www.django-rest-framework.org/api-guide/exceptions/#validationerror) will be raised if not set. 97 | 98 | ### Basic GET Request 99 | ```python 100 | urlpatterns = [ 101 | url(r"^(?P[\w+])/restaurants/", search_restaurants) 102 | ] 103 | 104 | from rest_typed_views import typed_api_view 105 | 106 | # Example request: /chicago/restaurants?delivery=yes 107 | @typed_api_view(["GET"]) 108 | def search_restaurants(city: str, rating: float = None, offers_delivery: bool = None): 109 | restaurants = Restaurant.objects.filter(city=city) 110 | 111 | if rating is not None: 112 | restaurants = restaurants.filter(rating__gte=rating) 113 | 114 | if offers_delivery is not None: 115 | restaurants = restaurants.filter(delivery=offers_delivery) 116 | ``` 117 | 118 | In this example, `city` is required and must be its string. Its value comes from the URL path variable with the same name. The other parameters, `rating` and `offers_delivery`, are not part of the path parameters and are assumed to be query parameters. They both have a default value, so they are optional. 119 | 120 | ### Basic POST Request 121 | ```python 122 | # urls.py 123 | urlpatterns = [url(r"^(?P[\w+])/bookings/", create_booking)] 124 | 125 | # settings.py 126 | DRF_TYPED_VIEWS = {"schema_packages": ["pydantic"]} 127 | 128 | # views.py 129 | from pydantic import BaseModel 130 | from rest_typed_views import typed_api_view 131 | 132 | 133 | class RoomEnum(str, Enum): 134 | double = 'double' 135 | twin = 'twin' 136 | single = 'single' 137 | 138 | 139 | class BookingSchema(BaseModel): 140 | start_date: date 141 | end_date: date 142 | room: RoomEnum = RoomEnum.double 143 | include_breakfast: bool = False 144 | 145 | # Example request: /chicago/bookings/ 146 | @typed_api_view(["POST"]) 147 | def create_booking(city: str, booking: BookingSchema): 148 | # do something with the validated booking... 149 | ``` 150 | 151 | In this example, `city` will again be populated using the URL path variable. The `booking` parameter is annotated using a supported complex schema class (Pydantic), so it's assumed to come from the request body, which will be read in as JSON, used to hydrate the Pydantic `BookingSchema` and then validated. If validation fails a `ValidationError` will be raised. 152 | 153 | ## How It Works: Advanced Usage 154 | 155 | For more advanced use cases, you can explicitly declare how each parameter's value is sourced from the request -- from the query parameters, path, body or headers -- as well as define additional validation rules. You import a class named after the request element that is expected to hold the value and assign it to the parameter's default. 156 | 157 | ```python 158 | from rest_typed_views import typed_api_view, Query, Path 159 | 160 | @typed_api_view(["GET"]) 161 | def list_documents(year: date = Path(), title: str = Query(default=None)): 162 | # ORM logic here... 163 | ``` 164 | 165 | In this example, `year` is required and must come from the URL path and `title` is an optional query parameter because the `default` is set. This is similar to Django REST's [serializer fields](https://www.django-rest-framework.org/api-guide/fields/#core-arguments): passing a default implies that the filed is not required. 166 | 167 | ```python 168 | from rest_typed_views import typed_api_view, Header 169 | 170 | @typed_api_view(["GET"]) 171 | def get_cache_header(cache: str = Header()): 172 | # ORM logic here... 173 | ``` 174 | 175 | In this example, `cache` is required and must come from the headers. 176 | 177 | ### Additional Validation Rules 178 | 179 | You can use the request element class (`Query`, `Path`, `Body`, `Header`) to set additional validation constraints. You'll find that these keywords are consistent with Django REST's serializer fields. 180 | 181 | ```python 182 | from rest_typed_views import typed_api_view, Query, Path 183 | 184 | @typed_api_view(["GET"]) 185 | def search_restaurants( 186 | year: date = Path(), 187 | rating: int = Query(default=None, min_value=1, max_value=5) 188 | ): 189 | # ORM logic here... 190 | 191 | 192 | @typed_api_view(["GET"]) 193 | def get_document(id: str = Path(format="uuid")): 194 | # ORM logic here... 195 | 196 | 197 | @typed_api_view(["GET"]) 198 | def search_users( 199 | email: str = Query(default=None, format="email"), 200 | ip_address: str = Query(default=None, format="ip"), 201 | ): 202 | # ORM logic here... 203 | ``` 204 | 205 | View a [full list](#supported-types-and-validator-rules) of supported types and additional validation rules. 206 | 207 | ### Nested Body Fields 208 | 209 | Similar to how `source` is used in Django REST to control field mappings during serialization, you can use it to specify the exact path to the request data. 210 | 211 | ```python 212 | from pydantic import BaseModel 213 | from rest_typed_views import typed_api_view, Query, Path 214 | 215 | class Document(BaseModel): 216 | title: str 217 | body: str 218 | 219 | """ 220 | POST 221 | { 222 | "strict": false, 223 | "data": { 224 | "title": "A Dark and Stormy Night", 225 | "body": "Once upon a time" 226 | } 227 | } 228 | """ 229 | @typed_api_view(["POST"]) 230 | def create_document( 231 | strict_mode: bool = Body(source="strict"), 232 | item: Document = Body(source="data") 233 | ): 234 | # ORM logic here... 235 | ``` 236 | You can also use dot-notation to source data multiple levels deep in the JSON payload. 237 | 238 | ### List Validation 239 | 240 | For the basic case of list validation - validating types within a comma-delimited string - declare the type to get automatic validation/coercion: 241 | 242 | ```python 243 | from rest_typed_views import typed_api_view, Query 244 | 245 | @typed_api_view(["GET"]) 246 | def search_movies(item_ids: List[int] = [])): 247 | print(item_ids) 248 | 249 | # GET /movies?items_ids=41,64,3 250 | # [41, 64, 3] 251 | ``` 252 | 253 | But you can also specify `min_length` and `max_length`, as well as the `delimiter` and specify additional rules for the child items -- think Django REST's [ListField](https://www.django-rest-framework.org/api-guide/fields/#listfield). 254 | 255 | Import the generic `Param` class and use it to set the rules for the `child` elements: 256 | 257 | ```python 258 | from rest_typed_views import typed_api_view, Query, Param 259 | 260 | @typed_api_view(["GET"]) 261 | def search_outcomes( 262 | scores: List[int] = Query(delimiter="|", child=Param(min_value=0, max_value=100)) 263 | ): 264 | # ORM logic ... 265 | 266 | @typed_api_view(["GET"]) 267 | def search_message( 268 | recipients: List[str] = Query(min_length=1, max_length=10, child=Param(format="email")) 269 | ): 270 | # ORM logic ... 271 | ``` 272 | 273 | ### Accessing the Request Object 274 | 275 | You probably won't need to access the `request` object directly, as this package will provide its relevant properties as view arguments. However, you can include it as a parameter annotated with its type and it will be injected: 276 | 277 | ```python 278 | from rest_framework.request import Request 279 | from rest_typed_views import typed_api_view 280 | 281 | @typed_api_view(["GET"]) 282 | def search_documens(request: Request, q: str = None): 283 | # ORM logic ... 284 | ``` 285 | 286 | ### Interdependent Query Parameter Validation 287 | Often, it's useful to validate a combination of query parameters - for instance, a `start_date` shouldn't come after an `end_date`. You can use complex schema object (Pydantic or Marshmallow) for this scenario. In the example below, `Query(source="*")` is instructing an instance of `SearchParamsSchema` to be populated/validated using all of the query parameters together: `request.query_params.dict()`. 288 | 289 | ```python 290 | from marshmallow import Schema, fields, validates_schema, ValidationError 291 | from rest_typed_views import typed_api_view 292 | 293 | class SearchParamsSchema(Schema): 294 | start_date = fields.Date() 295 | end_date = fields.Date() 296 | 297 | @validates_schema 298 | def validate_numbers(self, data, **kwargs): 299 | if data["start_date"] >= data["end_date"]: 300 | raise ValidationError("end_date must come after start_date") 301 | 302 | @typed_api_view(["GET"]) 303 | def search_documens(search_params: SearchParamsSchema = Query(source="*")): 304 | # ORM logic ... 305 | ``` 306 | 307 | ### (Simple) Access Control 308 | 309 | You can apply some very basic access control by applying some validation rules to a view parameter sourced from the `CurrentUser` request element class. In the example below, a `ValidationError` will be raised if the `request.user` is not a member of either `super_users` or `admins`. 310 | 311 | ```python 312 | from my_pydantic_schemas import BookingSchema 313 | from rest_typed_views import typed_api_view, CurrentUser 314 | 315 | @typed_api_view(["POST"]) 316 | def create_booking( 317 | booking: BookingSchema, 318 | user: User = CurrentUser(member_of_any=["super_users", "admins"]) 319 | ): 320 | # Do something with the request.user 321 | ``` 322 | 323 | Read more about the [`Current User` request element class](#current-user-keywords). 324 | 325 | ## Enabling Marshmallow, Pydantic Schemas 326 | 327 | As an alternative to Django REST's serializers, you can annotate views with [Pydantic](https://pydantic-docs.helpmanual.io/) models or [Marshmallow](https://marshmallow.readthedocs.io/en/stable/) schemas to have their parameters automatically validated and pass an instance of the Pydantic/Marshmallow class to your method/function. 328 | 329 | To enable support for third-party libraries for complex object validation, modify your settings: 330 | 331 | ```python 332 | DRF_TYPED_VIEWS = { 333 | "schema_packages": ["pydantic", "marshmallow"] 334 | } 335 | ``` 336 | 337 | These third-party packages must be installed in your virtual environment/runtime. 338 | 339 | ## Request Element Classes 340 | 341 | You can specify the part of the request that holds each view parameter by using default function arguments, for example: 342 | ```python 343 | from rest_typed_views import Body, Query 344 | 345 | @typed_api_view(["PUT"]) 346 | def update_user( 347 | user: UserSchema = Body(), 348 | optimistic_update: bool = Query(default=False) 349 | ): 350 | ``` 351 | 352 | The `user` parameter will come from the request body and is required because no default is provided. Meanwhile, `optimistic_update` is not required and will be populated from a query parameter with the same name. 353 | 354 | The core keyword arguments to these classes are: 355 | - `default` the default value for the parameter, which is required unless set 356 | - `source` if the view parameter has a different name than its key embedded in the request 357 | 358 | Passing keywords for additional validation constraints is a *powerful capability* that gets you *almost the same feature set* as Django REST's flexible [serializer fields](https://www.django-rest-framework.org/api-guide/fields/). See a [complete list](#supported-types-and-validator-rule) of validation keywords. 359 | 360 | 361 | ### Query 362 | Use the `source` argument to alias the parameter value and pass keywords to set additional constraints. For example, your query parameters can have dashes, but be mapped to a parameter that have underscores: 363 | 364 | ```python 365 | from rest_typed_views import typed_api_view, Query 366 | 367 | @typed_api_view(["GET"]) 368 | def search_events( 369 | starting_after: date = Query(source="starting-after"), 370 | available_tickets: int = Query(default=0, min_value=0) 371 | ): 372 | # ORM logic here... 373 | ``` 374 | 375 | ### Body 376 | By default, the entire request body is used to populate parameters marked with this class (`source="*"`): 377 | 378 | ```python 379 | from rest_typed_views import typed_api_view, Body 380 | from my_pydantic_schemas import ResidenceListing 381 | 382 | @typed_api_view(["POST"]) 383 | def create_listing(residence: ResidenceListing = Body()): 384 | # ORM logic ... 385 | ``` 386 | 387 | However, you can also specify nested fields in the request body, with support for dot notation. 388 | 389 | ```python 390 | """ 391 | POST /users/ 392 | { 393 | "first_name": "Homer", 394 | "last_name": "Simpson", 395 | "contact": { 396 | "phone" : "800-123-456", 397 | "fax": "13235551234" 398 | } 399 | } 400 | """ 401 | from rest_typed_views import typed_api_view, Body 402 | 403 | @typed_api_view(["POST"]) 404 | def create_user( 405 | first_name: str = Body(source="first_name"), 406 | last_name: str = Body(source="last_name"), 407 | phone: str = Body(source="contact.phone", min_length=10, max_length=20) 408 | ): 409 | # ORM logic ... 410 | ``` 411 | 412 | ### Path 413 | Use the `source` argument to alias a view parameter name. More commonly, though, you can set additional validation rules for parameters coming from the URL path. 414 | 415 | ```python 416 | from rest_typed_views import typed_api_view, Query 417 | 418 | @typed_api_view(["GET"]) 419 | def retrieve_event(id: int = Path(min_value=0, max_value=1000)): 420 | # ORM logic here... 421 | ``` 422 | 423 | ### Header 424 | Use the `Header` request element class to automatically retrieve a value from a header. Underscores in variable names are automatically converted to dashes. 425 | 426 | ```python 427 | from rest_typed_views import typed_api_view, Header 428 | 429 | @typed_api_view(["GET"]) 430 | def retrieve_event(id: int, cache_control: str = Header(default="no-cache")): 431 | # ORM logic here... 432 | ``` 433 | 434 | If you prefer, you can explicitly specify the exact header key: 435 | ```python 436 | from rest_typed_views import typed_api_view, Header 437 | 438 | @typed_api_view(["GET"]) 439 | def retrieve_event(id: int, cache_control: str = Header(source="cache-control", default="no-cache")): 440 | # ORM logic here... 441 | ``` 442 | 443 | ### CurrentUser 444 | 445 | Use this class to have a view parameter populated with the current user of the request. You can even extract fields from the current user using the `source` option. 446 | 447 | ```python 448 | from my_pydantic_schemas import BookingSchema 449 | from rest_typed_views import typed_api_view, CurrentUser 450 | 451 | @typed_api_view(["POST"]) 452 | def create_booking(booking: BookingSchema, user: User = CurrentUser()): 453 | # Do something with the request.user 454 | 455 | @typed_api_view(["GET"]) 456 | def retrieve_something(first_name: str = CurrentUser(source="first_name")): 457 | # Do something with the request.user's first name 458 | ``` 459 | You can also pass some additional parameters to the `CurrentUser` request element class to implement simple access control: 460 | - `member_of` (str) Validates that the current `request.user` is a member of a group with this name 461 | - `member_of_any` (List[str]) Validates that the current `request.user` is a member of one of these groups 462 | 463 | *Using these keyword validators assumes that your `User` model has a many-to-many relationship with `django.contrib.auth.models.Group` via `user.groups`.* 464 | 465 | An example: 466 | 467 | ```python 468 | from django.contrib.auth.models import User 469 | from rest_typed_views import typed_api_view, CurrentUser 470 | 471 | @typed_api_view(["GET"]) 472 | def do_something(user: User = CurrentUser(member_of="admin")): 473 | # now have a user instance (assuming ValidationError wasn't raised) 474 | ``` 475 | ## Supported Types and Validator Rules 476 | 477 | The following native Python types are supported. Depending on the type, you can pass additional validation rules to the request element class (`Query`, `Path`, `Body`). You can think of the type combining with the validation rules to create a Django REST serializer field on the fly -- in fact, that's what happens behind the scenes. 478 | 479 | ### str 480 | Additional arguments: 481 | - `max_length` Validates that the input contains no more than this number of characters. 482 | - `min_length` Validates that the input contains no fewer than this number of characters. 483 | - `trim_whitespace` (bool; default `True`) Whether to trim leading and trailing white space. 484 | - `format` Validates that the string matches a common format; supported values: 485 | - `email` validates the text to be a valid e-mail address. 486 | - `slug` validates the input against the pattern `[a-zA-Z0-9_-]+`. 487 | - `uuid` validates the input is a valid UUID string 488 | - `url` validates fully qualified URLs of the form `http:///` 489 | - `ip` validates input is a valid IPv4 or IPv6 string 490 | - `ipv4` validates input is a valid IPv4 string 491 | - `ipv6` validates input is a valid IPv6 string 492 | - `file_path` validates that the input corresponds to filenames in a certain directory on the filesystem; allows all the same keyword arguments as Django REST's [`FilePathField`](https://www.django-rest-framework.org/api-guide/fields/#filepathfield) 493 | 494 | Some examples: 495 | 496 | ```python 497 | from rest_typed_views import typed_api_view, Query 498 | 499 | @typed_api_view(["GET"]) 500 | def search_users(email: str = Query(format='email')): 501 | # ORM logic here... 502 | return Response(data) 503 | 504 | @typed_api_view(["GET"]) 505 | def search_shared_links(url: str = Query(default=None, format='url')): 506 | # ORM logic here... 507 | return Response(data) 508 | 509 | @typed_api_view(["GET"]) 510 | def search_request_logs(ip_address: str = Query(default=None, format='ip')): 511 | # ORM logic here... 512 | return Response(data) 513 | ``` 514 | 515 | ### int 516 | Additional arguments: 517 | - `max_value` Validate that the number provided is no greater than this value. 518 | - `min_value` Validate that the number provided is no less than this value. 519 | 520 | An example: 521 | ```python 522 | from rest_typed_views import typed_api_view, Query 523 | 524 | @typed_api_view(["GET"]) 525 | def search_products(inventory: int = Query(min_value=0)): 526 | # ORM logic here... 527 | ``` 528 | 529 | ### float 530 | Additional arguments: 531 | - `max_value` Validate that the number provided is no greater than this value. 532 | - `min_value` Validate that the number provided is no less than this value. 533 | 534 | An example: 535 | ```python 536 | from rest_typed_views import typed_api_view, Query 537 | 538 | @typed_api_view(["GET"]) 539 | def search_products(price: float = Query(min_value=0)): 540 | # ORM logic here... 541 | ``` 542 | 543 | ### Decimal 544 | Additional arguments: 545 | - `max_value` Validate that the number provided is no greater than this value. 546 | - `min_value` Validate that the number provided is no less than this value. 547 | - .. even more ... accepts the same arguments as [Django REST's `DecimalField`](https://www.django-rest-framework.org/api-guide/fields/#decimalfield) 548 | 549 | ### bool 550 | View parameters annotated with this type will validate and coerce the same values as Django REST's `BooleanField`, including but not limited to the following: 551 | ```python 552 | true_values = ["yes", 1, "on", "y", "true"] 553 | false_values = ["no", 0, "off", "n", "false"] 554 | ``` 555 | 556 | ### datetime 557 | Additional arguments: 558 | - `input_formats` A list of input formats which may be used to parse the date-time, defaults to Django's `DATETIME_INPUT_FORMATS` settings, which defaults to `['iso-8601']` 559 | - `default_timezone` A `pytz.timezone` of the timezone. If not specified, falls back to Django's `USE_TZ` setting. 560 | 561 | ### date 562 | Additional arguments: 563 | - `input_formats` A list of input formats which may be used to parse the date, defaults to Django's `DATETIME_INPUT_FORMATS` settings, which defaults to `['iso-8601']` 564 | 565 | ### time 566 | Additional arguments: 567 | - `input_formats` A list of input formats which may be used to parse the time, defaults to Django's `TIME_INPUT_FORMATS` settings, which defaults to `['iso-8601']` 568 | 569 | ### timedelta 570 | Validates strings of the format `'[DD] [HH:[MM:]]ss[.uuuuuu]'` and converts them to a `datetime.timedelta` instance. 571 | 572 | Additional arguments: 573 | - `max_value` Validate that the input duration is no greater than this value. 574 | - `min_value` Validate that the input duration is no less than this value. 575 | 576 | ### List 577 | Validates strings of the format `'[DD] [HH:[MM:]]ss[.uuuuuu]'` and converts them to a `datetime.timedelta` instance. 578 | 579 | Additional arguments: 580 | - `min_length` Validates that the list contains no fewer than this number of elements. 581 | - `max_length` Validates that the list contains no more than this number of elements. 582 | - `child` Pass keyword constraints via a `Param` instance to to validate the members of the list. 583 | 584 | An example: 585 | ```python 586 | from rest_typed_views import typed_api_view, Param, Query 587 | 588 | @typed_api_view(["GET"]) 589 | def search_contacts(emails: List[str] = Query(max_length=10, child=Param(format="email"))): 590 | # ORM logic here... 591 | ``` 592 | 593 | ### Enum 594 | Validates that the value of the input is one of a limited set of choices. Think of this as mapping to a Django REST [`ChoiceField`](https://www.django-rest-framework.org/api-guide/fields/#choicefield). 595 | 596 | An example: 597 | ```python 598 | from rest_typed_views import typed_api_view, Query 599 | 600 | class Straws(str, Enum): 601 | paper = "paper" 602 | plastic = "plastic" 603 | 604 | @typed_api_view(["GET"]) 605 | def search_straws(type: Straws = None): 606 | # ORM logic here... 607 | ``` 608 | 609 | ### marshmallow.Schema 610 | You can annotate view parameters with [Marshmallow schemas](https://marshmallow.readthedocs.io/en/stable/) to validate request data and pass an instance of the schema to the view. 611 | 612 | ```python 613 | from marshmallow import Schema, fields 614 | from rest_typed_views import typed_api_view, Query 615 | 616 | class ArtistSchema(Schema): 617 | name = fields.Str() 618 | 619 | class AlbumSchema(Schema): 620 | title = fields.Str() 621 | release_date = fields.Date() 622 | artist = fields.Nested(ArtistSchema()) 623 | 624 | """ 625 | POST 626 | { 627 | "title": "Michael Scott's Greatest Hits", 628 | "release_date": "2019-03-03", 629 | "artist": { 630 | "name": "Michael Scott" 631 | } 632 | } 633 | """ 634 | @typed_api_view(["POST"]) 635 | def create_album(album: AlbumSchema): 636 | # now have an album instance (assuming ValidationError wasn't raised) 637 | ``` 638 | 639 | ### pydantic.BaseModel 640 | You can annotate view parameters with [Pydantic models](https://pydantic-docs.helpmanual.io/) to validate request data and pass an instance of the model to the view. 641 | 642 | ```python 643 | from pydantic import BaseModel 644 | from rest_typed_views import typed_api_view, Query 645 | 646 | class User(BaseModel): 647 | id: int 648 | name: str 649 | signup_ts: datetime = None 650 | friends: List[int] = [] 651 | 652 | """ 653 | POST 654 | { 655 | "id": 24529782, 656 | "name": "Michael Scott", 657 | "friends": [24529782] 658 | } 659 | """ 660 | @typed_api_view(["POST"]) 661 | def create_user(user: User): 662 | # now have a user instance (assuming ValidationError wasn't raised) 663 | ``` 664 | 665 | ## Change Log 666 | 667 | * June 7, 2020 668 | * Fixes compatability with DRF decorator. Thanks @sjquant! 669 | * Makes Django's QueryDict work with Marshmallow and Pydantic validators. Thanks @filwaline! 670 | * February 2, 2020: Adds support for `Header` request parameter. Thanks @bbkgh! 671 | 672 | ## Motivation 673 | 674 | While REST Framework's ModelViewSets and ModelSerializers are very productive when building out CRUD resources, I've felt less productive in the framework when developing other types of operations. Serializers are a powerful and flexible way to validate incoming request data, but are not as self-documenting as type annotations. Furthermore, the Django ecosystem is hugely productive and I see no reason why REST Framework cannot take advantage of more Python 3 features. 675 | 676 | ## Inspiration 677 | 678 | I first came across type annotations for validation in [API Star](https://github.com/encode/apistar), which has since evolved into an OpenAPI toolkit. This pattern has also been offered by [Hug](https://hugapi.github.io/hug/) and [Molten](https://github.com/Bogdanp/molten) (I believe in that order). Furthermore, I've borrowed ideas from [FastAPI](https://github.com/tiangolo/fastapi), specifically its use of default values to declare additional validation rules. Finally, this [blog post](https://instagram-engineering.com/types-for-python-http-apis-an-instagram-story-d3c3a207fdb7) from Instagram's engineering team showed me how decorators can be used to implement these features on view functions. 679 | 680 | --------------------------------------------------------------------------------