├── api ├── __init__.py ├── migrations │ ├── __init__.py │ └── 0001_initial.py ├── admin.py ├── tests.py ├── apps.py ├── models.py ├── filters.py └── views.py ├── web ├── __init__.py ├── templates │ └── list.html ├── migrations │ └── __init__.py ├── admin.py ├── apps.py ├── models.py ├── views.py └── tests.py ├── django_rest_framework_queryset ├── __init__.py ├── wsgi.py ├── urls.py └── settings.py ├── requirements ├── base.txt └── test.txt ├── requirements.txt ├── setup.cfg ├── MANIFEST.in ├── rest_framework_queryset ├── __init__.py ├── views.py ├── pagination.py └── queryset.py ├── tox.ini ├── .travis.yml ├── manage.py ├── LICENSE ├── .gitignore ├── README.md ├── setup.py └── tests.py /api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /web/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /web/templates/list.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /api/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /web/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /django_rest_framework_queryset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | requests>=2.19 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/base.txt 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | -------------------------------------------------------------------------------- /api/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | # Register your models here. 4 | -------------------------------------------------------------------------------- /api/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /web/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | # Register your models here. 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt 4 | include requirements/* 5 | -------------------------------------------------------------------------------- /api/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ApiConfig(AppConfig): 5 | name = 'api' 6 | -------------------------------------------------------------------------------- /web/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class WebConfig(AppConfig): 5 | name = 'web' 6 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | -r base.txt 2 | Django~=2.1 3 | djangorestframework~=3.8.0 4 | django-filter~=2.0.0 5 | mock==2.0.0 6 | -------------------------------------------------------------------------------- /web/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.db import models 4 | 5 | # Create your models here. 6 | -------------------------------------------------------------------------------- /api/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.db import models 4 | 5 | 6 | class DataModel(models.Model): 7 | value = models.IntegerField() 8 | -------------------------------------------------------------------------------- /rest_framework_queryset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | from __future__ import absolute_import 4 | 5 | __version__ = '0.3.4' 6 | 7 | from .queryset import RestFrameworkQuerySet 8 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py27 3 | skipsdist = True 4 | [testenv] 5 | changedir = {toxinidir}/rest_framework_queryset 6 | commands = {toxinidir}/rest_framework_queryset/manage.py test 7 | deps = -r{toxinidir}/requirements/test.txt 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | env: 5 | - DJANGO_VERSION=2.1 6 | # command to install dependencies 7 | install: 8 | - pip install -r requirements/test.txt 9 | - pip install -I Django==$DJANGO_VERSION 10 | # command to run tests 11 | script: 12 | - cd $TRAVIS_BUILD_DIR && ./manage.py test -------------------------------------------------------------------------------- /api/filters.py: -------------------------------------------------------------------------------- 1 | from .models import DataModel 2 | import django_filters 3 | 4 | 5 | class DataModelFilter(django_filters.FilterSet): 6 | class Meta: 7 | model = DataModel 8 | fields = { 9 | 'id': ['exact'], 10 | 'value': ['exact', 'gt', 'gte', 'lt', 'lte'] 11 | } 12 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_rest_framework_queryset.settings") 7 | 8 | from django.core.management import execute_from_command_line 9 | 10 | execute_from_command_line(sys.argv) 11 | -------------------------------------------------------------------------------- /web/views.py: -------------------------------------------------------------------------------- 1 | from django.shortcuts import render 2 | from django.views.generic import ListView 3 | from rest_framework_queryset import RestFrameworkQuerySet 4 | 5 | 6 | class ListDataView(ListView): 7 | paginate_by = 10 8 | template_name = 'list.html' 9 | 10 | def get_queryset(self, *args, **kwargs): 11 | return RestFrameworkQuerySet('{}/api/'.format(self.request.META['SERVER_URL'])).filter(**self.request.GET.dict()) 12 | -------------------------------------------------------------------------------- /django_rest_framework_queryset/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for rest_framework_queryset project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.9/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "rest_framework_queryset.settings") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /api/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import generics 2 | from rest_framework import serializers 3 | from rest_framework import viewsets 4 | from .models import DataModel 5 | from .filters import DataModelFilter 6 | 7 | 8 | class DataModelSerializer(serializers.ModelSerializer): 9 | class Meta(object): 10 | model = DataModel 11 | fields = '__all__' 12 | 13 | 14 | class ListView(viewsets.ModelViewSet): 15 | serializer_class = DataModelSerializer 16 | queryset = DataModel.objects.all() 17 | filter_class = DataModelFilter 18 | -------------------------------------------------------------------------------- /api/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.9 on 2016-06-22 22:34 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | initial = True 11 | 12 | dependencies = [ 13 | ] 14 | 15 | operations = [ 16 | migrations.CreateModel( 17 | name='DataModel', 18 | fields=[ 19 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 20 | ('value', models.IntegerField()), 21 | ], 22 | ), 23 | ] 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 James Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rest_framework_queryset/views.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import unicode_literals 4 | from __future__ import absolute_import 5 | 6 | 7 | class APISearchableMixin(object): 8 | # request fields that we are interested in 9 | search_fields = [] 10 | 11 | def __init__(self, *args, **kwargs): 12 | super(APISearchableMixin, self).__init__(*args, **kwargs) 13 | self._search_params = {} 14 | 15 | def get_context_data(self, *args, **kwargs): 16 | """ 17 | put search_fields into ctx 18 | """ 19 | ctx = super(APISearchableMixin, self).get_context_data(*args, **kwargs) 20 | ctx['search_fields'] = self.get_search_params() 21 | return ctx 22 | 23 | def get_search_params(self): 24 | for field in self.search_fields: 25 | if self.request.method == 'POST': 26 | self._search_params[field] = self.request.POST.get('__search_{}'.format(field), '') 27 | if self.request.method == 'GET': 28 | self._search_params[field] = self.request.GET.get(field, '') 29 | return self._search_params 30 | 31 | def post(self, request, *args, **kwargs): 32 | return self.get(request, *args, **kwargs) 33 | -------------------------------------------------------------------------------- /web/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import LiveServerTestCase 2 | from api.models import DataModel 3 | 4 | 5 | class ListTestCase(LiveServerTestCase): 6 | def setUp(self): 7 | # create data 8 | for i in range(100): 9 | DataModel.objects.create(value=i) 10 | 11 | def test_list(self): 12 | resp = self.client.get('/list/', SERVER_URL=self.live_server_url) 13 | self.assertEqual([v['value'] for v in resp.context['object_list']], [i for i in range(10)]) 14 | resp = self.client.get('/list/page/2/', SERVER_URL=self.live_server_url) 15 | self.assertEqual([v['value'] for v in resp.context['object_list']], [i for i in range(10, 20)]) 16 | 17 | def test_filter(self): 18 | resp = self.client.get('/list/?value=10', SERVER_URL=self.live_server_url) 19 | self.assertEqual(len(resp.context['object_list']), 1) 20 | self.assertEqual(resp.context['object_list'][0]['value'], 10) 21 | 22 | def test_get(self): 23 | dm = DataModel.objects.first() 24 | resp = self.client.get('/list/?id={}'.format(dm.id), SERVER_URL=self.live_server_url) 25 | self.assertEqual(len(resp.context['object_list']), 1) 26 | self.assertEqual(resp.context['object_list'][0]['value'], dm.value) 27 | -------------------------------------------------------------------------------- /django_rest_framework_queryset/urls.py: -------------------------------------------------------------------------------- 1 | """rest_framework_queryset URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/1.9/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Add an import: from blog import urls as blog_urls 14 | 2. Import the include() function: from django.conf.urls import url, include 15 | 3. Add a URL to urlpatterns: url(r'^blog/', include(blog_urls)) 16 | """ 17 | from django.conf.urls import url 18 | from django.contrib import admin 19 | from rest_framework import routers 20 | from api.views import ListView 21 | from web.views import ListDataView 22 | 23 | router = routers.SimpleRouter() 24 | router.register(r'api', ListView) 25 | 26 | urlpatterns = [ 27 | url(r'^admin/', admin.site.urls), 28 | # url(r'^api/', ListView.as_view()), 29 | url(r'^list/$', ListDataView.as_view()), 30 | url(r'^list/page/(?P\d+)/$', ListDataView.as_view()), 31 | ] + router.urls 32 | -------------------------------------------------------------------------------- /rest_framework_queryset/pagination.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | from rest_framework.pagination import PageNumberPagination, LimitOffsetPagination 4 | 5 | 6 | class HybridPagination(PageNumberPagination): 7 | """ 8 | Basically allows both pagination method to work within a single pagination class. 9 | By default it uses the PageNumberPagination 10 | When 'offset' is used in request.GET, it will switch to use LimitOffsetPagination 11 | """ 12 | page_size = 10 13 | default_limit = 1 14 | 15 | def __init__(self, *args, **kwargs): 16 | self.proxy = None 17 | return super(HybridPagination, self).__init__(*args, **kwargs) 18 | 19 | def paginate_queryset(self, queryset, request, view=None): 20 | if 'offset' in request.GET or 'limit' in request.GET: 21 | self.proxy = LimitOffsetPagination() 22 | return self.proxy.paginate_queryset(queryset, request, view) 23 | return super(HybridPagination, self).paginate_queryset(queryset, request, view) 24 | 25 | def __getattribute__(self, item): 26 | if item in ['paginate_queryset']: 27 | return object.__getattribute__(self, item) 28 | try: 29 | proxy = object.__getattribute__(self, "proxy") 30 | return object.__getattribute__(proxy, item) 31 | except AttributeError: 32 | return object.__getattribute__(self, item) 33 | 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | db.sqlite3 92 | .idea 93 | .idea/*/** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/variable/django-rest-framework-queryset.svg?branch=master)](https://travis-ci.org/variable/django-rest-framework-queryset) 2 | # Django Rest Framework QuerySet 3 | Mimicking the Django ORM queryset over rest framework api, which does lazy loading. 4 | 5 | ## Usage: 6 | 7 | ### normal operation 8 | ```python 9 | from rest_framework_queryset import RestFrameworkQuerySet 10 | from django.core.paginator import Paginator 11 | 12 | qs = RestFrameworkQuerySet('http://localhost:8082/api/') 13 | 14 | # filter 15 | boys = qs.filter(gender='boy') 16 | girls = qs.filter(gender='girl') 17 | 18 | # get by id 19 | boy = qs.get(101) 20 | 21 | # filter enforce 1 result 22 | boy = qs.get(name='james', gender='boy') 23 | 24 | # slicing 25 | first_100_boys = boys[:100] 26 | 27 | # iterate all records 28 | for i in qs: 29 | print(i) 30 | 31 | # pagination 32 | p = Paginator(qs, 10) 33 | print p.count 34 | print p.num_pages 35 | page1 = p.page(1) 36 | ``` 37 | 38 | ### class based view 39 | ```python 40 | from django.views.generic import ListView 41 | from rest_framework_queryset import RestFrameworkQuerySet 42 | 43 | class ListDataView(ListView): 44 | paginate_by = 10 45 | template_name = 'list.html' 46 | 47 | def get_queryset(self, *args, **kwargs): 48 | return RestFrameworkQuerySet('http://localhost:8082/api/').filter(**self.request.GET.dict()) 49 | ``` 50 | 51 | ## Dependencies 52 | The queryset is dependent on the API that uses [LimitOffsetPagination](http://www.django-rest-framework.org/api-guide/pagination/#limitoffsetpagination) 53 | 54 | In this project, it provides a HybridPagination class, which will swap to `LimitOffsetPagination` when it sees `limit` or `offset` query params, 55 | so that if you are currently using `PageNumberPagination` then you can swap it 56 | with `rest_framework_queryset.pagination.HybridPagination` to achieve both purposes. This feature is experimental, so please report any problems. 57 | 58 | ## Compatibility 59 | - Python 2 60 | - Python 3 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import codecs 3 | import os 4 | import re 5 | 6 | from setuptools import setup, find_packages 7 | 8 | 9 | def read(*parts): 10 | filename = os.path.join(os.path.dirname(__file__), *parts) 11 | with codecs.open(filename, encoding='utf-8') as fp: 12 | return fp.read() 13 | 14 | 15 | def find_version(*file_paths): 16 | version_file = read(*file_paths) 17 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 18 | if version_match: 19 | return version_match.group(1) 20 | raise RuntimeError("Unable to find version string.") 21 | 22 | 23 | def get_packages(package): 24 | """ 25 | Return root package and all sub-packages. 26 | """ 27 | return [dirpath 28 | for dirpath, dirnames, filenames in os.walk(package) 29 | if os.path.exists(os.path.join(dirpath, '__init__.py'))] 30 | 31 | 32 | def get_package_data(package): 33 | """ 34 | Return all files under the root package, that are not in a 35 | package themselves. 36 | """ 37 | walk = [(dirpath.replace(package + os.sep, '', 1), filenames) 38 | for dirpath, dirnames, filenames in os.walk(package) 39 | if not os.path.exists(os.path.join(dirpath, '__init__.py'))] 40 | 41 | filepaths = [] 42 | for base, filenames in walk: 43 | filepaths.extend([os.path.join(base, filename) 44 | for filename in filenames]) 45 | return {package: filepaths} 46 | 47 | 48 | setup( 49 | name='django-rest-framework-queryset', 50 | version=find_version('rest_framework_queryset/__init__.py'), 51 | author='James Lin', 52 | author_email='james@lin.net.nz', 53 | long_description='', 54 | install_requires=['requests>=2.1'], 55 | packages=find_packages(exclude=["tests", "api", "api.*", "web", "web.*", "django_rest_framework_queryset"]), 56 | # packages=get_packages('rest_framework_queryset'), 57 | # package_data=get_package_data('rest_framework_queryset'), 58 | license='MIT', 59 | description="Mimicking the Django ORM queryset over rest framework api", 60 | classifiers=[ 61 | "License :: OSI Approved :: MIT License", 62 | "Intended Audience :: Developers", 63 | "Programming Language :: Python", 64 | "Programming Language :: Python :: 2.7", 65 | "Programming Language :: Python :: 3.6", 66 | "Framework :: Django", 67 | ], 68 | ) 69 | -------------------------------------------------------------------------------- /rest_framework_queryset/queryset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | from django.core.exceptions import MultipleObjectsReturned 4 | from django.core.paginator import Paginator 5 | import requests 6 | import copy 7 | 8 | 9 | class BaseAPIQuerySet(object): 10 | def __init__(self, url, *args, **kwargs): 11 | _req_session = requests.Session() 12 | self.request_method = _req_session.get 13 | self.url = url 14 | self.args = args 15 | self.kwargs = kwargs 16 | 17 | def _call_api(self): 18 | """ 19 | perform api call 20 | """ 21 | return self.request_method(self.url, *self.args, **self.kwargs) 22 | 23 | def __iter__(self): 24 | return iter(self.get_result()) 25 | 26 | def __len__(self): 27 | return self.count() 28 | 29 | def __getitem__(self, index): 30 | if isinstance(index, int): 31 | return self.get_result()[index] 32 | elif isinstance(index, slice): 33 | return self.page_result(index) 34 | 35 | def _clone(self): 36 | kwargs = copy.deepcopy(self.kwargs) 37 | args = copy.deepcopy(self.args) 38 | url = copy.copy(self.url) 39 | cloned = self.__class__(url, *args, **kwargs) 40 | cloned.request_method = self.request_method 41 | return cloned 42 | 43 | def count(self): 44 | raise NotImplementedError() 45 | 46 | def filter(self, **kargs): 47 | raise NotImplementedError() 48 | 49 | def all(self, **kargs): 50 | raise NotImplementedError() 51 | 52 | def get_result(self): 53 | raise NotImplementedError() 54 | 55 | def page_result(self): 56 | raise NotImplementedError() 57 | 58 | 59 | class RestFrameworkQuerySet(BaseAPIQuerySet): 60 | page_size = 100 61 | 62 | def __init__(self, *args, **kwargs): 63 | super(RestFrameworkQuerySet, self).__init__(*args, **kwargs) 64 | self.__id = None 65 | 66 | def __iter__(self): 67 | paginator = Paginator(self, self.page_size) 68 | for page in paginator.page_range: 69 | for item in paginator.page(page).object_list: 70 | yield item 71 | 72 | def _clone(self): 73 | cloned = super(RestFrameworkQuerySet, self)._clone() 74 | cloned.page_size = self.page_size 75 | return cloned 76 | 77 | def _call_api(self): 78 | if self.__id: 79 | self.url = '{}/{}/'.format(self.url.rstrip('/'), self.__id) 80 | return super(RestFrameworkQuerySet, self)._call_api() 81 | 82 | def count(self): 83 | cloned = self._clone() 84 | params = cloned.kwargs.get('params', {}) 85 | params['offset'] = 0 86 | params['limit'] = 0 87 | resp = cloned._call_api() 88 | result = resp.json() 89 | return result['count'] 90 | 91 | def get_result(self): 92 | response = self._call_api() 93 | result = response.json() 94 | if 'results' in result: 95 | return result['results'] 96 | return result 97 | 98 | def page_result(self, slicer): 99 | cloned = self._clone() 100 | params = cloned.kwargs.setdefault('params', {}) 101 | params['offset'] = slicer.start 102 | params['limit'] = slicer.stop - slicer.start 103 | return cloned.get_result() 104 | 105 | def filter(self, **kwargs): 106 | cloned = self._clone() 107 | params = cloned.kwargs.setdefault('params', {}) 108 | params.update(kwargs) 109 | return cloned 110 | 111 | def get(self, __id=None, **kwargs): 112 | cloned = self.filter(**kwargs) 113 | cloned.__id = __id 114 | result = cloned.get_result() 115 | if isinstance(result, list): 116 | if len(result) > 1: 117 | raise MultipleObjectsReturned('get() returned more than one result, it returned {}'.format(cloned.count())) 118 | return result[0] 119 | return result 120 | 121 | def all(self): 122 | return self._clone() 123 | -------------------------------------------------------------------------------- /django_rest_framework_queryset/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for rest_framework_queryset project. 3 | 4 | Generated by 'django-admin startproject' using Django 1.9. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.9/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/1.9/ref/settings/ 11 | """ 12 | 13 | import os 14 | os.environ['DJANGO_LIVE_TEST_SERVER_ADDRESS'] = 'localhost:8082' 15 | 16 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 17 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 18 | 19 | 20 | # Quick-start development settings - unsuitable for production 21 | # See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/ 22 | 23 | # SECURITY WARNING: keep the secret key used in production secret! 24 | SECRET_KEY = '=!eduxyj9rn5nvmnxhvbjsy&0ir-clss&mt=*9)m6^_k2!&b*1' 25 | 26 | # SECURITY WARNING: don't run with debug turned on in production! 27 | DEBUG = True 28 | 29 | ALLOWED_HOSTS = [] 30 | 31 | 32 | # Application definition 33 | 34 | INSTALLED_APPS = [ 35 | 'django.contrib.admin', 36 | 'django.contrib.auth', 37 | 'django.contrib.contenttypes', 38 | 'django.contrib.sessions', 39 | 'django.contrib.messages', 40 | 'django.contrib.staticfiles', 41 | 'rest_framework', 42 | 'api', 43 | 'web', 44 | 'django_filters', 45 | ] 46 | 47 | MIDDLEWARE_CLASSES = [ 48 | 'django.middleware.security.SecurityMiddleware', 49 | 'django.contrib.sessions.middleware.SessionMiddleware', 50 | 'django.middleware.common.CommonMiddleware', 51 | 'django.middleware.csrf.CsrfViewMiddleware', 52 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 53 | 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', 54 | 'django.contrib.messages.middleware.MessageMiddleware', 55 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 56 | ] 57 | 58 | ROOT_URLCONF = 'django_rest_framework_queryset.urls' 59 | 60 | TEMPLATES = [ 61 | { 62 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 63 | 'DIRS': [], 64 | 'APP_DIRS': True, 65 | 'OPTIONS': { 66 | 'context_processors': [ 67 | 'django.template.context_processors.debug', 68 | 'django.template.context_processors.request', 69 | 'django.contrib.auth.context_processors.auth', 70 | 'django.contrib.messages.context_processors.messages', 71 | ], 72 | }, 73 | }, 74 | ] 75 | 76 | WSGI_APPLICATION = 'django_rest_framework_queryset.wsgi.application' 77 | 78 | 79 | # Database 80 | # https://docs.djangoproject.com/en/1.9/ref/settings/#databases 81 | 82 | DATABASES = { 83 | 'default': { 84 | 'ENGINE': 'django.db.backends.sqlite3', 85 | 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 86 | } 87 | } 88 | 89 | 90 | # Password validation 91 | # https://docs.djangoproject.com/en/1.9/ref/settings/#auth-password-validators 92 | 93 | AUTH_PASSWORD_VALIDATORS = [ 94 | { 95 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 96 | }, 97 | { 98 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 99 | }, 100 | { 101 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 102 | }, 103 | { 104 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 105 | }, 106 | ] 107 | 108 | 109 | # Internationalization 110 | # https://docs.djangoproject.com/en/1.9/topics/i18n/ 111 | 112 | LANGUAGE_CODE = 'en-us' 113 | 114 | TIME_ZONE = 'UTC' 115 | 116 | USE_I18N = True 117 | 118 | USE_L10N = True 119 | 120 | USE_TZ = True 121 | 122 | 123 | # Static files (CSS, JavaScript, Images) 124 | # https://docs.djangoproject.com/en/1.9/howto/static-files/ 125 | 126 | STATIC_URL = '/static/' 127 | 128 | REST_FRAMEWORK = { 129 | 'DEFAULT_FILTER_BACKENDS': ('django_filters.rest_framework.DjangoFilterBackend',), 130 | 'DEFAULT_PAGINATION_CLASS': 'rest_framework_queryset.pagination.HybridPagination', 131 | 'PAGE_SIZE': 30 132 | } 133 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | from __future__ import absolute_import 4 | from django.test import TestCase, LiveServerTestCase 5 | from django.core.paginator import Paginator 6 | from django.core.exceptions import MultipleObjectsReturned 7 | from api.models import DataModel 8 | from rest_framework_queryset import RestFrameworkQuerySet 9 | from mock import patch, MagicMock 10 | 11 | 12 | class RestFrameworkQuerySetTestCase(TestCase): 13 | 14 | def test_filter(self): 15 | qs = RestFrameworkQuerySet('/api/') 16 | qs1 = qs.filter(a=123) 17 | self.assertTrue('params' not in qs.kwargs, "qs should not have params set as it's cloned in filter()") 18 | self.assertEqual(qs1.kwargs['params']['a'], 123) 19 | 20 | def test_filter_chain(self): 21 | qs = RestFrameworkQuerySet('/api/') 22 | qs1 = qs.filter(a=123) 23 | self.assertEqual(qs1.kwargs['params']['a'], 123) 24 | qs2 = qs1.filter(b=234) 25 | self.assertEqual(qs2.kwargs['params']['b'], 234) 26 | 27 | qs = RestFrameworkQuerySet('/api/') 28 | qs1 = qs.filter(a=123).filter(b=234) 29 | self.assertEqual(qs1.kwargs['params']['a'], 123) 30 | self.assertEqual(qs1.kwargs['params']['b'], 234) 31 | 32 | def test_filter_call(self): 33 | fake_response = MagicMock(json=lambda:{'count': 10, 'results': range(10)}) 34 | with patch('rest_framework_queryset.queryset.requests.Session.get', return_value=fake_response) as mock_get: 35 | qs = RestFrameworkQuerySet('/api/') 36 | qs1 = qs.filter(a=123) 37 | self.assertEqual(list(qs1), list(range(10))) 38 | mock_get.assert_any_call('/api/', params={'a': 123, 'offset': 0, 'limit': 10}) 39 | qs2 = qs1.filter(b=234) 40 | list(qs2) # execute 41 | mock_get.assert_any_call('/api/', params={'a': 123, 'b': 234, 'offset': 0, 'limit': 10}) 42 | 43 | def test_get_call(self): 44 | fake_response = MagicMock(json=lambda:{'count': 10, 'results': list(range(10))}) 45 | with patch('rest_framework_queryset.queryset.requests.Session.get', return_value=fake_response) as mock_get: 46 | qs = RestFrameworkQuerySet('/api/') 47 | with self.assertRaises(MultipleObjectsReturned): 48 | qs1 = qs.get(a=123) 49 | self.assertEqual(list(qs), list(range(10))) 50 | mock_get.assert_any_call('/api/', params={'a': 123}) 51 | 52 | def test_get_call_by_id(self): 53 | fake_response = MagicMock(json=lambda:{'a': 123}) 54 | with patch('rest_framework_queryset.queryset.requests.Session.get', return_value=fake_response) as mock_get: 55 | qs = RestFrameworkQuerySet('/api/') 56 | qs1 = qs.get(123) 57 | self.assertEqual(qs1, {'a': 123}) 58 | mock_get.assert_any_call('/api/123/', params={}) 59 | 60 | def test_count_call(self): 61 | fake_response = MagicMock(json=lambda:{'count': 10, 'results': range(10)}) 62 | with patch('rest_framework_queryset.queryset.requests.Session.get', return_value=fake_response) as mock_get: 63 | count = RestFrameworkQuerySet('/api/').count() 64 | self.assertEqual(count, 10) 65 | 66 | def test_all(self): 67 | fake_response = MagicMock(json=lambda:{'count': 10, 'results': range(10)}) 68 | with patch('rest_framework_queryset.queryset.requests.Session.get', return_value=fake_response) as mock_get: 69 | qs = RestFrameworkQuerySet('/api/').all() 70 | self.assertEqual(list(qs), list(range(10))) 71 | 72 | 73 | class APILiveServerTestCase(LiveServerTestCase): 74 | def test_pagination(self): 75 | for i in range(100): 76 | DataModel.objects.create(value=i) 77 | qs = RestFrameworkQuerySet('{}/api/'.format(self.live_server_url)) 78 | p = Paginator(qs, 10) 79 | self.assertEqual(p.count, 100) 80 | self.assertEqual(p.num_pages, 10) 81 | page2 = p.page(2) 82 | item_list = [item['value'] for item in page2.object_list] 83 | self.assertEqual(item_list, list(range(10, 20))) 84 | 85 | def test_list_all(self): 86 | DataModel.objects.bulk_create([DataModel(value=i) for i in range(1000)]) 87 | qs = RestFrameworkQuerySet('{}/api/'.format(self.live_server_url)) 88 | item_list = [item['value'] for item in list(qs)] 89 | self.assertEqual(item_list, list(range(1000))) 90 | 91 | def test_list_filter(self): 92 | DataModel.objects.bulk_create([DataModel(value=i) for i in range(1000)]) 93 | qs = RestFrameworkQuerySet('{}/api/'.format(self.live_server_url)) 94 | qs = qs.filter(value__gt=300) 95 | item_list = [item['value'] for item in list(qs)] 96 | self.assertEqual(item_list, list(range(301, 1000))) 97 | 98 | def test_slice(self): 99 | DataModel.objects.bulk_create([DataModel(value=i) for i in range(1000)]) 100 | qs = RestFrameworkQuerySet('{}/api/'.format(self.live_server_url)) 101 | item_list = [item['value'] for item in qs[200:1000]] 102 | self.assertEqual(item_list, list(range(200, 1000))) 103 | 104 | def test_get(self): 105 | DataModel.objects.bulk_create([DataModel(value=i) for i in range(100)]) 106 | qs = RestFrameworkQuerySet('{}/api/'.format(self.live_server_url)) 107 | item = qs.get(1) 108 | self.assertEqual(item['id'], 1) 109 | item = qs.get(2) 110 | self.assertEqual(item['id'], 2) 111 | --------------------------------------------------------------------------------