├── .whitesource ├── LICENSE ├── MANIFEST.in ├── README.rst ├── rest_framework_drilldown ├── __init__.py ├── models.py ├── tests.py ├── urls.py └── views.py └── setup.py /.whitesource: -------------------------------------------------------------------------------- 1 | { 2 | "generalSettings": { 3 | "shouldScanRepo": true 4 | }, 5 | "checkRunSettings": { 6 | "vulnerableCheckRunConclusionLevel": "success" 7 | }, 8 | "issueSettings": { 9 | "minSeverityLevel": "LOW" 10 | } 11 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Peter Hollingsworth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.rst 3 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Drilldown API 3 | ============= 4 | 5 | Extends Django REST Framework to create instant full-featured GET APIs with fields, filters, offset, 6 | limit, etc., including the ability to drill down into chained objects (foreignKey, manyToMany, oneToOne fields). 7 | 8 | You create just one simple view per API; you do not need to create separate serializers for the 9 | chained objects. 10 | 11 | 12 | Quickstart 13 | ---------- 14 | Example adapted from code in tests.py. 15 | 16 | 1. Install the package: 17 | ``pip install git+https://github.com/peterh32/django-rest-framework-drilldown.git`` 18 | 19 | 20 | 2. Create a view that's a subclass of DrillDownAPIView (use your own models; "Invoice" is just an example):: 21 | 22 | from rest_framework_drilldown import DrillDownAPIView 23 | 24 | class InvoiceList(DrillDownAPIView): 25 | """A GET API for Invoice objects""" 26 | # Primary model for the API (required) 27 | model = Invoice 28 | 29 | # Global throttle for the API. Defaults to 1000. If your query hits MAX_RESULTS, 30 | # this is noted in X-Query_Warning in the response header. 31 | MAX_RESULTS = 5000 32 | 33 | # The picky flag defaults to False; if True, then any unidentifiable param in the 34 | # request will result in an error. 35 | #picky = True 36 | 37 | # Optional list of chained foreignKey, manyToMany, and oneToOne objects 38 | # your users can drill down into -- note that you do not need to build 39 | # separate serializers for these; DrilldownAPI builds them dynamically. 40 | drilldowns = ['client__profile', 'salesperson__profile', 'items'] 41 | 42 | # Optional list of fields to ignore if they appear in the request 43 | ignore = ['fakefield'] 44 | 45 | # Optional list of fields that your users are not allowed to 46 | # see or query 47 | hide = ['salesperson__commission_pct'] 48 | 49 | def get_base_query(self): 50 | # Base query for your class, typically just '.objects.all()' 51 | return Invoice.objects.all() 52 | 53 | 54 | 3. In urls.py, create a URL for the view: 55 | ``url(r'^invoices/$', InvoiceList.as_view(), name='invoices'),`` 56 | 57 | 4. Start running queries! Some of the things you can do: 58 | 59 | * Limit and offset: 60 | ``/invoices/?limit=10&offset=60`` 61 | 62 | Does just what you'd expect. The total number of results is returned in a custom header code: ``X-Total-Count: 2034`` 63 | 64 | * Specify fields to include, including "drilldown" fields: 65 | ``/invoices/?fields=id,client.profile.first_name,client.profile.last_name`` 66 | 67 | Returns invoices showing just the invoice ID and the client's first and last name. 68 | 69 | * Filter on fields: 70 | ``/invoices/?total__gte=100&salesperson.last_name__iexact=smith`` 71 | 72 | Lists invoices where total >= $100 and salesperson is "Smith". 73 | 74 | * Filter on dates and booleans: 75 | ``/invoices/?paid=false&bill_date__lt=2014-05-01`` 76 | 77 | 78 | Dates are formatted YYYY-MM-DD and booleans may be true, True, false, or False. 79 | 80 | * Use the 'ALL' keyword to return all fields in an object: 81 | ``/invoices/?fields=salesperson.ALL`` 82 | 83 | Lists the salesperson for each invoice; will display all salesperson fields 84 | EXCEPT commission_pct which is in the "hide" list in the API above. 85 | 86 | * Use order_by, including - sign for reverse: 87 | ``/invoices/?order_by=client.profile.last_name,-amount`` 88 | 89 | Returns invoices ordered by associated client's last name, from highest to lowest amount. 90 | 91 | Total number of results for each query (before applying limit and offset) are returned in a custom header code: 92 | ``X-Total-Count: 2034`` 93 | 94 | 95 | Errors and warnings are also returned in a custom header code. Errors get status 400; warnings are status 200. 96 | ``X-Query_Error: error text`` 97 | ``X-Query_Warning: warning text`` 98 | 99 | Also supports format parameter, e.g. ?format=json 100 | 101 | POST requests 102 | ------------- 103 | DrillDownAPIView overrides the Django REST Framework's get() method. It does not affect post() and other methods 104 | at all, so your DrillDownAPIView class may include a standard Django REST Framework post() method. 105 | 106 | Solutions for Common Problems 107 | ----------------------------- 108 | * Access Control: 109 | In your API view, override the get() method and add your access control to it:: 110 | 111 | @method_decorator(accounting_permission_required) 112 | def get(self, request): 113 | return super(InvoiceList, self).get(request) 114 | 115 | 116 | * Custom Queries: 117 | Assume that invoices > $1000 require prior authorization, and you'd like to support that as a simple query: 118 | 119 | ``/invoices/?requires_authorization=True`` 120 | 121 | 1. Add your field to the ignore list in the API view: 122 | ``ignore = ['requires_authorizaton']`` 123 | 124 | 2. Add the logic for handling the new filter to ``get_base_query()`` in the API view:: 125 | 126 | def get_base_query(self): 127 | qs = Invoice.objects.all() 128 | if self.request.GET.get('requires_authorizaton'): 129 | requires_authorization = self.request.GET['requires_authorization'] 130 | if requires_authorization == 'True': 131 | qs = qs.filter(total__gt=1000) 132 | elif requires_authorization == 'False': 133 | qs = qs.exclude(total__gt=1000) 134 | return qs 135 | 136 | Now you can query for ``requires_authorization=True`` or ``requires_authorization=False``. 137 | -------------------------------------------------------------------------------- /rest_framework_drilldown/__init__.py: -------------------------------------------------------------------------------- 1 | from .views import DrillDownAPIView 2 | 3 | VERSION = (0, 1, 1) 4 | __version__ = VERSION # alias -------------------------------------------------------------------------------- /rest_framework_drilldown/models.py: -------------------------------------------------------------------------------- 1 | # need empty file to make tests work 2 | -------------------------------------------------------------------------------- /rest_framework_drilldown/tests.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | from django.db import models 3 | from django.conf import settings 4 | from django.test import TestCase 5 | from django.test.client import RequestFactory 6 | 7 | from .views import DrillDownAPIView 8 | 9 | 10 | # Create some database models 11 | class test_Profile(models.Model): 12 | first_name = models.CharField(max_length=20) 13 | last_name = models.CharField(max_length=20) 14 | spy_name = models.CharField(max_length=10) 15 | 16 | 17 | class test_Client(models.Model): 18 | wholesale = models.BooleanField(default=False) 19 | profile = models.ForeignKey(test_Profile) 20 | 21 | 22 | class test_Salesperson(models.Model): 23 | commission_pct = models.IntegerField(default=10) 24 | profile = models.ForeignKey(test_Profile) 25 | 26 | 27 | class test_Item(models.Model): 28 | description = models.TextField(max_length=100) 29 | price = models.DecimalField(decimal_places=2, max_digits=8, null=True) 30 | 31 | 32 | class test_Invoice(models.Model): 33 | client = models.ForeignKey(test_Client) 34 | salesperson = models.ForeignKey(test_Salesperson, null=True) 35 | items = models.ManyToManyField(test_Item, related_name='invoice') 36 | total = models.DecimalField(decimal_places=2, max_digits=8, default=Decimal('0')) 37 | 38 | def save(self, *args, **kwargs): 39 | if self.id: 40 | self.total = sum([i.price for i in self.items.all()]) 41 | super(test_Invoice, self).save(*args, **kwargs) 42 | 43 | 44 | # Build an API 45 | class DrilldownTestAPI(DrillDownAPIView): 46 | """A GET API for test_Invoice objects""" 47 | authentication_classes = () # turn off authentication for the test 48 | permission_classes = () 49 | 50 | model = test_Invoice 51 | drilldowns = ['client__profile', 'salesperson__profile', 'items'] # objects you are allowed to drill down in 52 | ignore = ['fakefield'] # fields to ignore in the request 53 | hide = ['salesperson__commission_pct'] # fields that you cannot access or use in filters 54 | 55 | def get_base_query(self): 56 | return test_Invoice.objects.all() 57 | 58 | 59 | # Build another, but one that's picky 60 | class PickyTestAPI(DrilldownTestAPI): 61 | """a subclass of DrilldownTestAPI that will fail with bad fields""" 62 | picky = True 63 | 64 | 65 | # Build another with a low limit 66 | class TwoItemMaxTestAPI(DrilldownTestAPI): 67 | """a subclass of DrilldownTestAPI that will fail with bad fields""" 68 | MAX_RESULTS = 2 69 | 70 | 71 | class DrilldownAPITest(TestCase): 72 | def setUp(self): 73 | mary_smith = test_Profile(last_name='Smith', first_name='Mary', spy_name='Mango') 74 | joe_dokes = test_Profile(last_name='Dokes', first_name='Joe', spy_name='Bravo') 75 | bob_dobbs = test_Profile(last_name='Dobbs', first_name='Bob', spy_name='Catgut') 76 | ann_ames = test_Profile(last_name='Ames', first_name='Ann', spy_name='Pegleg') 77 | for profile in [mary_smith, joe_dokes, bob_dobbs, ann_ames]: 78 | profile.save() 79 | 80 | client_mary = test_Client(wholesale=True, profile=mary_smith) 81 | client_joe = test_Client(wholesale=False, profile=joe_dokes) 82 | salesperson_bob = test_Salesperson(commission_pct=8, profile=ann_ames) 83 | salesperson_ann = test_Salesperson(profile=bob_dobbs) 84 | client_mary.save() 85 | client_joe.save() 86 | salesperson_bob.save() 87 | salesperson_ann.save() 88 | 89 | items = [('Tape', '2.20'), ('Dog bowl', '4.00'), ('Hat', '10.00'), ('Tire', '30.00'), ('Mouse', '20'), 90 | ('Sandwich', '5.50'), ('Audi', '10000'), ('Coffee', '1.50'), ('Chair', '75'), ('Eggs', '3.50'), 91 | ('Pencils', '0.05'), ('Spoon', '2.00'), ] 92 | for item in items: 93 | i = test_Item(description=item[0], price=Decimal(item[1])) 94 | i.save() 95 | 96 | def create_invoice(client, salesperson, items): 97 | i = test_Invoice(client=client, salesperson=salesperson) 98 | i.save() 99 | i.items.add(*items) 100 | i.save() 101 | 102 | items = test_Item.objects.all() 103 | create_invoice(client_mary, salesperson_bob, items[0:3]) 104 | create_invoice(client_mary, salesperson_bob, items[3:4]) 105 | create_invoice(client_joe, salesperson_ann, items[4:6]) 106 | create_invoice(client_joe, None, items[6:9]) 107 | create_invoice(client_joe, salesperson_ann, items[9:12]) 108 | 109 | self.factory = RequestFactory() 110 | 111 | def test_the_api(self): 112 | my_view = DrilldownTestAPI.as_view() 113 | picky_view = PickyTestAPI.as_view() 114 | two_item_view = TwoItemMaxTestAPI.as_view() 115 | 116 | # set debug true so that API will return X-Query-Count (number of queries run) 117 | saved_debug = settings.DEBUG 118 | settings.DEBUG = True 119 | 120 | def get_response(data): 121 | return my_view(self.factory.get('/url/', data, content_type='application/json')) 122 | 123 | def picky_response(data): 124 | return picky_view(self.factory.get('/url/', data, content_type='application/json')) 125 | 126 | def two_item_response(data): 127 | return two_item_view(self.factory.get('/url/', data, content_type='application/json')) 128 | 129 | # return all results 130 | response = get_response({}) 131 | self.assertEqual(len(response.data), 5) 132 | # test w MAX_RESULTS 133 | response = two_item_response({}) 134 | self.assertEqual(len(response.data), 2) 135 | self.assertTrue('hit global maximum' in response.get('X-Query_Warning')) 136 | 137 | # a filter 138 | response = get_response({'salesperson.profile.first_name': 'Ann'}) 139 | self.assertEqual(len(response.data), 2) 140 | 141 | # isnull 142 | response = get_response({'salesperson__isnull': 'true'}) 143 | self.assertEqual(len(response.data), 1) 144 | self.assertEqual(int(response.get('X-Query-Count', 0)), 1) # should only need one query 145 | 146 | # fields 147 | response = get_response({ 148 | 'salesperson__isnull': 'false', 149 | 'fields': 'salesperson.profile.first_name,salesperson.profile.last_name' 150 | }) 151 | self.assertEqual(len(response.data[0]['salesperson']['profile']), 2) # each record contains only the two fields 152 | self.assertEqual(int(response.get('X-Query-Count', 0)), 1) # should only need one query 153 | 154 | # complicated 155 | response = get_response({ 156 | 'salesperson__isnull': 'false', 157 | 'fields': 'salesperson.profile.first_name,items.price,client.profile.first_name' 158 | }) 159 | self.assertIsInstance(response.data[0]['items'][0]['price'], Decimal) 160 | self.assertTrue(int(response.get('X-Query-Count', 0)) <= 2) # should only need 1-2 queries 161 | 162 | # on a manytomany field if you don't specify subfields, returns a flat list of ids 163 | response = get_response({'fields': 'items'}) 164 | # data should look something like this: [{'items': [1, 2]}], NOT [{'items': [{'id': 1}, {'id': 2}]}] 165 | self.assertTrue(type(response.data[0]['items'][0]) is int) 166 | 167 | # try with limit 168 | response = get_response({'limit': 1}) 169 | self.assertEqual(len(response.data), 1) 170 | self.assertEqual(int(response.get('X-Total-Count', 0)), 5) 171 | 172 | # try with offset 173 | response = get_response({'offset': 3}) 174 | self.assertEqual(len(response.data), 2) 175 | self.assertEqual(int(response.get('X-Total-Count', 0)), 5) 176 | 177 | # both, with arbitrary high limit 178 | response = get_response({'offset': 2, 'limit': 100}) 179 | self.assertEqual(len(response.data), 3) 180 | self.assertEqual(int(response.get('X-Total-Count', 0)), 5) 181 | 182 | # zero results 183 | response = get_response({'salesperson.profile.first_name': 'Fred'}) 184 | self.assertEqual(response.status_code, 200) # not an error 185 | self.assertEqual(len(response.data), 0) 186 | 187 | # a bad filter 188 | data = {'salesperson.profile.dog_name': 'Freddyboy', 'total__gt': 25} 189 | response = get_response(data) 190 | self.assertEqual(response.status_code, 200) # returns results 191 | self.assertEqual(len(response.data), 3) # ignores bad field, just returns list of matching invoices 192 | self.assertIsNone(response.get('X-Query_Error')) # no error 193 | self.assertTrue('dog_name' in response.get('X-Query_Warning')) # but returns warning 194 | # with picky option on 195 | response = picky_response(data) 196 | self.assertEqual(response.status_code, 400) # error; no results 197 | self.assertEqual(len(response.data), 0) 198 | self.assertTrue('dog_name' in response.get('X-Query_Error')) 199 | 200 | # 2 bad filters 201 | data = {'foo': '12', 'bar__lt': 11, 'total__gt': 25} 202 | response = get_response(data) 203 | self.assertEqual(response.status_code, 200) # returns results 204 | self.assertEqual(len(response.data), 3) # ignores bad field, just returns list of matching invoices 205 | self.assertIsNone(response.get('X-Query_Error')) # no error 206 | self.assertTrue('foo' in response.get('X-Query_Warning') and 'bar' in response.get('X-Query_Warning')) 207 | 208 | # a bad field 209 | data = {'fields': 'salesperson.profile.first_name,monkey'} 210 | response = get_response(data) 211 | self.assertEqual(response.status_code, 400) # error! 212 | self.assertEqual(len(response.data), 0) 213 | 214 | # an ignore field, with picky option 215 | response = picky_response({'fakefield__lt': '3000', 'limit': 3}) 216 | self.assertEqual(response.status_code, 200) # no error, as 'fakefield' is in the ignore list 217 | self.assertIsNone(response.get('X-Query_Warning')) # no warning -- it's in the ignore list 218 | self.assertEqual(len(response.data), 3) 219 | 220 | # ALL selector 221 | response = get_response({'fields': 'client.profile.ALL'}) 222 | self.assertEqual(len(response.data[0]['client']['profile']), 4) # 4 fields including id 223 | 224 | # a hide field -- should not show up in results 225 | response = get_response({'salesperson__isnull': 'false', 'fields': 'salesperson.ALL'}) 226 | self.assertIsNone(response.data[0]['salesperson'].get('commission_pct')) # commission_pct is a hide field 227 | self.assertIsNotNone(response.data[0]['salesperson'].get('profile')) 228 | 229 | settings.DEBUG = saved_debug # revert settings 230 | -------------------------------------------------------------------------------- /rest_framework_drilldown/urls.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clearcare/django-rest-framework-drilldown/7e95c430adb902c673e9ca2a6f398635cce20601/rest_framework_drilldown/urls.py -------------------------------------------------------------------------------- /rest_framework_drilldown/views.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from django.db import connection 3 | from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, OneToOneRel 4 | from django.core.exceptions import FieldError 5 | from rest_framework import serializers 6 | from rest_framework.response import Response 7 | from rest_framework.views import APIView 8 | 9 | 10 | class DrillDownAPIView(APIView): 11 | """ 12 | Subclass this to create an instant GET API with fields, filters, etc. 13 | 14 | Supports the following GET request parameters (shown with examples): 15 | format: 16 | format=json 17 | 18 | limit and offset: 19 | limit=10&offset=60 20 | 21 | order_by: 22 | order_by=-client.profile.first_name < order by associated client's first name, in reverse order 23 | 24 | fields: 25 | fields=id,client.profile.first_name < returns the item id and associated client's first name 26 | fields=client < returns the associated client (will just be an id) 27 | fields=client.ALL < ALL keyword, will return the entire record (flat) for the associated client 28 | (Note that ability to drill down to a sub-model is constrained by drilldowns setting in the view.) 29 | 30 | other parameters are treated as filters: 31 | client.profile.first_name__istartswith=pete < associated client's name starts with 'pete' 32 | paid=true < paid is true (True/true and False/false will all work) 33 | amount__gt=100 < amount greater than 100 34 | (Multiple filters can be combined. Filterable objects are also constrained to those in the drilldowns list.) 35 | 36 | Returns results with header codes: 37 | X-Total-Count: the total match count before applying limit or offset 38 | X-Query_Error: any errors, usually returned with status 400 39 | X-Query_Warning: warning, returned with status 200 40 | """ 41 | drilldowns = None # override this to allow drilldowns into sub-objects 42 | ignore = None # override this with fieldnames that should be ignored in the GET request 43 | hide = None 44 | model = None # override this with the model 45 | picky = False # if true, will error 400 if any bad fields are included 46 | MAX_RESULTS = 1000 # max result count; can override in your api 47 | 48 | def __init__(self, *args, **kwargs): 49 | self.error = '' 50 | self.warning = '' 51 | 52 | # deal with None's that should be arrays 53 | self.ignore = self.ignore or [] 54 | self.drilldowns = self.drilldowns or [] 55 | self.hide = self.hide or [] 56 | 57 | # These will go into the query 58 | self.select_relateds = [] 59 | self.prefetch_relateds = [] 60 | 61 | # Non-model fields that may be in the query 62 | self.hide_fields = [h.replace('__', '.') for h in self.hide] 63 | self.ignore_fields = set(['fields', 'limit', 'offset', 'format', 'order_by'] + self.ignore + self.hide_fields) 64 | 65 | super(DrillDownAPIView, self).__init__(*args, **kwargs) 66 | 67 | def get_base_query(self): # override this to return your base query 68 | return None 69 | 70 | def get(self, request): 71 | """The main method in this object; handles a GET request with filters, fields, etc.""" 72 | if settings.DEBUG: 73 | num_queries = len(connection.queries) # just for testing 74 | 75 | data = {} 76 | headers = {} 77 | 78 | def _result(status=200): 79 | if self.warning: 80 | headers['X-Query_Warning'] = self.warning 81 | return Response(data, 82 | headers=headers, 83 | status=status) 84 | 85 | def _error(msg='Error'): 86 | headers['X-Query_Error'] = msg 87 | return _result(status=400) 88 | 89 | fields = self.request.QUERY_PARAMS.get('fields', []) 90 | fields = fields and fields.split(',') # fields parameter is a comma-delimited list 91 | 92 | # get parameters that will be used as filters 93 | filters = {} 94 | request_params = self.request.QUERY_PARAMS 95 | for f in request_params: 96 | if f.split('__')[0] not in self.ignore_fields: # split so you catch things like "invoice.total__lt=10000" 97 | filters[f] = request_params[f] 98 | 99 | qs = self.get_base_query() 100 | if qs is None: 101 | return _error('API error: get_base_query() missing or invalid') 102 | 103 | # Get the complete set of drilldowns 104 | self.drilldowns = self._validate_drilldowns(self.drilldowns) 105 | if self.error: 106 | return _error(self.error) 107 | 108 | # Create the fields_map (a multi-level dictionary describing the fields to be returned) 109 | self.fields_map = self._create_fields_map(fields) 110 | if self.error: 111 | return _error(self.error) 112 | 113 | # Get filters, validate against drilldowns 114 | self.filter_kwargs = self._set_filter_kwargs(filters) 115 | if self.error: 116 | return _error(self.error) 117 | 118 | # Add our relateds to the query 119 | if self.select_relateds: 120 | qs = qs.select_related(*self.select_relateds) 121 | if self.prefetch_relateds: 122 | qs = qs.prefetch_related(*self.prefetch_relateds) 123 | 124 | # Add our filters to the query 125 | try: 126 | qs = qs.filter(**self.filter_kwargs) 127 | except FieldError: 128 | qs = qs.none() 129 | return _error('Bad filter parameter in query.') 130 | except ValueError: 131 | qs = qs.none() 132 | return _error('Bad filter value in query') 133 | queryset_for_count = qs # saving this off so we can use it later before we add limits 134 | 135 | # Deal with ordering 136 | order_by = self.request.QUERY_PARAMS.get('order_by', '').replace('.', '__') 137 | # would be nice to validate order_by fields here, but difficult, so we just trap below 138 | if order_by: 139 | order_by = order_by.split(',') 140 | qs = qs.order_by(*order_by) 141 | 142 | # Deal with offset and limit 143 | self.offset = int_or_none(self.request.QUERY_PARAMS.get('offset')) or 0 144 | self.limit = int_or_none(self.request.QUERY_PARAMS.get('limit')) or 0 145 | self.limit = min(self.MAX_RESULTS, self.limit or self.MAX_RESULTS) 146 | if self.limit and self.offset: 147 | qs = qs[self.offset:self.limit + self.offset] 148 | elif self.limit: 149 | qs = qs[:self.limit] 150 | elif self.offset: 151 | qs = qs[self.offset:] 152 | 153 | # create the chained serializer 154 | serializer = DrilldownSerializerFactory(self.model)(fields_map=self.fields_map, instance=qs, many=True) 155 | 156 | # return the response 157 | try: 158 | data = serializer.data 159 | except FieldError: 160 | return _error('Error: May be bad field name in order_by') # typical error 161 | 162 | # get total count if 1) your count = the limit, or 2) the query has an offset. 163 | if self.offset or (len(data) and len(data) == self.limit): 164 | total_count = queryset_for_count.count() 165 | if len(data) == self.MAX_RESULTS: 166 | self.warning += 'Number of results hit global maximum (%s results). ' % self.MAX_RESULTS 167 | else: 168 | total_count = len(data) 169 | headers = {'X-Total-Count': total_count} 170 | if settings.DEBUG: 171 | headers['X-Query-Count'] = len(connection.queries) - num_queries 172 | return _result() 173 | 174 | # Various Methods # 175 | # Validate the list of drilldowns and fill in any gaps; returns array of drilldowns 176 | def _validate_drilldowns(self, drilldowns): 177 | ERROR_STRING = 'Error in drilldowns' 178 | validated_drilldowns = [] 179 | 180 | def validate_me(current_model, dd_string, current_string=''): 181 | pair = dd_string.split('__', 1) 182 | fieldname = (pair[0]).strip() 183 | if not is_field_in(current_model, fieldname): 184 | self.error = ('%s: "%s" is not a valid field in %s. Remember __ not .)' % 185 | (ERROR_STRING, fieldname, current_model.__name__)) 186 | return None 187 | new_model = get_model(current_model, fieldname) 188 | if not new_model: 189 | self.error = ('%s: "%s" is not a ForeignKey, ManyToMany, OneToOne, ManyToOneRel, or OneToOneRel.' 190 | % (ERROR_STRING, fieldname)) 191 | return None 192 | current_string = (current_string + '__' + fieldname).strip('__') 193 | 194 | # note that we add missing intermediate models, e.g. 'client' if list included 'client__profile' 195 | if current_string not in validated_drilldowns: 196 | validated_drilldowns.append(current_string) 197 | # if there's more, keep drilling 198 | if len(pair) > 1: 199 | validate_me(new_model, pair[1], current_string) # recursion 200 | 201 | for dd in drilldowns: 202 | validate_me(self.model, dd) 203 | if ERROR_STRING in self.error: 204 | validated_drilldowns = [] 205 | return validated_drilldowns 206 | 207 | def _create_fields_map(self, fields): 208 | """Take the list of fields submitted in the query and turn it into a multi-level tree dict""" 209 | fields_map = {} 210 | ERROR_STRING = 'Error in fields' 211 | 212 | def add_to_fields_map(current_model, current_map, dot_string, current_related=''): 213 | pair = dot_string.split('.', 1) 214 | fieldname = (pair[0]).strip() 215 | there_are_subfields = len(pair) > 1 216 | if not (fieldname == 'ALL' or is_field_in(current_model, fieldname)): # ALL is allowed in fields_map 217 | self.error = ('%s: "%s" is not a valid field' % (ERROR_STRING, dot_string)) 218 | return None 219 | 220 | if fieldname == 'ALL': 221 | # add in all the fields for the model 222 | fname_prefix = current_related.replace('__', '.') + '.' 223 | for fname in current_model._meta.get_all_field_names(): 224 | if (fname_prefix + fname).strip('.') in self.hide_fields: 225 | continue # skip it 226 | field_type = get_field_type(current_model, fname) 227 | # don't add the field if it's a related field and out of drilldowns range 228 | if field_type in [ManyToManyField]: 229 | temp = (current_related + '__' + fname).strip('__') 230 | if temp not in self.drilldowns: 231 | continue # don't add this one 232 | add_to_fields_map(current_model, current_map, dot_string=fname, current_related=current_related) 233 | else: 234 | # add it to the map 235 | if current_map.get(fieldname) is None: 236 | current_map[fieldname] = {} 237 | # drill down one level in the map 238 | current_map = current_map[fieldname] 239 | # see if the field is a related one 240 | new_model = get_model(current_model, fieldname) 241 | field_type = get_field_type(current_model, fieldname) 242 | if new_model and (field_type == ManyToManyField or there_are_subfields): 243 | # Add field to select_related or prefetch_relateds 244 | current_related = (current_related + '__' + fieldname).strip('__') 245 | if current_related in self.drilldowns: 246 | field_type = get_field_type(current_model, fieldname) 247 | if field_type in [ForeignKey, OneToOneField, ManyToOneRel, OneToOneRel]: 248 | self.select_relateds.append(current_related) 249 | else: 250 | self.prefetch_relateds.append(current_related) 251 | else: 252 | self.error = ('%s: %s is not valid' % (ERROR_STRING, current_related.replace('__', '.'))) 253 | return None 254 | 255 | # Add sub-field to fields_map 256 | if there_are_subfields: 257 | add_to_fields_map(new_model, current_map, pair[1], current_related) # recurse 258 | else: 259 | add_to_fields_map(new_model, current_map, 'id') # defaults to return the id only 260 | elif there_are_subfields: # requested a sub-field for a field that's not a model, e.g. amount.profile 261 | self.error = ('%s: %s not valid field' % (ERROR_STRING, dot_string)) 262 | return None 263 | 264 | for fieldname in fields: 265 | add_to_fields_map(self.model, fields_map, fieldname) 266 | 267 | if ERROR_STRING in self.error: 268 | fields_map = {} 269 | return fields_map 270 | 271 | def _set_relateds(self, fields_map): 272 | """Go through the fields_map and see what related objs should be added to the querystring""" 273 | def add_to_relateds(current_model, current_map, fieldname, current_string=''): 274 | # figure out if the field should be a prefetch or select, and add it. Also validate against drilldowns 275 | if current_map[fieldname]: # e.g. if there are sub-fields 276 | field_type = get_field_type(current_model, fieldname) 277 | current_string = (current_string + '__' + fieldname).strip('__') 278 | if field_type in [ForeignKey, OneToOneField, ManyToOneRel, OneToOneRel, ManyToManyField]: 279 | if not current_string in self.drilldowns: 280 | self.error = ('Error: %s not valid' % current_string.replace('__', '.')) 281 | return None 282 | if field_type in [ForeignKey, OneToOneField, ManyToOneRel, OneToOneRel]: 283 | self.select_relateds.append(current_string) 284 | else: 285 | self.prefetch_relateds.append(current_string) 286 | 287 | new_model = get_model(current_model, fieldname) 288 | for f in current_map.get(fieldname, {}): 289 | add_to_relateds(new_model, current_map[fieldname], f, current_string) # recursion 290 | 291 | for fieldname in fields_map: 292 | add_to_relateds(self.model, fields_map, fieldname) 293 | return True 294 | 295 | def _set_filter_kwargs(self, filters): 296 | """Create the kwargs to filter the querystring with""" 297 | filter_kwargs = {} 298 | for p in filters: 299 | pair = p.split('__') 300 | dot_string = pair[0] 301 | if len(pair) > 1: 302 | operation = '__' + pair[1] # 'operation' is something like '__gt', '__isnull', etc. 303 | else: 304 | operation = '' 305 | 306 | def do_filter(dot_string, filter_string, current_model): 307 | """ 308 | Recursive function that takes 'invoice.client.last_name' 309 | and puts out a string like 'invoice__client__last_name' after validating that all the fields are 310 | valid and accessible to the user 311 | """ 312 | parts = dot_string.split('.', 1) 313 | fieldname = parts[0] 314 | filter_string = (filter_string + '__' + fieldname).strip('__') 315 | if len(parts) > 1: 316 | leftover = parts[1] 317 | else: 318 | leftover = '' 319 | 320 | if not is_field_in(current_model, fieldname): 321 | if self.picky: 322 | self.error = ('"%s" is not a valid filter' % fieldname) 323 | else: 324 | self.warning += '"%s" is not a valid parameter. ' % filter_string.replace('__', '.') 325 | return None 326 | 327 | if leftover: 328 | field_type = get_field_type(current_model, fieldname) 329 | if filter_string not in self.drilldowns: 330 | if self.picky: 331 | self.error = 'Error in filters: %s' % filter_string.replace('__', '.') 332 | else: 333 | self.warning += '"%s" is not a valid parameter. ' % filter_string.replace('__', '.') 334 | return None 335 | if field_type not in [ForeignKey, OneToOneField, ManyToOneRel, OneToOneRel, ManyToManyField]: 336 | if self.picky: 337 | self.error = ('Error: %s has no children' % filter_string) 338 | else: 339 | self.warning += '"%s" is not a valid parameter. ' % filter_string.replace('__', '.') 340 | return None 341 | 342 | # go to the related model 343 | current_model = get_model(current_model, fieldname) 344 | return do_filter(leftover, filter_string, current_model) # recursion 345 | else: 346 | return filter_string 347 | 348 | filter_string = do_filter(dot_string, '', self.model) 349 | 350 | if filter_string: 351 | # __in operator requires a list to work with so making it a special case for now. 352 | # And adding it in list to support multiple operators that require similar logic 353 | if operation in ["__in"]: 354 | comma_separated_multiple_values = self.request.QUERY_PARAMS[p] 355 | filter_kwargs[filter_string + operation] = comma_separated_multiple_values.split(",") 356 | else: 357 | filter_kwargs[filter_string + operation] = self.request.QUERY_PARAMS[p] 358 | 359 | for k in filter_kwargs: 360 | if filter_kwargs[k] in ['true', 'True']: 361 | filter_kwargs[k] = True 362 | elif filter_kwargs[k] in ['false', 'False']: 363 | filter_kwargs[k] = False 364 | 365 | return filter_kwargs 366 | 367 | 368 | def DrilldownSerializerFactory(the_model): 369 | """Creates a generic model serializer with sub-serializers, based on the fields map """ 370 | class Serializer(serializers.ModelSerializer): 371 | class Meta: 372 | model = the_model 373 | 374 | def __init__(self, *args, **kwargs): 375 | # pull off the fields_map argument; don't pass to superclass 376 | if 'fields_map' in kwargs: 377 | fields_map = kwargs.pop('fields_map') 378 | fields_map = fields_map 379 | else: 380 | fields_map = {} 381 | 382 | super(Serializer, self).__init__(*args, **kwargs) 383 | 384 | if fields_map: 385 | # recurse through the fields dict, setting the fields list for each level and building sub-serializers 386 | def prune_fields(fields_map, model): 387 | # Set the list of fields for this serializer 388 | requested = list(fields_map) # flatten to get fields requested for this specific serializer model 389 | available = list(self.fields) # by default this is all fields for the model 390 | for field_name in set(available) - set(requested): 391 | self.fields.pop(field_name) # delete the ones we don't want from the serializer 392 | if field_name in fields_map: # and from fields_map 393 | del fields_map[field_name] 394 | # Attach sub-serializers for relationship fields 395 | for field_name in fields_map: 396 | sub_fm = fields_map[field_name] 397 | if sub_fm and sub_fm != {'id': {}}: # only do this for fields with sub-fields requested 398 | ftype = get_field_type(model, field_name) 399 | if ftype in [ForeignKey, OneToOneField, ManyToOneRel, OneToOneRel, ManyToManyField]: 400 | m = get_model(model, field_name) 401 | self.fields[field_name] = DrilldownSerializerFactory(m)( 402 | fields_map=fields_map[field_name]) # recursively create another serializer 403 | 404 | prune_fields(fields_map=fields_map, model=self.Meta.model) 405 | else: 406 | # if no fields specified, return ids only 407 | for field_name in set(self.fields): 408 | if field_name != 'id': 409 | self.fields.pop(field_name) 410 | return Serializer 411 | 412 | 413 | # Some utilities 414 | def get_model(parent_model, fieldname): 415 | """Get the model of a foreignkey, manytomany, etc. field""" 416 | field_class = parent_model._meta.get_field(fieldname) 417 | field_type = type(field_class) 418 | if field_type in [ForeignKey, ManyToManyField, OneToOneField]: 419 | model = parent_model._meta.get_field(fieldname).rel.to 420 | elif field_type == ManyToOneRel: 421 | model = parent_model._meta.get_field_by_name(fieldname)[0].model 422 | elif field_type == OneToOneRel: 423 | model = field_class.related_model 424 | else: 425 | model = None 426 | return model 427 | 428 | 429 | def get_field_type(model, fieldname): 430 | """Get the type of a field in a model""" 431 | return type(model._meta.get_field_by_name(fieldname)[0]) 432 | 433 | 434 | def is_field_in(model, fieldname): 435 | """Return true if fieldname is a field or relatedobject in model""" 436 | fieldnames = model._meta.get_all_field_names() 437 | return fieldname in fieldnames 438 | 439 | 440 | def int_or_none(value): 441 | """Convenience method to return None if int fails""" 442 | try: 443 | result = int(value) 444 | except (ValueError, TypeError): 445 | result = None 446 | return result 447 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | README = open(os.path.join(os.path.dirname(__file__), 'README.rst')).read() 5 | 6 | # allow setup.py to be run from any path 7 | os.chdir(os.path.normpath(os.path.join(os.path.abspath(__file__), os.pardir))) 8 | 9 | setup( 10 | name='djangorestframework-drilldown', 11 | version='0.1.1', 12 | url='http://github.com/peterh32/django-rest-framework-drilldown', 13 | license='MIT', 14 | packages=['rest_framework_drilldown'], 15 | include_package_data=True, 16 | description='Django REST API extension enables chained relations, filters, field selectors, limit, offset, etc., via a single view.', 17 | long_description=README, 18 | author='Peter Hollingsworth', 19 | author_email='peter@hollingsworth.net', 20 | install_requires=['djangorestframework<3.9.4'], 21 | ) 22 | --------------------------------------------------------------------------------