├── .circleci └── config.yml ├── .gitignore ├── AUTHORS ├── LICENSE ├── README.md ├── example ├── __init__.py ├── app.py ├── documents.py └── schemas.py ├── flask_mongorest ├── __init__.py ├── authentication.py ├── exceptions.py ├── methods.py ├── mongorest.py ├── operators.py ├── resources.py ├── templates │ └── mongorest │ │ └── debug.html ├── utils.py └── views.py ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests └── __init__.py └── tox.ini /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | workflows: 4 | version: 2 5 | workflow: 6 | jobs: 7 | - test-3.7 8 | - test-3.8 9 | - test-3.9 10 | - static-code-analysis 11 | 12 | defaults: &defaults 13 | working_directory: ~/code 14 | steps: 15 | - checkout 16 | - run: 17 | name: Install dependencies 18 | command: pip install --user -r requirements.txt nose 19 | - run: 20 | name: Test 21 | command: nosetests 22 | 23 | jobs: 24 | static-code-analysis: 25 | docker: 26 | - image: circleci/python:3.8 27 | working_directory: ~/code 28 | steps: 29 | - checkout 30 | 31 | - run: 32 | name: Prepare Environment 33 | command: pip install --user -r requirements.txt lintlizard 34 | 35 | - run: 36 | name: lintlizard 37 | command: lintlizard 38 | 39 | test-3.7: 40 | <<: *defaults 41 | docker: 42 | - image: circleci/python:3.7 43 | - image: mongo:3.2.19 44 | test-3.8: 45 | <<: *defaults 46 | docker: 47 | - image: circleci/python:3.8 48 | - image: mongo:3.2.19 49 | test-3.9: 50 | <<: *defaults 51 | docker: 52 | - image: circleci/python:3.9 53 | - image: mongo:3.2.19 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[co] 2 | 3 | # Packages 4 | *.egg 5 | *.egg-info 6 | dist 7 | build 8 | eggs 9 | parts 10 | bin 11 | var 12 | sdist 13 | develop-eggs 14 | .installed.cfg 15 | .eggs 16 | 17 | # Installer logs 18 | pip-log.txt 19 | 20 | # Unit test / coverage reports 21 | .coverage 22 | .tox 23 | 24 | #Translations 25 | *.mo 26 | 27 | #Mr Developer 28 | .mr.developer.cfg 29 | 30 | #sphinx docs 31 | docs/_build 32 | 33 | venv 34 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Anthony Nemitz -- http://github.com/anemitz 2 | Thomas Steinacher -- http://thomasst.ch 3 | Phil Freo -- http://philfreo.com 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010-2012 See AUTHORS. 2 | 3 | Some rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are 7 | met: 8 | 9 | * Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above 13 | copyright notice, this list of conditions and the following 14 | disclaimer in the documentation and/or other materials provided 15 | with the distribution. 16 | 17 | * The names of the contributors may not be used to endorse or 18 | promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | *This project isn't actively maintained. We do not recommend it for production use.* 2 | 3 | Flask-MongoRest [![Build Status](https://circleci.com/gh/closeio/flask-mongorest.png?branch=master&style=shield)](https://circleci.com/gh/closeio/flask-mongorest) 4 | =============== 5 | A Restful API framework wrapped around MongoEngine. 6 | 7 | Setup 8 | ===== 9 | 10 | ``` python 11 | from flask import Flask 12 | from flask_mongoengine import MongoEngine 13 | from flask_mongorest import MongoRest 14 | from flask_mongorest.views import ResourceView 15 | from flask_mongorest.resources import Resource 16 | from flask_mongorest import operators as ops 17 | from flask_mongorest import methods 18 | 19 | 20 | app = Flask(__name__) 21 | 22 | app.config.update( 23 | MONGODB_HOST = 'localhost', 24 | MONGODB_PORT = '27017', 25 | MONGODB_DB = 'mongorest_example_app', 26 | ) 27 | 28 | db = MongoEngine(app) 29 | api = MongoRest(app) 30 | 31 | class User(db.Document): 32 | email = db.EmailField(unique=True, required=True) 33 | 34 | class Content(db.EmbeddedDocument): 35 | text = db.StringField() 36 | 37 | class ContentResource(Resource): 38 | document = Content 39 | 40 | class Post(db.Document): 41 | title = db.StringField(max_length=120, required=True) 42 | author = db.ReferenceField(User) 43 | content = db.EmbeddedDocumentField(Content) 44 | 45 | class PostResource(Resource): 46 | document = Post 47 | related_resources = { 48 | 'content': ContentResource, 49 | } 50 | filters = { 51 | 'title': [ops.Exact, ops.Startswith], 52 | 'author_id': [ops.Exact], 53 | } 54 | rename_fields = { 55 | 'author': 'author_id', 56 | } 57 | 58 | @api.register(name='posts', url='/posts/') 59 | class PostView(ResourceView): 60 | resource = PostResource 61 | methods = [methods.Create, methods.Update, methods.Fetch, methods.List] 62 | ``` 63 | 64 | With this app, following cURL commands could be used: 65 | ``` 66 | Create a Post: 67 | curl -H "Content-Type: application/json" -X POST -d \ 68 | '{"title": "First post!", "author_id": "author_id_from_a_previous_api_call", "content": {"text": "this is our test post content"}}' http://0.0.0.0:5000/posts/ 69 | { 70 | "id": "1", 71 | "title": "First post!", 72 | "author_id": "author_id_from_a_previous_api_call", 73 | "content": { 74 | "text": "this is our test post content" 75 | } 76 | } 77 | ``` 78 | Get a Post: 79 | ``` 80 | curl http://0.0.0.0:5000/posts/1/ 81 | { 82 | "id": "1", 83 | "title": "First post!", 84 | "author_id": "author_id_from_a_previous_api_call", 85 | "content": { 86 | "text": "this is our test post content" 87 | } 88 | } 89 | ``` 90 | List all Posts or filter by the title: 91 | ``` 92 | curl http://0.0.0.0:5000/posts/ or curl http://0.0.0.0:5000/posts/?title__startswith=First%20post 93 | { 94 | "data": [ 95 | { 96 | "id": "1", 97 | "title": "First post!", 98 | "author_id": "author_id_from_a_previous_api_call", 99 | "content": { 100 | "text": "this is our test post content" 101 | } 102 | }, 103 | ... other posts 104 | ] 105 | } 106 | ``` 107 | Delete a Post: 108 | ``` 109 | curl -X DELETE http://0.0.0.0:5000/posts/1/ 110 | # Fails since PostView.methods does not allow Delete 111 | ``` 112 | 113 | Request Params 114 | ============== 115 | 116 | **_skip** and **_limit** => utilize the built-in functions of mongodb. 117 | 118 | **_fields** => limit the response's fields to those named here (comma separated). 119 | 120 | **_order_by** => order results if this string is present in the Resource.allowed_ordering list. 121 | 122 | 123 | Resource Configuration 124 | ====================== 125 | 126 | **rename_fields** => dict of renaming rules. Useful for mapping _id fields such as "organization": "organization_id" 127 | 128 | **filters** => filter results of a List request using the allowed filters which are used like `/user/?id__gt=2` or `/user/?email__exact=a@b.com` 129 | 130 | **related_resources** => nested resource serialization for reference/embedded fields of a document 131 | 132 | **child_document_resources** => Suppose you have a Person base class which has Male and Female subclasses. These subclasses and their respective resources share the same MongoDB collection, but have different fields and serialization characteristics. This dictionary allows you to map class instances to their respective resources to be used during serialization. 133 | 134 | Authentication 135 | ============== 136 | The AuthenticationBase class provides the ability for application's to implement their own API auth. Two common patterns are shown below along with a BaseResourceView which can be used as the parent View of all of your app's resources. 137 | ``` python 138 | class SessionAuthentication(AuthenticationBase): 139 | def authorized(self): 140 | return current_user.is_authenticated() 141 | 142 | class ApiKeyAuthentication(AuthenticationBase): 143 | """ 144 | @TODO ApiKey document and key generation left to the specific implementation 145 | """ 146 | def authorized(self): 147 | if 'AUTHORIZATION' in request.headers: 148 | authorization = request.headers['AUTHORIZATION'].split() 149 | if len(authorization) == 2 and authorization[0].lower() == 'basic': 150 | try: 151 | authorization_parts = base64.b64decode(authorization[1]).partition(':') 152 | key = smart_unicode(authorization_parts[0]) 153 | api_key = ApiKey.objects.get(key__exact=key) 154 | if api_key.user: 155 | login_user(api_key.user) 156 | setattr(current_user, 'api_key', api_key) 157 | return True 158 | except (TypeError, UnicodeDecodeError, ApiKey.DoesNotExist): 159 | pass 160 | return False 161 | 162 | class BaseResourceView(ResourceView): 163 | authentication_methods = [SessionAuthentication, ApiKeyAuthentication] 164 | ``` 165 | 166 | Running the test suite 167 | ====================== 168 | This package uses nosetests for automated testing. Just run `python setup.py nosetests` to run the tests. No setup or any other prep needed. 169 | 170 | Contributing 171 | ============ 172 | Pull requests are greatly appreciated! 173 | -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closeio/flask-mongorest/992258d5da59947d39bb45f68e3acb04d020975c/example/__init__.py -------------------------------------------------------------------------------- /example/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from flask import Flask, request 4 | from flask_mongoengine import MongoEngine 5 | 6 | from example import documents, schemas 7 | from flask_mongorest import MongoRest, operators as ops 8 | from flask_mongorest.authentication import AuthenticationBase 9 | from flask_mongorest.methods import * 10 | from flask_mongorest.resources import Resource 11 | from flask_mongorest.views import ResourceView 12 | 13 | app = Flask(__name__) 14 | 15 | app.url_map.strict_slashes = False 16 | 17 | app.config.update( 18 | DEBUG=True, 19 | TESTING=True, 20 | MONGODB_SETTINGS={ 21 | "HOST": "localhost", 22 | "PORT": 27017, 23 | "DB": "mongorest_example_app", 24 | "TZ_AWARE": False, 25 | }, 26 | ) 27 | 28 | db = MongoEngine(app) 29 | api = MongoRest(app) 30 | 31 | 32 | class UserResource(Resource): 33 | document = documents.User 34 | schema = schemas.User 35 | filters = {"datetime": [ops.Exact]} 36 | 37 | 38 | @api.register() 39 | class UserView(ResourceView): 40 | resource = UserResource 41 | methods = [Create, Update, Fetch, List, Delete] 42 | 43 | 44 | class ContentResource(Resource): 45 | document = documents.Content 46 | 47 | 48 | class PostResource(Resource): 49 | document = documents.Post 50 | schema = schemas.Post 51 | related_resources = { 52 | "content": ContentResource, 53 | "sections": ContentResource, # nested complex objects 54 | #'author': UserResource, 55 | #'editor': UserResource, 56 | #'user_lists': UserResource, 57 | "primary_user": UserResource, 58 | } 59 | filters = { 60 | "title": [ops.Exact, ops.Startswith, ops.In(allow_negation=True)], 61 | "author_id": [ops.Exact], 62 | "is_published": [ops.Boolean], 63 | } 64 | rename_fields = {"author": "author_id"} 65 | bulk_update_limit = 10 66 | 67 | def get_objects(self, **kwargs): 68 | qs, has_more = super(PostResource, self).get_objects(**kwargs) 69 | return qs, has_more, {"more": "stuff"} 70 | 71 | def get_fields(self): 72 | fields = super(PostResource, self).get_fields() 73 | if "_include_primary_user" in request.args: 74 | fields = set(fields) | {"primary_user"} 75 | return fields 76 | 77 | def update_object(self, obj, data=None, save=True, parent_resources=None): 78 | data = data or self.data 79 | if data.get("author"): 80 | author = data["author"] 81 | if author.email == "vincent@vangogh.com": 82 | obj.tags.append("art") 83 | return super(PostResource, self).update_object( 84 | obj, data, save, parent_resources 85 | ) 86 | 87 | 88 | @api.register(name="posts", url="/posts/") 89 | class PostView(ResourceView): 90 | resource = PostResource 91 | methods = [Create, Update, BulkUpdate, Fetch, List, Delete] 92 | 93 | 94 | class LimitedPostResource(Resource): 95 | document = documents.Post 96 | related_resources = {"content": ContentResource} 97 | 98 | 99 | @api.register(name="limited_posts", url="/limited_posts/") 100 | class LimitedPostView(ResourceView): 101 | resource = LimitedPostResource 102 | methods = [Create, Update, Fetch, List] 103 | 104 | 105 | class DummyAuthenication(AuthenticationBase): 106 | def authorized(self): 107 | return False 108 | 109 | 110 | @api.register(name="auth", url="/auth/") 111 | class DummyAuthView(ResourceView): 112 | resource = PostResource 113 | methods = [Create, Update, Fetch, List, Delete] 114 | authentication_methods = [DummyAuthenication] 115 | 116 | 117 | @api.register(name="restricted", url="/restricted/") 118 | class RestrictedPostView(ResourceView): 119 | """This class allows us to put restrictions in place regarding 120 | who/what can be read, changed, added or deleted. 121 | """ 122 | 123 | resource = PostResource 124 | methods = [Create, Update, Fetch, List, Delete] 125 | 126 | # Can't read a post if it isn't published 127 | def has_read_permission(self, request, qs): 128 | return qs.filter(is_published=True) 129 | 130 | # Can't add a post in a published state 131 | def has_add_permission(self, request, obj): 132 | return not obj.is_published 133 | 134 | # Can't change a post if it is published 135 | def has_change_permission(self, request, obj): 136 | return not obj.is_published 137 | 138 | # Can't delete a post if it is published 139 | def has_delete_permission(self, request, obj): 140 | return not obj.is_published 141 | 142 | 143 | class TestDocument(db.Document): 144 | name = db.StringField() 145 | other = db.StringField() 146 | dictfield = db.DictField() 147 | is_new = db.BooleanField() 148 | email = db.EmailField() 149 | 150 | 151 | class TestResource(Resource): 152 | document = TestDocument 153 | 154 | 155 | class TestFieldsResource(Resource): 156 | document = TestDocument 157 | fields = ["id", "name", "upper_name"] 158 | 159 | def upper_name(self, obj): 160 | return obj.name.upper() 161 | 162 | 163 | @api.register(name="test", url="/test/") 164 | class TestView(ResourceView): 165 | resource = TestResource 166 | methods = [Create, Update, Fetch, List] 167 | 168 | 169 | @api.register(name="testfields", url="/testfields/") 170 | class TestFieldsResource(ResourceView): 171 | resource = TestFieldsResource 172 | methods = [Create, Update, Fetch, List] 173 | 174 | 175 | class LanguageResource(Resource): 176 | document = documents.Language 177 | 178 | 179 | class PersonResource(Resource): 180 | document = documents.Person 181 | schema = schemas.Person 182 | related_resources = {"languages": LanguageResource} 183 | save_related_fields = ["languages"] 184 | 185 | 186 | @api.register(name="person", url="/person/") 187 | class PersonView(ResourceView): 188 | resource = PersonResource 189 | methods = [Create, Update, Fetch, List] 190 | 191 | 192 | # extra resources for testing max_limit 193 | class Post10Resource(PostResource): 194 | max_limit = 10 195 | 196 | 197 | class Post250Resource(PostResource): 198 | max_limit = 250 199 | 200 | 201 | @api.register(name="posts10", url="/posts10/") 202 | class Post10View(ResourceView): 203 | resource = Post10Resource 204 | methods = [Create, Update, BulkUpdate, Fetch, List, Delete] 205 | 206 | 207 | @api.register(name="posts250", url="/posts250/") 208 | class Post250View(ResourceView): 209 | resource = Post250Resource 210 | methods = [Create, Update, BulkUpdate, Fetch, List, Delete] 211 | 212 | 213 | # Documents, resources, and views for testing differences between db refs and object ids 214 | class A(db.Document): 215 | txt = db.StringField() 216 | 217 | 218 | class B(db.Document): 219 | ref = db.ReferenceField(A, dbref=True) 220 | txt = db.StringField() 221 | 222 | 223 | class C(db.Document): 224 | ref = db.ReferenceField(A) 225 | txt = db.StringField() 226 | 227 | 228 | class AResource(Resource): 229 | document = A 230 | 231 | 232 | class BResource(Resource): 233 | document = B 234 | 235 | 236 | class CResource(Resource): 237 | document = C 238 | 239 | 240 | @api.register(url="/a/") 241 | class AView(ResourceView): 242 | resource = AResource 243 | methods = [Create, Update, BulkUpdate, Fetch, List, Delete] 244 | 245 | 246 | @api.register(url="/b/") 247 | class BView(ResourceView): 248 | resource = BResource 249 | methods = [Create, Update, BulkUpdate, Fetch, List, Delete] 250 | 251 | 252 | @api.register(url="/c/") 253 | class CView(ResourceView): 254 | resource = CResource 255 | methods = [Create, Update, BulkUpdate, Fetch, List, Delete] 256 | 257 | 258 | # Documents, resources, and views for testing method permissions 259 | class MethodTestDoc(db.Document): 260 | txt = db.StringField() 261 | 262 | 263 | class MethodTestResource(Resource): 264 | document = MethodTestDoc 265 | 266 | 267 | @api.register(url="/create_only/") 268 | class CreateOnlyView(ResourceView): 269 | resource = MethodTestResource 270 | methods = [Create] 271 | 272 | 273 | @api.register(url="/update_only/") 274 | class UpdateOnlyView(ResourceView): 275 | resource = MethodTestResource 276 | methods = [Update] 277 | 278 | 279 | @api.register(url="/bulk_update_only/") 280 | class BulkUpdateOnlyView(ResourceView): 281 | resource = MethodTestResource 282 | methods = [BulkUpdate] 283 | 284 | 285 | @api.register(url="/fetch_only/") 286 | class FetchOnlyView(ResourceView): 287 | resource = MethodTestResource 288 | methods = [Fetch] 289 | 290 | 291 | @api.register(url="/list_only/") 292 | class ListOnlyView(ResourceView): 293 | resource = MethodTestResource 294 | methods = [List] 295 | 296 | 297 | @api.register(url="/delete_only/") 298 | class DeleteOnlyView(ResourceView): 299 | resource = MethodTestResource 300 | methods = [Delete] 301 | 302 | 303 | class ViewMethodTestDoc(db.Document): 304 | txt = db.StringField() 305 | 306 | 307 | class ViewMethodTestResource(Resource): 308 | document = ViewMethodTestDoc 309 | 310 | 311 | @api.register(url="/test_view_method/") 312 | class TestViewMethodView(ResourceView): 313 | resource = ViewMethodTestResource 314 | methods = [Create, Update, BulkUpdate, Fetch, List, Delete] 315 | 316 | def _dispatch_request(self, *args, **kwargs): 317 | super(TestViewMethodView, self)._dispatch_request(*args, **kwargs) 318 | return {"method": self._resource.view_method.__name__} 319 | 320 | 321 | class DateTimeResource(Resource): 322 | document = documents.DateTime 323 | schema = schemas.DateTime 324 | 325 | 326 | @api.register(name="datetime", url="/datetime/") 327 | class DateTimeView(ResourceView): 328 | resource = DateTimeResource 329 | methods = [Create, Update, Fetch, List] 330 | 331 | 332 | # Document, resource, and view for testing invalid JSON 333 | class DictDoc(db.Document): 334 | dict = db.DictField() 335 | 336 | 337 | class DictDocResource(Resource): 338 | document = DictDoc 339 | 340 | 341 | @api.register(url="/dict_doc/") 342 | class DictDocView(ResourceView): 343 | resource = DictDocResource 344 | methods = [Fetch, List, Create, Update] 345 | 346 | 347 | if __name__ == "__main__": 348 | port = int(os.environ.get("PORT", 8000)) 349 | app.run(host="0.0.0.0", port=port) 350 | -------------------------------------------------------------------------------- /example/documents.py: -------------------------------------------------------------------------------- 1 | from mongoengine import * 2 | 3 | 4 | class DateTime(Document): 5 | datetime = DateTimeField() 6 | 7 | 8 | class Language(Document): 9 | name = StringField() 10 | 11 | 12 | class Person(Document): 13 | name = StringField() 14 | languages = ListField(ReferenceField(Language)) 15 | 16 | 17 | class User(Document): 18 | email = EmailField(unique=True, required=True) 19 | first_name = StringField(max_length=50) 20 | last_name = StringField(max_length=50) 21 | emails = ListField(EmailField()) 22 | datetime = DateTimeField() 23 | datetime_local = DateTimeField() 24 | balance = IntField() # in cents 25 | 26 | 27 | class Content(EmbeddedDocument): 28 | text = StringField() 29 | lang = StringField(max_length=3) 30 | 31 | 32 | class Post(Document): 33 | title = StringField(max_length=120, required=True) 34 | description = StringField(max_length=120, required=False) 35 | author = ReferenceField(User) 36 | editor = ReferenceField(User) 37 | tags = ListField(StringField(max_length=30)) 38 | try: 39 | user_lists = ListField(SafeReferenceField(User)) 40 | except NameError: 41 | user_lists = ListField(ReferenceField(User)) 42 | sections = ListField(EmbeddedDocumentField(Content)) 43 | content = EmbeddedDocumentField(Content) 44 | is_published = BooleanField() 45 | 46 | def primary_user(self): 47 | return self.user_lists[0] if self.user_lists else None 48 | -------------------------------------------------------------------------------- /example/schemas.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from cleancat import * 4 | from cleancat.mongo import MongoEmbeddedReference, MongoReference 5 | 6 | from example import documents 7 | 8 | 9 | class User(Schema): 10 | email = Email(required=False) 11 | first_name = String(required=False) 12 | last_name = String(required=False) 13 | emails = List(Email(), required=False) 14 | datetime = DateTime(regex_message="Invalid date 💩", required=False) 15 | datetime_local = DateTime(required=False) 16 | balance = Integer(required=False) 17 | 18 | 19 | class Content(Schema): 20 | text = String() 21 | lang = String() 22 | 23 | 24 | class Post(Schema): 25 | title = String() 26 | description = String(required=False) 27 | author = MongoReference(documents.User, required=False) 28 | editor = MongoReference(documents.User, required=False) 29 | tags = List(String(), required=False) 30 | user_lists = List(MongoReference(documents.User), required=False) 31 | sections = List(MongoEmbeddedReference(documents.Content, Content), required=False) 32 | content = MongoEmbeddedReference(documents.Content, Content, required=False) 33 | is_published = Bool() 34 | 35 | 36 | class Language(Schema): 37 | name = String() 38 | 39 | 40 | class Person(Schema): 41 | name = String() 42 | languages = List( 43 | MongoEmbeddedReference(documents.Language, Language), required=False 44 | ) 45 | 46 | 47 | class DateTime(Schema): 48 | datetime = DateTime() 49 | -------------------------------------------------------------------------------- /flask_mongorest/__init__.py: -------------------------------------------------------------------------------- 1 | from .methods import BulkUpdate, Create, List 2 | from .mongorest import MongoRest 3 | 4 | __all__ = [ 5 | "MongoRest", 6 | # TODO these methods probably shouldn't be exposed here? 7 | "BulkUpdate", 8 | "Create", 9 | "List", 10 | ] 11 | -------------------------------------------------------------------------------- /flask_mongorest/authentication.py: -------------------------------------------------------------------------------- 1 | class AuthenticationBase: 2 | def authorized(self): 3 | return False 4 | -------------------------------------------------------------------------------- /flask_mongorest/exceptions.py: -------------------------------------------------------------------------------- 1 | class MongoRestException(Exception): 2 | pass 3 | 4 | 5 | class OperatorNotAllowed(MongoRestException): 6 | def __init__(self, operator_name): 7 | self.op_name = operator_name 8 | 9 | def __unicode__(self): 10 | return f'"{self.op_name}" is not a valid operator name.' 11 | 12 | 13 | class InvalidFilter(MongoRestException): 14 | pass 15 | 16 | 17 | class ValidationError(MongoRestException): 18 | pass 19 | 20 | 21 | class UnknownFieldError(Exception): 22 | pass 23 | -------------------------------------------------------------------------------- /flask_mongorest/methods.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | 4 | class Create: 5 | method = "POST" 6 | 7 | 8 | class Update: 9 | method = "PUT" 10 | 11 | 12 | class BulkUpdate: 13 | method = "PUT" 14 | 15 | 16 | class Fetch: 17 | method = "GET" 18 | 19 | 20 | class List: 21 | method = "GET" 22 | 23 | 24 | class Delete: 25 | method = "DELETE" 26 | 27 | 28 | # type alias 29 | METHODS_TYPE = typing.Union[ 30 | typing.Type[Create], 31 | typing.Type[Update], 32 | typing.Type[BulkUpdate], 33 | typing.Type[Fetch], 34 | typing.Type[List], 35 | typing.Type[Delete], 36 | ] 37 | -------------------------------------------------------------------------------- /flask_mongorest/mongorest.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from flask import Blueprint, Flask 4 | 5 | from flask_mongorest import BulkUpdate, Create, List 6 | 7 | 8 | class DelayedApp: 9 | """ 10 | Store URL rules for later merging with an application URL map. 11 | """ 12 | 13 | def __init__(self): 14 | self.url_rules = [] 15 | 16 | def add_url_rule(self, *args, **kwargs): 17 | self.url_rules.append((args, kwargs)) 18 | 19 | 20 | def register_class(app: Union[DelayedApp, Flask], klass, *, url_prefix, **kwargs): 21 | # Construct a url based on a 'name' kwarg with a fallback to the 22 | # view's class name. Note that the name must be unique. 23 | name = kwargs.pop("name", klass.__name__) 24 | url = kwargs.pop("url", None) 25 | if not url: 26 | document_name = klass.resource.document.__name__.lower() 27 | url = f"/{document_name}/" 28 | 29 | # Insert the url prefix, if it exists 30 | if url_prefix: 31 | url = f"{url_prefix}{url}" 32 | 33 | # Add url rules 34 | pk_type = kwargs.pop("pk_type", "string") 35 | view_func = klass.as_view(name) 36 | if List in klass.methods: 37 | app.add_url_rule( 38 | url, 39 | defaults={"pk": None}, 40 | view_func=view_func, 41 | methods=[List.method], 42 | **kwargs, 43 | ) 44 | if Create in klass.methods or BulkUpdate in klass.methods: 45 | app.add_url_rule( 46 | url, 47 | view_func=view_func, 48 | methods=[x.method for x in klass.methods if x in (Create, BulkUpdate)], 49 | **kwargs, 50 | ) 51 | app.add_url_rule( 52 | f"{url}<{pk_type}:pk>/", 53 | view_func=view_func, 54 | methods=[x.method for x in klass.methods if x not in (List, BulkUpdate)], 55 | **kwargs, 56 | ) 57 | 58 | 59 | class MongoRest: 60 | def __init__(self, app=None, url_prefix="", template_folder="templates"): 61 | self.url_prefix = url_prefix 62 | self.template_folder = template_folder 63 | self._delayed_app = DelayedApp() 64 | self._registered_apps = [] 65 | 66 | if app is not None: 67 | self.init_app(app) 68 | 69 | def init_app(self, app): 70 | """ 71 | Provide delayed application instance initialization to support 72 | Flask application factory pattern. For further details on application 73 | factories see: https://flask.palletsprojects.com/en/2.0.x/extensiondev/ 74 | """ 75 | app.register_blueprint( 76 | Blueprint(self.url_prefix, __name__, template_folder=self.template_folder) 77 | ) 78 | 79 | for args, kwargs in self._delayed_app.url_rules: 80 | app.add_url_rule(*args, **kwargs) 81 | 82 | self._registered_apps.append(app) 83 | 84 | def register(self, **kwargs): 85 | def decorator(klass): 86 | for app in [self._delayed_app] + self._registered_apps: 87 | register_class(app, klass, url_prefix=self.url_prefix, **kwargs) 88 | return klass 89 | 90 | return decorator 91 | -------------------------------------------------------------------------------- /flask_mongorest/operators.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flask-MongoRest operators. 3 | 4 | Operators are the building blocks that Resource filters are built upon. 5 | Their role is to generate and apply the right filters to a provided 6 | queryset. For example: 7 | 8 | GET /post/?title__startswith=John 9 | 10 | Such request would result in calling this module's `Startswith` operator 11 | like so: 12 | 13 | new_queryset = Startswith().apply(queryset, 'title', 'John') 14 | 15 | Where the original queryset would be `BlogPost.objects.all()` and the 16 | new queryset would be equivalent to: 17 | 18 | BlogPost.objects.filter(title__startswith='John') 19 | 20 | It's also easy to create your own Operator subclass and use it in your 21 | Resource. For example, if you have an endpoint listing students and you 22 | want to filter them by the range of their scores like so: 23 | 24 | GET /student/?score__range=0,10 25 | 26 | Then you can create a Range Operator: 27 | 28 | class Range(Operator): 29 | op = 'range' 30 | def prepare_queryset_kwargs(self, field, value, negate=False): 31 | # For the sake of simplicity, we won't support negate here, 32 | # i.e. /student/?score__not__range=0,10 won't work. 33 | lower, upper = value.split(',') 34 | return { 35 | field + '__lte': upper, 36 | field + '__gte': lower 37 | } 38 | 39 | Then you include it in your Resource's filters: 40 | 41 | class StudentResource(Resource): 42 | document = documents.Student 43 | filters = { 44 | 'score': [Range] 45 | } 46 | 47 | And this way, the request we mentioned above would result in: 48 | 49 | Student.objects.filter(score__lte=upper, score__gte=lower) 50 | """ 51 | 52 | 53 | class Operator: 54 | """Base class that all the other operators should inherit from.""" 55 | 56 | op = "exact" 57 | 58 | # Can be overridden via constructor. 59 | allow_negation = False 60 | 61 | def __init__(self, allow_negation=False): 62 | self.allow_negation = allow_negation 63 | 64 | # Lets us specify filters as an instance if we want to override the 65 | # default arguments (in addition to specifying them as a class). 66 | def __call__(self): 67 | return self 68 | 69 | def prepare_queryset_kwargs(self, field, value, negate): 70 | if negate: 71 | return {"__".join(filter(None, [field, "not", self.op])): value} 72 | else: 73 | return {"__".join(filter(None, [field, self.op])): value} 74 | 75 | def apply(self, queryset, field, value, negate=False): 76 | kwargs = self.prepare_queryset_kwargs(field, value, negate) 77 | return queryset.filter(**kwargs) 78 | 79 | 80 | class Ne(Operator): 81 | op = "ne" 82 | 83 | 84 | class Lt(Operator): 85 | op = "lt" 86 | 87 | 88 | class Lte(Operator): 89 | op = "lte" 90 | 91 | 92 | class Gt(Operator): 93 | op = "gt" 94 | 95 | 96 | class Gte(Operator): 97 | op = "gte" 98 | 99 | 100 | class Exact(Operator): 101 | op = "exact" 102 | 103 | def prepare_queryset_kwargs(self, field, value, negate): 104 | # Using __exact causes mongoengine to generate a regular 105 | # expression query, which we'd like to avoid. 106 | if negate: 107 | return {f"{field}__ne": value} 108 | else: 109 | return {field: value} 110 | 111 | 112 | class IExact(Operator): 113 | op = "iexact" 114 | 115 | 116 | class In(Operator): 117 | op = "in" 118 | 119 | def prepare_queryset_kwargs(self, field, value, negate): 120 | # this is null if the user submits an empty in expression (like 121 | # "user__in=") 122 | value = value or [] 123 | 124 | # only use 'in' or 'nin' if multiple values are specified 125 | if "," in value: 126 | value = value.split(",") 127 | op = negate and "nin" or self.op 128 | else: 129 | op = negate and "ne" or "" 130 | return {"__".join(filter(None, [field, op])): value} 131 | 132 | 133 | class Contains(Operator): 134 | op = "contains" 135 | 136 | 137 | class IContains(Operator): 138 | op = "icontains" 139 | 140 | 141 | class Startswith(Operator): 142 | op = "startswith" 143 | 144 | 145 | class IStartswith(Operator): 146 | op = "istartswith" 147 | 148 | 149 | class Endswith(Operator): 150 | op = "endswith" 151 | 152 | 153 | class IEndswith(Operator): 154 | op = "iendswith" 155 | 156 | 157 | class Boolean(Operator): 158 | op = "exact" 159 | 160 | def prepare_queryset_kwargs(self, field, value, negate): 161 | if value == "false": 162 | bool_value = False 163 | else: 164 | bool_value = True 165 | 166 | if negate: 167 | bool_value = not bool_value 168 | 169 | return {field: bool_value} 170 | -------------------------------------------------------------------------------- /flask_mongorest/resources.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | from typing import Dict, List, Type 4 | from urllib.parse import urlparse 5 | 6 | import mongoengine 7 | from bson.dbref import DBRef 8 | from bson.objectid import ObjectId 9 | from flask import has_request_context, request, url_for 10 | 11 | try: # closeio/mongoengine 12 | from mongoengine.base.proxy import DocumentProxy 13 | from mongoengine.fields import SafeReferenceField 14 | except ImportError: 15 | DocumentProxy = None 16 | SafeReferenceField = None 17 | 18 | from cleancat import ValidationError as SchemaValidationError 19 | from mongoengine.fields import ( 20 | DictField, 21 | EmbeddedDocumentField, 22 | GenericReferenceField, 23 | ListField, 24 | ReferenceField, 25 | ) 26 | 27 | from flask_mongorest import methods 28 | from flask_mongorest.exceptions import UnknownFieldError, ValidationError 29 | from flask_mongorest.utils import cmp_fields, equal, isbound, isint 30 | 31 | 32 | class ResourceMeta(type): 33 | def __init__(cls, name, bases, classdict): 34 | if classdict.get("__metaclass__") is not ResourceMeta: 35 | for document, resource in cls.child_document_resources.items(): 36 | if resource == name: 37 | cls.child_document_resources[document] = cls 38 | type.__init__(cls, name, bases, classdict) 39 | 40 | 41 | class Resource(metaclass=ResourceMeta): 42 | # MongoEngine Document class related to this resource (required) 43 | document = None 44 | 45 | # List of fields that can (and should by default) be included in the 46 | # response 47 | fields = None 48 | 49 | # Dict of original field names (as seen in `fields`) and what they should 50 | # be renamed to in the API response 51 | rename_fields: Dict[str, str] = {} 52 | 53 | # CleanCat Schema class (used for validation) 54 | schema = None 55 | 56 | # List of fields that the objects can be ordered by 57 | allowed_ordering: List[str] = [] 58 | 59 | # Define whether or not this resource supports pagination 60 | paginate = True 61 | 62 | # Default limit if no _limit is specified in the request. Only relevant 63 | # if pagination is enabled. 64 | default_limit = 100 65 | 66 | # Maximum value of _limit that can be requested (avoids DDoS'ing the API). 67 | # Only relevant if pagination is enabled. 68 | max_limit = 100 69 | 70 | # Maximum number of objects which can be bulk-updated by a single request 71 | bulk_update_limit = 1000 72 | 73 | # Map of field names and Resource classes that should be used to handle 74 | # these fields (for serialization, saving, etc.). 75 | related_resources: Dict[str, "Resource"] = {} 76 | 77 | # Map of field names on this resource's document to field names on the 78 | # related resource's document, used as a helper in the process of 79 | # turning a field value from a queryset to a list of objects 80 | # 81 | # TODO Behavior of this is *very* unintuitive and should be changed or 82 | # dropped, or at least refactored 83 | related_resources_hints: Dict[str, str] = {} 84 | 85 | # List of field names corresponding to related resources. If a field is 86 | # mentioned here and in `related_resources`, it can be created/updated 87 | # from within this resource. 88 | save_related_fields: List[str] = [] 89 | 90 | # Map of MongoEngine Document classes to Resource class names. Defines 91 | # which sub-resource should be used for handling a particular subclass of 92 | # this resource's document. 93 | child_document_resources: Dict[Type, str] = {} 94 | 95 | # Whenever a new document is posted and the system doesn't know the type 96 | # of it yet, it will choose a default sub-resource for this document type 97 | default_child_resource_document = None 98 | 99 | # Defines whether MongoEngine's select_related should be used on a 100 | # filtered query set, pulling all the references efficiently. 101 | select_related = False 102 | 103 | # Must start and end with a "/" 104 | uri_prefix = None 105 | 106 | def __init__(self, view_method=None): 107 | """ 108 | Initialize a resource. Optionally, a method class can be given to 109 | view_method (see methods.py) so the resource can behave differently 110 | depending on the method. 111 | """ 112 | doc_fields = self.document._fields.keys() 113 | if self.fields is None: 114 | self.fields = doc_fields 115 | self._related_resources = self.get_related_resources() 116 | self._rename_fields = self.get_rename_fields() 117 | self._reverse_rename_fields = {} 118 | for k, v in self._rename_fields.items(): 119 | self._reverse_rename_fields[v] = k 120 | assert len(self._rename_fields) == len( 121 | self._reverse_rename_fields 122 | ), "Cannot rename multiple fields to the same name" 123 | self._filters = self.get_filters() 124 | self._child_document_resources = self.get_child_document_resources() 125 | self._default_child_resource_document = ( 126 | self.get_default_child_resource_document() 127 | ) 128 | self.data = None 129 | self._dirty_fields = None 130 | self.view_method = view_method 131 | 132 | @property 133 | def params(self): 134 | """ 135 | Return parameters of the request which is currently being processed. 136 | Params can be passed in two different ways: 137 | 138 | 1. As a querystring (e.g. '/resource/?status=active&_limit=10'). 139 | 2. As a _params property in the JSON payload. For example: 140 | { '_params': { 'status': 'active', '_limit': '10' } } 141 | """ 142 | if not has_request_context(): 143 | # `params` doesn't make sense if we don't have a request 144 | raise AttributeError 145 | 146 | if not hasattr(self, "_params"): 147 | if "_params" in self.raw_data: 148 | self._params = self.raw_data["_params"] 149 | else: 150 | try: 151 | self._params = request.args.to_dict() 152 | except AttributeError: # mocked request with regular dict 153 | self._params = request.args 154 | return self._params 155 | 156 | def _enforce_strict_json(self, val): 157 | """ 158 | Enforce strict json parsing. 159 | 160 | Raise a ValueError if NaN, Infinity, or -Infinity were posted. By 161 | default, json.loads accepts these values, but it allows us to perform 162 | extra validation via a parse_constant kwarg. 163 | """ 164 | # according to the `json.loads` docs: "parse_constant, if specified, 165 | # will be called with one of the following strings: '-Infinity', 166 | # 'Infinity', 'NaN'". Since none of them are valid JSON, we can simply 167 | # raise an exception here. 168 | raise ValueError 169 | 170 | @property 171 | def raw_data(self): 172 | """Validate and return parsed JSON payload.""" 173 | if not has_request_context(): 174 | # `raw_data` doesn't make sense if we don't have a request 175 | raise AttributeError 176 | 177 | if not hasattr(self, "_raw_data"): 178 | if request.method in ("PUT", "POST") or request.data: 179 | if request.mimetype and "json" not in request.mimetype: 180 | raise ValidationError( 181 | { 182 | "error": "Please send valid JSON with a 'Content-Type: application/json' header." 183 | } 184 | ) 185 | if request.headers.get("Transfer-Encoding") == "chunked": 186 | raise ValidationError( 187 | {"error": "Chunked Transfer-Encoding is not supported."} 188 | ) 189 | 190 | try: 191 | self._raw_data = json.loads( 192 | request.data.decode("utf-8"), 193 | parse_constant=self._enforce_strict_json, 194 | ) 195 | except ValueError: 196 | raise ValidationError( 197 | {"error": "The request contains invalid JSON."} 198 | ) 199 | if not isinstance(self._raw_data, dict): 200 | raise ValidationError({"error": "JSON data must be a dict."}) 201 | else: 202 | self._raw_data = {} 203 | 204 | return self._raw_data 205 | 206 | @classmethod 207 | def uri(cls, path): 208 | """Generate a URI reference for the given path""" 209 | if cls.uri_prefix: 210 | ret = cls.uri_prefix + path 211 | return ret 212 | else: 213 | raise ValueError( 214 | "Cannot generate URI for resources that do not specify a uri_prefix" 215 | ) 216 | 217 | @classmethod 218 | def _url(cls, path): 219 | """Generate a complete URL for the given path. Requires application context.""" 220 | if cls.uri_prefix: 221 | url = url_for(cls.uri_prefix.lstrip("/").rstrip("/"), _external=True) 222 | ret = url + path 223 | return ret 224 | else: 225 | raise ValueError( 226 | "Cannot generate URL for resources that do not specify a uri_prefix" 227 | ) 228 | 229 | def get_fields(self): 230 | """ 231 | Return a list of fields that should be included in the response 232 | (unless a `_fields` param didn't include them). 233 | """ 234 | return self.fields 235 | 236 | def get_optional_fields(self): 237 | """ 238 | Return a list of fields that can optionally be included in the 239 | response (but only if a `_fields` param mentioned them explicitly). 240 | """ 241 | return [] 242 | 243 | def get_requested_fields(self, **kwargs): 244 | """ 245 | Process a list of fields requested by the client and return only the 246 | ones which are allowed by get_fields and get_optional_fields. 247 | 248 | If `_fields` param is set to '_all', return a list of all the fields 249 | from get_fields and get_optional_fields combined. 250 | """ 251 | params = kwargs.get("params", None) 252 | 253 | include_all = False 254 | 255 | if "fields" in kwargs: 256 | fields = kwargs["fields"] 257 | all_fields_set = set(fields) 258 | else: 259 | fields = self.get_fields() 260 | all_fields_set = set(fields) | set(self.get_optional_fields()) 261 | 262 | if params and "_fields" in params: 263 | only_fields = set(params["_fields"].split(",")) 264 | if "_all" in only_fields: 265 | include_all = True 266 | else: 267 | only_fields = None 268 | 269 | requested_fields = [] 270 | if include_all or only_fields is None: 271 | if include_all: 272 | field_selection = all_fields_set 273 | else: 274 | field_selection = fields 275 | for field in field_selection: 276 | requested_fields.append(field) 277 | else: 278 | for field in only_fields: 279 | actual_field = self._reverse_rename_fields.get(field, field) 280 | if actual_field in all_fields_set: 281 | requested_fields.append(actual_field) 282 | 283 | return requested_fields 284 | 285 | def get_max_limit(self): 286 | return self.max_limit 287 | 288 | def get_related_resources(self): 289 | return self.related_resources 290 | 291 | def get_save_related_fields(self): 292 | return self.save_related_fields 293 | 294 | def get_rename_fields(self): 295 | """ 296 | @TODO should automatically support model_id for reference fields (only) and model for related_resources 297 | """ 298 | return self.rename_fields 299 | 300 | def get_child_document_resources(self): 301 | # By default, don't inherit child_document_resources. This lets us have 302 | # multiple resources for a child document without having to reset the 303 | # child_document_resources property in the subclass. 304 | if "child_document_resources" in self.__class__.__dict__: 305 | return self.child_document_resources 306 | else: 307 | return {} 308 | 309 | def get_default_child_resource_document(self): 310 | # See comment on get_child_document_resources. 311 | if "default_child_resource_document" in self.__class__.__dict__: 312 | return self.default_child_resource_document 313 | else: 314 | return None 315 | 316 | def get_filters(self): 317 | """ 318 | Given the filters declared on this resource, return a mapping 319 | of all allowed filters along with their individual mappings of 320 | suffixes and operators. 321 | 322 | For example, if self.filters declares: 323 | { 'date': [operators.Exact, operators.Gte] } 324 | then this method will return: 325 | { 326 | 'date': { 327 | '': operators.Exact, 328 | 'exact': operators.Exact, 329 | 'gte': operators.Gte 330 | } 331 | } 332 | Then, when a request comes in, Flask-MongoRest will match 333 | `?date__gte=value` to the 'date' field and the 'gte' suffix: 'gte', 334 | and hence use the Gte operator to filter the data. 335 | """ 336 | filters = {} 337 | for field, operators in getattr(self, "filters", {}).items(): 338 | field_filters = {} 339 | for op in operators: 340 | if op.op == "exact": 341 | field_filters[""] = op 342 | field_filters[op.op] = op 343 | filters[field] = field_filters 344 | return filters 345 | 346 | def serialize_field(self, obj, **kwargs): 347 | if self.uri_prefix and hasattr(obj, "id"): 348 | return self._url(str(obj.id)) 349 | else: 350 | return self.serialize(obj, **kwargs) 351 | 352 | def _subresource(self, obj): 353 | """ 354 | Select and create an appropriate sub-resource class for delegation or 355 | return None if there isn't one. 356 | """ 357 | s_class = self._child_document_resources.get(obj.__class__) 358 | if not s_class and self._default_child_resource_document: 359 | s_class = self._child_document_resources[ 360 | self._default_child_resource_document 361 | ] 362 | if s_class and s_class != self.__class__: 363 | r = s_class(view_method=self.view_method) 364 | r.data = self.data 365 | return r 366 | else: 367 | return None 368 | 369 | def get_field_value(self, obj, field_name, field_instance=None, **kwargs): 370 | """Return a json-serializable field value. 371 | 372 | field_name is the name of the field in `obj` to be serialized. 373 | field_instance is a MongoEngine field definition. 374 | **kwargs are just any options to be passed through to child resources serializers. 375 | """ 376 | has_field_instance = bool(field_instance) 377 | field_instance = ( 378 | field_instance 379 | or self.document._fields.get(field_name, None) 380 | or getattr(self.document, field_name, None) 381 | ) 382 | 383 | # Determine the field value 384 | if has_field_instance: 385 | field_value = obj 386 | elif isinstance(obj, dict): 387 | return obj[field_name] 388 | else: 389 | try: 390 | field_value = getattr(obj, field_name) 391 | except AttributeError: 392 | raise UnknownFieldError 393 | 394 | return self.serialize_field_value( 395 | obj, field_name, field_instance, field_value, **kwargs 396 | ) 397 | 398 | def serialize_field_value( 399 | self, obj, field_name, field_instance, field_value, **kwargs 400 | ): 401 | """Select and delegate to an appropriate serializer method based on type of field instance. 402 | 403 | field_value is an actual value to be serialized. 404 | For other fields, see get_field_value method. 405 | """ 406 | if isinstance( 407 | field_instance, 408 | (ReferenceField, GenericReferenceField, EmbeddedDocumentField), 409 | ): 410 | return self.serialize_document_field(field_name, field_value, **kwargs) 411 | 412 | elif isinstance(field_instance, ListField): 413 | return self.serialize_list_field( 414 | field_instance, field_name, field_value, **kwargs 415 | ) 416 | 417 | elif isinstance(field_instance, DictField): 418 | return self.serialize_dict_field( 419 | field_instance, field_name, field_value, **kwargs 420 | ) 421 | 422 | elif callable(field_instance): 423 | return self.serialize_callable_field( 424 | obj, field_instance, field_name, field_value, **kwargs 425 | ) 426 | return field_value 427 | 428 | def serialize_callable_field( 429 | self, obj, field_instance, field_name, field_value, **kwargs 430 | ): 431 | """Execute a callable field and return it or serialize 432 | it based on its related resource defined in the `related_resources` map. 433 | """ 434 | if isinstance(field_value, list): 435 | value = field_value 436 | else: 437 | if isbound(field_instance): 438 | value = field_instance() 439 | elif isbound(field_value): 440 | value = field_value() 441 | else: 442 | value = field_instance(obj) 443 | if field_name in self._related_resources: 444 | if isinstance(value, list): 445 | return [ 446 | self._related_resources[field_name]().serialize_field(o, **kwargs) 447 | for o in value 448 | ] 449 | elif value is None: 450 | return None 451 | else: 452 | return self._related_resources[field_name]().serialize_field( 453 | value, **kwargs 454 | ) 455 | return value 456 | 457 | def serialize_dict_field(self, field_instance, field_name, field_value, **kwargs): 458 | """Serialize each value based on an explicit field type 459 | (e.g. if the schema defines a DictField(IntField), where all 460 | the values in the dict should be ints). 461 | """ 462 | if field_instance.field: 463 | return { 464 | key: self.get_field_value( 465 | elem, field_name, field_instance=field_instance.field, **kwargs 466 | ) 467 | for (key, elem) in field_value.items() 468 | } 469 | # ... or simply return the dict intact, if the field type 470 | # wasn't specified 471 | else: 472 | return field_value 473 | 474 | def serialize_list_field(self, field_instance, field_name, field_value, **kwargs): 475 | """Serialize each item in the list separately.""" 476 | return [ 477 | val 478 | for val in [ 479 | self.get_field_value( 480 | elem, field_name, field_instance=field_instance.field, **kwargs 481 | ) 482 | for elem in field_value 483 | ] 484 | if val 485 | ] 486 | 487 | def serialize_document_field(self, field_name, field_value, **kwargs): 488 | """If this field is a reference or an embedded document, either return 489 | a DBRef or serialize it using a resource found in `related_resources`. 490 | """ 491 | if field_name in self._related_resources: 492 | return ( 493 | field_value 494 | and not isinstance(field_value, DBRef) 495 | and self._related_resources[field_name]().serialize_field( 496 | field_value, **kwargs 497 | ) 498 | ) 499 | else: 500 | if DocumentProxy and isinstance(field_value, DocumentProxy): 501 | # Don't perform a DBRef isinstance check below since 502 | # it might trigger an extra query. 503 | return field_value.to_dbref() 504 | if isinstance(field_value, DBRef): 505 | return field_value 506 | return field_value and field_value.to_dbref() 507 | 508 | def serialize(self, obj, **kwargs): 509 | """ 510 | Given an object, serialize it, turning it into its JSON 511 | representation. 512 | """ 513 | if not obj: 514 | return {} 515 | 516 | # If a subclass of an obj has been called with a base class' resource, 517 | # use the subclass-specific serialization 518 | subresource = self._subresource(obj) 519 | if subresource: 520 | return subresource.serialize(obj, **kwargs) 521 | 522 | # Get the requested fields 523 | requested_fields = self.get_requested_fields(**kwargs) 524 | 525 | # Drop the kwargs we don't need any more (we're passing `kwargs` to 526 | # child resources so we don't want to pass `fields` and `params` that 527 | # pertain to the parent resource). 528 | kwargs.pop("fields", None) 529 | kwargs.pop("params", None) 530 | 531 | # Fill in the `data` dict by serializing each of the requested fields 532 | # one by one. 533 | data = {} 534 | for field in requested_fields: 535 | 536 | # resolve the user-facing name of the field 537 | renamed_field = self._rename_fields.get(field, field) 538 | 539 | # if the field is callable, execute it with `obj` as the param 540 | if hasattr(self, field) and callable(getattr(self, field)): 541 | value = getattr(self, field)(obj) 542 | 543 | # if the field is associated with a specific resource (via the 544 | # `related_resources` map), use that resource to serialize it 545 | if field in self._related_resources and value is not None: 546 | related_resource = self._related_resources[field]() 547 | if isinstance(value, mongoengine.document.Document): 548 | value = related_resource.serialize_field(value) 549 | elif isinstance(value, dict): 550 | value = { 551 | k: related_resource.serialize_field(v) 552 | for (k, v) in value.items() 553 | } 554 | else: # assume queryset or list 555 | value = [related_resource.serialize_field(o) for o in value] 556 | data[renamed_field] = value 557 | else: 558 | try: 559 | data[renamed_field] = self.get_field_value(obj, field, **kwargs) 560 | except UnknownFieldError: 561 | with contextlib.suppress(UnknownFieldError): 562 | data[renamed_field] = self.value_for_field(obj, field) 563 | 564 | return data 565 | 566 | def handle_serialization_error(self, exc, obj): 567 | """ 568 | Override this to implement custom behavior whenever serializing an 569 | object fails. 570 | """ 571 | pass 572 | 573 | def value_for_field(self, obj, field): 574 | """ 575 | If we specify a field which doesn't exist on the resource or on the 576 | object, this method lets us return a custom value. 577 | """ 578 | raise UnknownFieldError 579 | 580 | def validate_request(self, obj=None): 581 | """ 582 | Validate the request that's currently being processed and fill in 583 | the self.data dict that'll later be used to save/update an object. 584 | 585 | `obj` points to the object that's being updated, or is empty if a new 586 | object is being created. 587 | """ 588 | # When creating or updating a single object, delegate the validation 589 | # to a more specific subresource, if it exists 590 | if (request.method == "PUT" and obj) or request.method == "POST": 591 | subresource = self._subresource(obj) 592 | if subresource: 593 | subresource._raw_data = self._raw_data 594 | subresource.validate_request(obj=obj) 595 | self.data = subresource.data 596 | return 597 | 598 | # Don't work on original raw data, we may reuse the resource for bulk 599 | # updates. 600 | self.data = self.raw_data.copy() 601 | 602 | # Do renaming in two passes to prevent potential multiple renames 603 | # depending on dict traversal order. 604 | # E.g. if a -> b, b -> c, then a should never be renamed to c. 605 | fields_to_delete = [] 606 | fields_to_update = {} 607 | for k, v in self._rename_fields.items(): 608 | if v in self.data: 609 | fields_to_update[k] = self.data[v] 610 | fields_to_delete.append(v) 611 | for k in fields_to_delete: 612 | del self.data[k] 613 | for k, v in fields_to_update.items(): 614 | self.data[k] = v 615 | 616 | # If CleanCat schema exists on this resource, use it to perform the 617 | # validation 618 | if self.schema: 619 | if request.method == "PUT" and obj is not None: 620 | obj_data = {key: getattr(obj, key) for key in obj._fields.keys()} 621 | else: 622 | obj_data = None 623 | 624 | schema = self.schema(self.data, obj_data) 625 | try: 626 | self.data = schema.full_clean() 627 | except SchemaValidationError: 628 | raise ValidationError( 629 | {"field-errors": schema.field_errors, "errors": schema.errors} 630 | ) 631 | 632 | def get_queryset(self): 633 | """ 634 | Return a MongoEngine queryset that will later be used to return 635 | matching documents. 636 | """ 637 | return self.document.objects 638 | 639 | def get_object(self, pk, qfilter=None): 640 | """ 641 | Given a PK and an optional queryset filter function, find a matching 642 | document in the queryset. 643 | """ 644 | qs = self.get_queryset() 645 | # If a queryset filter was provided, pass our current queryset in and 646 | # get a new one out 647 | if qfilter: 648 | qs = qfilter(qs) 649 | obj = qs.get(pk=pk) 650 | 651 | # We don't need to fetch related resources for DELETE requests because 652 | # those requests do not serialize the object (a successful DELETE 653 | # simply returns a `{}`, at least by default). We still want to fetch 654 | # related resources for GET and PUT. 655 | if request.method != "DELETE": 656 | self.fetch_related_resources( 657 | [obj], self.get_requested_fields(params=self.params) 658 | ) 659 | 660 | return obj 661 | 662 | def fetch_related_resources(self, objs, only_fields=None): 663 | """ 664 | Given a list of objects and an optional list of the only fields we 665 | should care about, fetch these objects' related resources. 666 | """ 667 | if not self.related_resources_hints: 668 | return 669 | 670 | # Create a map of field names to MongoEngine Q objects that will 671 | # later be used to fetch the related resources from MongoDB 672 | # Queries for the same document/collection are combined to improve 673 | # efficiency. 674 | document_queryset = {} 675 | for obj in objs: 676 | for field_name in self.related_resources_hints.keys(): 677 | if only_fields is not None and field_name not in only_fields: 678 | continue 679 | method = getattr(obj, field_name) 680 | if callable(method): 681 | q = method() 682 | if field_name in document_queryset: 683 | document_queryset[field_name] = ( 684 | document_queryset[field_name] | q._query_obj 685 | ) 686 | else: 687 | document_queryset[field_name] = q._query_obj 688 | 689 | # For each field name, execute the queries we generated in the block 690 | # above, and map the results to each object that references them. 691 | # TODO This is in dire need of refactoring, or a complete overhaul 692 | hints = {} 693 | for field_name, q_obj in document_queryset.items(): 694 | doc = self.get_related_resources()[field_name].document 695 | 696 | # Create a QuerySet based on the query object 697 | query = doc.objects.filter(q_obj) 698 | 699 | # Don't let MongoDB do the sorting as it won't use the index. 700 | # Store the ordering so we can do client sorting afterwards. 701 | ordering = query._ordering or query._get_order_by( 702 | query._document._meta["ordering"] 703 | ) 704 | query = query.order_by() 705 | 706 | # Fetch the results 707 | results = list(query) 708 | 709 | # Reapply the ordering and add results to the mapping 710 | if ordering: 711 | document_queryset[field_name] = sorted(results, cmp_fields(ordering)) 712 | else: 713 | document_queryset[field_name] = results 714 | 715 | # For each field name, create a map of obj PKs to a list of 716 | # results they referenced. 717 | hint_index = {} 718 | if field_name in self.related_resources_hints: 719 | hint_field = self.related_resources_hints[field_name] 720 | for obj in document_queryset[field_name]: 721 | hint_field_instance = obj._fields[hint_field] 722 | # Don't trigger a query for SafeReferenceFields 723 | if SafeReferenceField and isinstance( 724 | hint_field_instance, SafeReferenceField 725 | ): 726 | hinted = obj._db_data[hint_field] 727 | if hint_field_instance.dbref: 728 | hinted = hinted.id 729 | else: 730 | hinted = str(getattr(obj, hint_field).id) 731 | if hinted not in hint_index: 732 | hint_index[hinted] = [obj] 733 | else: 734 | hint_index[hinted].append(obj) 735 | 736 | hints[field_name] = hint_index 737 | 738 | # Assign the results to each object 739 | # TODO This is in dire need of refactoring, or a complete overhaul 740 | for obj in objs: 741 | for field_name, hint_index in hints.items(): 742 | obj_id = obj.id 743 | if isinstance(obj_id, DBRef): 744 | obj_id = obj_id.id 745 | elif isinstance(obj_id, ObjectId): 746 | obj_id = str(obj_id) 747 | if obj_id not in hint_index: 748 | setattr(obj, field_name, []) 749 | else: 750 | setattr(obj, field_name, hint_index[obj_id]) 751 | 752 | def apply_filters(self, qs, params=None): 753 | """ 754 | Given this resource's filters, and the params of the request that's 755 | currently being processed, apply additional filtering to the queryset 756 | and return it. 757 | """ 758 | if params is None: 759 | params = self.params 760 | 761 | for key, value in params.items(): 762 | # If this is a resource identified by a URI, we need 763 | # to extract the object id at this point since 764 | # MongoEngine only understands the object id 765 | if self.uri_prefix: 766 | url = urlparse(value) 767 | uri = url.path 768 | value = uri.lstrip(self.uri_prefix) 769 | 770 | # special handling of empty / null params 771 | # http://werkzeug.pocoo.org/docs/0.9/utils/ url_decode returns '' for empty params 772 | if value == "": 773 | value = None 774 | elif value in ['""', "''"]: 775 | value = "" 776 | 777 | negate = False 778 | op_name = "" 779 | parts = key.split("__") 780 | for i in range(len(parts) + 1, 0, -1): 781 | field = "__".join(parts[:i]) 782 | allowed_operators = self._filters.get(field) 783 | if allowed_operators: 784 | parts = parts[i:] 785 | break 786 | if allowed_operators is None: 787 | continue 788 | 789 | if parts: 790 | # either an operator or a query lookup! See what's allowed. 791 | op_name = parts[-1] 792 | if op_name in allowed_operators: 793 | # operator; drop it 794 | parts.pop() 795 | else: 796 | # assume it's part of a lookup 797 | op_name = "" 798 | if parts and parts[-1] == "not": 799 | negate = True 800 | parts.pop() 801 | 802 | operator = allowed_operators.get(op_name, None) 803 | if operator is None: 804 | continue 805 | if negate and not operator.allow_negation: 806 | continue 807 | if parts: 808 | field = f"{field}__{'__'.join(parts)}" 809 | field = self._reverse_rename_fields.get(field, field) 810 | qs = operator().apply(qs, field, value, negate) 811 | return qs 812 | 813 | def apply_ordering(self, qs, params=None): 814 | """ 815 | Given this resource's allowed_ordering, and the params of the request 816 | that's currently being processed, apply ordering to the queryset 817 | and return it. 818 | """ 819 | if params is None: 820 | params = self.params 821 | if self.allowed_ordering and params.get("_order_by") in self.allowed_ordering: 822 | order_params = [ 823 | self._reverse_rename_fields.get(p, p) 824 | for p in params["_order_by"].split(",") 825 | ] 826 | qs = qs.order_by(*order_params) 827 | return qs 828 | 829 | def get_skip_and_limit(self, params=None): 830 | """ 831 | Perform validation and return sanitized values for _skip and _limit 832 | params of the request that's currently being processed. 833 | """ 834 | max_limit = self.get_max_limit() 835 | if params is None: 836 | params = self.params 837 | if self.paginate: 838 | # _limit and _skip validation 839 | if not isint(params.get("_limit", 1)): 840 | raise ValidationError( 841 | { 842 | "error": '_limit must be an integer (got "{}" instead).'.format( 843 | params["_limit"] 844 | ) 845 | } 846 | ) 847 | if not isint(params.get("_skip", 1)): 848 | raise ValidationError( 849 | { 850 | "error": '_skip must be an integer (got "{}" instead).'.format( 851 | params["_skip"] 852 | ) 853 | } 854 | ) 855 | if params.get("_limit") and int(params["_limit"]) > max_limit: 856 | raise ValidationError( 857 | { 858 | "error": f"The limit you set is larger than the maximum limit for this resource (max_limit = {max_limit})." 859 | } 860 | ) 861 | if params.get("_skip") and int(params["_skip"]) < 0: 862 | raise ValidationError( 863 | { 864 | "error": '_skip must be a non-negative integer (got "{}" instead).'.format( 865 | params["_skip"] 866 | ) 867 | } 868 | ) 869 | 870 | limit = min(int(params.get("_limit", self.default_limit)), max_limit) 871 | # Fetch one more so we know if there are more results. 872 | return int(params.get("_skip", 0)), limit 873 | else: 874 | return 0, max_limit 875 | 876 | def get_objects(self, qs=None, qfilter=None): 877 | """ 878 | Return objects fetched from the database based on all the parameters 879 | of the request that's currently being processed. 880 | 881 | Params: 882 | - Custom queryset can be passed via `qs`. Otherwise `self.get_queryset` 883 | is used. 884 | - Pass `qfilter` function to modify the queryset. 885 | """ 886 | params = self.params 887 | 888 | custom_qs = True 889 | if qs is None: 890 | custom_qs = False 891 | qs = self.get_queryset() 892 | 893 | # If a queryset filter was provided, pass our current queryset in and 894 | # get a new one out 895 | if qfilter: 896 | qs = qfilter(qs) 897 | 898 | # Apply filters and ordering, based on the params supplied by the 899 | # request 900 | qs = self.apply_filters(qs, params) 901 | qs = self.apply_ordering(qs, params) 902 | 903 | # Apply limit and skip to the queryset 904 | limit = None 905 | if self.view_method == methods.BulkUpdate: 906 | # limit the number of objects that can be bulk-updated at a time 907 | qs = qs.limit(self.bulk_update_limit) 908 | elif not custom_qs: 909 | # no need to skip/limit if a custom `qs` was provided 910 | skip, limit = self.get_skip_and_limit(params) 911 | qs = qs.skip(skip).limit(limit + 1) 912 | 913 | # Needs to be at the end as it returns a list, not a queryset 914 | if self.select_related: 915 | qs = qs.select_related() 916 | 917 | # Evaluate the queryset 918 | objs = list(qs) 919 | 920 | # Raise a validation error if bulk update would result in more than 921 | # bulk_update_limit updates 922 | if ( 923 | self.view_method == methods.BulkUpdate 924 | and len(objs) >= self.bulk_update_limit 925 | ): 926 | raise ValidationError( 927 | { 928 | "errors": [ 929 | f"It's not allowed to update more than {self.bulk_update_limit} objects at once" 930 | ] 931 | } 932 | ) 933 | 934 | # Determine the value of has_more 935 | if self.view_method != methods.BulkUpdate and self.paginate: 936 | has_more = len(objs) > limit 937 | if has_more: 938 | objs = objs[:-1] 939 | else: 940 | has_more = None 941 | 942 | # bulk-fetch related resources for moar speed 943 | self.fetch_related_resources(objs, self.get_requested_fields(params=params)) 944 | 945 | return objs, has_more 946 | 947 | def save_related_objects(self, obj, parent_resources=None): 948 | if not parent_resources: 949 | parent_resources = [self] 950 | else: 951 | parent_resources += [self] 952 | 953 | if self._dirty_fields: 954 | for field_name in set(self._dirty_fields) & set( 955 | self.get_save_related_fields() 956 | ): 957 | try: 958 | related_resource = self.get_related_resources()[field_name] 959 | except KeyError: 960 | related_resource = None 961 | 962 | field_instance = getattr(self.document, field_name) 963 | 964 | # If it's a ReferenceField, just save it. 965 | if isinstance(field_instance, ReferenceField): 966 | instance = getattr(obj, field_name) 967 | if instance: 968 | if related_resource: 969 | related_resource().save_object( 970 | instance, parent_resources=parent_resources 971 | ) 972 | else: 973 | instance.save() 974 | 975 | # If it's a ListField(ReferenceField), save all instances. 976 | if isinstance(field_instance, ListField) and isinstance( 977 | field_instance.field, ReferenceField 978 | ): 979 | instance_list = getattr(obj, field_name) 980 | for instance in instance_list: 981 | if related_resource: 982 | related_resource().save_object( 983 | instance, parent_resources=parent_resources 984 | ) 985 | else: 986 | instance.save() 987 | 988 | def save_object(self, obj, **kwargs): 989 | self.save_related_objects(obj, **kwargs) 990 | obj.save() 991 | obj.reload() 992 | 993 | self._dirty_fields = None # No longer dirty. 994 | 995 | def get_object_dict(self, data=None, update=False): 996 | if data is None: 997 | data = {} 998 | data = self.data or data 999 | filter_fields = set(self.document._fields.keys()) 1000 | if update: 1001 | # We want to update only the fields that appear in the request data 1002 | # rather than re-updating all the document's existing/other fields. 1003 | filter_fields &= { 1004 | self._reverse_rename_fields.get(field, field) 1005 | for field in self.raw_data.keys() 1006 | } 1007 | update_dict = { 1008 | field: value for field, value in data.items() if field in filter_fields 1009 | } 1010 | return update_dict 1011 | 1012 | def create_object(self, data=None, save=True, parent_resources=None): 1013 | update_dict = self.get_object_dict(data) 1014 | obj = self.document(**update_dict) 1015 | self._dirty_fields = update_dict.keys() 1016 | if save: 1017 | self.save_object(obj) 1018 | return obj 1019 | 1020 | def update_object(self, obj, data=None, save=True, parent_resources=None): 1021 | subresource = self._subresource(obj) 1022 | if subresource: 1023 | return subresource.update_object( 1024 | obj, data=data, save=save, parent_resources=parent_resources 1025 | ) 1026 | 1027 | update_dict = self.get_object_dict(data, update=True) 1028 | 1029 | self._dirty_fields = [] 1030 | 1031 | for field, value in update_dict.items(): 1032 | update = False 1033 | 1034 | # If we're comparing reference fields, only compare ids without 1035 | # hitting the database 1036 | if hasattr(obj, "_db_data") and isinstance( 1037 | obj._fields.get(field), ReferenceField 1038 | ): 1039 | db_val = obj._db_data.get(field) 1040 | id_from_obj = db_val and getattr(db_val, "id", db_val) 1041 | id_from_data = value and getattr(value, "pk", value) 1042 | if id_from_obj != id_from_data: 1043 | update = True 1044 | elif not equal(getattr(obj, field), value): 1045 | update = True 1046 | 1047 | if update: 1048 | setattr(obj, field, value) 1049 | self._dirty_fields.append(field) 1050 | 1051 | if save: 1052 | self.save_object(obj) 1053 | return obj 1054 | 1055 | def delete_object(self, obj, parent_resources=None): 1056 | obj.delete() 1057 | -------------------------------------------------------------------------------- /flask_mongorest/templates/mongorest/debug.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {{ config.APPLICATION_NAME }} -- API 5 | 6 | 7 |
{{ data }}
8 | 9 | 10 | -------------------------------------------------------------------------------- /flask_mongorest/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import decimal 3 | import json 4 | 5 | import mongoengine 6 | from bson.dbref import DBRef 7 | from bson.objectid import ObjectId 8 | 9 | isbound = lambda m: getattr(m, "im_self", None) is not None 10 | 11 | 12 | def isint(int_str): 13 | try: 14 | int(int_str) 15 | return True 16 | except (TypeError, ValueError): 17 | return False 18 | 19 | 20 | class MongoEncoder(json.JSONEncoder): 21 | def default(self, value, **kwargs): 22 | if isinstance(value, ObjectId): 23 | return str(value) 24 | if isinstance(value, DBRef): 25 | return value.id 26 | if isinstance(value, datetime.datetime): 27 | return value.isoformat() 28 | if isinstance(value, datetime.date): 29 | return value.strftime("%Y-%m-%d") 30 | if isinstance(value, decimal.Decimal): 31 | return str(value) 32 | return super(MongoEncoder, self).default(value, **kwargs) 33 | 34 | 35 | def cmp(a, b): 36 | return (a > b) - (a < b) 37 | 38 | 39 | def cmp_fields(ordering): 40 | # Takes a list of fields and directions and returns a 41 | # comparison function for sorted() to perform client-side 42 | # sorting. 43 | # Example: sorted(objs, cmp_fields([('date_created', -1)])) 44 | def _cmp(x, y): 45 | for field, direction in ordering: 46 | result = cmp(getattr(x, field), getattr(y, field)) * direction 47 | if result: 48 | return result 49 | return 0 50 | 51 | return _cmp 52 | 53 | 54 | def equal(a, b): 55 | """ 56 | Compare two objects. In addition to the "==" operator, this function 57 | ensures that the data of two mongoengine objects is the same. Also, it 58 | assumes that a UTC-TZ-aware datetime is equal to an unaware datetime if 59 | the date and time components match. 60 | """ 61 | # When comparing dicts (we serialize documents using to_dict) or lists 62 | # we may encounter datetime instances in the values, so compare them item 63 | # by item. 64 | if isinstance(a, dict) and isinstance(b, dict): 65 | if sorted(a.keys()) != sorted(b.keys()): 66 | return False 67 | return all(equal(b[k], v) for k, v in a.items()) 68 | 69 | if isinstance(a, list) and isinstance(b, list): 70 | if len(a) != len(b): 71 | return False 72 | return all([equal(m, n) for (m, n) in zip(a, b)]) 73 | 74 | # Two mongoengine objects are equal if their ID is equal. However, 75 | # in this case we want to check if the data is equal. Note this 76 | # doesn't look into mongoengine documents which are nested within 77 | # mongoengine documents. 78 | if isinstance(a, mongoengine.Document) and isinstance(b, mongoengine.Document): 79 | # Don't evaluate lazy documents 80 | if getattr(a, "_lazy", False) and getattr(b, "_lazy", False): 81 | return True 82 | return equal(dict(a.to_mongo()), dict(b.to_mongo())) 83 | 84 | # Since comparing an aware and unaware datetime results in an 85 | # exception and we may assign unaware datetimes to objects that 86 | # previously had an aware datetime, we convert aware datetimes 87 | # to their unaware equivalent before comparing. 88 | if isinstance(a, datetime.datetime) and isinstance(b, datetime.datetime): 89 | # This doesn't cover all the cases, but it covers the most 90 | # important case where the utcoffset is 0. 91 | if a.utcoffset() is not None and a.utcoffset() == datetime.timedelta(0): 92 | a = a.replace(tzinfo=None) 93 | if b.utcoffset() is not None and b.utcoffset() == datetime.timedelta(0): 94 | b = b.replace(tzinfo=None) 95 | 96 | try: 97 | return a == b 98 | except Exception: # Exception during comparison, mainly datetimes. 99 | return False 100 | -------------------------------------------------------------------------------- /flask_mongorest/views.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Type 3 | 4 | import mimerender 5 | import mongoengine 6 | from flask import render_template, request 7 | from flask.views import MethodView 8 | from werkzeug.exceptions import NotFound, Unauthorized 9 | 10 | from flask_mongorest import methods 11 | from flask_mongorest.authentication import AuthenticationBase 12 | from flask_mongorest.exceptions import ValidationError 13 | from flask_mongorest.methods import METHODS_TYPE 14 | from flask_mongorest.utils import MongoEncoder 15 | 16 | mimerender = mimerender.FlaskMimeRender() 17 | 18 | render_json = lambda **payload: json.dumps(payload, allow_nan=False, cls=MongoEncoder) 19 | render_html = lambda **payload: render_template( 20 | "mongorest/debug.html", 21 | data=json.dumps(payload, cls=MongoEncoder, sort_keys=True, indent=4), 22 | ) 23 | 24 | 25 | def get_exception_message(e): 26 | """ME ValidationError has compatibility code with py2.6 27 | that doesn't follow py3 .args interface. This works around that. 28 | """ 29 | from mongoengine.errors import ValidationError as MEValidationError 30 | 31 | if isinstance(e, MEValidationError) and not e.args: 32 | return e.message 33 | else: 34 | return e.args[0] 35 | 36 | 37 | def serialize_mongoengine_validation_error(e): 38 | """ 39 | Take a MongoEngine ValidationError as an argument, and returns a 40 | serializable error dict. Note that we can have nested ValidationErrors. 41 | """ 42 | 43 | def serialize_errors(e): 44 | if isinstance(e, Exception): 45 | return get_exception_message(e) 46 | elif hasattr(e, "items"): 47 | return {k: serialize_errors(v) for (k, v) in e.items()} 48 | else: 49 | return str(e) 50 | 51 | if e.errors: 52 | return {"field-errors": serialize_errors(e.errors)} 53 | else: 54 | return {"error": get_exception_message(e)} 55 | 56 | 57 | class ResourceView(MethodView): 58 | resource = None 59 | methods: List[METHODS_TYPE] = [] # type: ignore 60 | authentication_methods: List[Type[AuthenticationBase]] = [] 61 | 62 | def __init__(self): 63 | assert self.resource and self.methods 64 | 65 | @mimerender(default="json", json=render_json, html=render_html) 66 | def dispatch_request(self, *args, **kwargs): 67 | # keep all the logic in a helper method (_dispatch_request) so that 68 | # it's easy for subclasses to override this method (when they don't want to use 69 | # this mimerender decorator) without them also having to copy/paste all the 70 | # authentication logic, etc. 71 | return self._dispatch_request(*args, **kwargs) 72 | 73 | def _dispatch_request(self, *args, **kwargs): 74 | authorized = bool(len(self.authentication_methods) == 0) 75 | for authentication_method in self.authentication_methods: 76 | if authentication_method().authorized(): 77 | authorized = True 78 | if not authorized: 79 | return {"error": "Unauthorized"}, "401 Unauthorized" 80 | 81 | try: 82 | self._resource = self.requested_resource(request) 83 | return super(ResourceView, self).dispatch_request(*args, **kwargs) 84 | except mongoengine.queryset.DoesNotExist as e: 85 | return {"error": "Empty query: " + str(e)}, "404 Not Found" 86 | except ValidationError as e: 87 | return e.args[0], "400 Bad Request" 88 | except Unauthorized: 89 | return {"error": "Unauthorized"}, "401 Unauthorized" 90 | except NotFound as e: 91 | return {"error": str(e)}, "404 Not Found" 92 | 93 | def handle_validation_error(self, e): 94 | if isinstance(e, ValidationError): 95 | raise 96 | elif isinstance(e, mongoengine.ValidationError): 97 | raise ValidationError(serialize_mongoengine_validation_error(e)) 98 | else: 99 | raise 100 | 101 | def requested_resource(self, request): 102 | """In the case where the Resource that this view is associated with points to a Document class 103 | that allows inheritance, this method should indicate the specific Resource class to use 104 | when processing POST and PUT requests through information available in the request 105 | itself or through other means. 106 | """ 107 | # Default behavior is to use the (base) resource class 108 | return self.resource() 109 | 110 | def get(self, **kwargs): 111 | pk = kwargs.pop("pk", None) 112 | 113 | # Set the view_method on a resource instance 114 | if pk: 115 | self._resource.view_method = methods.Fetch 116 | else: 117 | self._resource.view_method = methods.List 118 | 119 | # Create a queryset filter to control read access to the 120 | # underlying objects 121 | qfilter = lambda qs: self.has_read_permission(request, qs.clone()) 122 | if pk is None: 123 | result = self._resource.get_objects(qfilter=qfilter) 124 | 125 | # Result usually contains objects and a has_more bool. However, in case where 126 | # more data is returned, we include it at the top level of the response dict 127 | if len(result) == 2: 128 | objs, has_more = result 129 | extra = {} 130 | elif len(result) == 3: 131 | objs, has_more, extra = result 132 | else: 133 | raise ValueError("Unsupported value of resource.get_objects") 134 | 135 | data = [] 136 | for obj in objs: 137 | try: 138 | data.append(self._resource.serialize(obj, params=request.args)) 139 | except Exception as e: 140 | fixed_obj = self._resource.handle_serialization_error(e, obj) 141 | if fixed_obj is not None: 142 | data.append(fixed_obj) 143 | 144 | # Serialize the objects one by one 145 | ret = {"data": data} 146 | 147 | if has_more is not None: 148 | ret["has_more"] = has_more 149 | 150 | if extra: 151 | ret.update(extra) 152 | else: 153 | obj = self._resource.get_object(pk, qfilter=qfilter) 154 | ret = self._resource.serialize(obj, params=request.args) 155 | return ret 156 | 157 | def post(self, **kwargs): 158 | if "pk" in kwargs: 159 | raise NotFound("Did you mean to use PUT?") 160 | 161 | # Set the view_method on a resource instance 162 | self._resource.view_method = methods.Create 163 | 164 | self._resource.validate_request() 165 | try: 166 | obj = self._resource.create_object() 167 | except Exception as e: 168 | self.handle_validation_error(e) 169 | 170 | # Check if we have permission to create this object 171 | if not self.has_add_permission(request, obj): 172 | raise Unauthorized 173 | 174 | ret = self._resource.serialize(obj, params=request.args) 175 | if isinstance(obj, mongoengine.Document) and self._resource.uri_prefix: 176 | return ret, "201 Created", {"Location": self._resource._url(str(obj.id))} 177 | else: 178 | return ret 179 | 180 | def process_object(self, obj): 181 | """Validate and update an object""" 182 | # Check if we have permission to change this object 183 | if not self.has_change_permission(request, obj): 184 | raise Unauthorized 185 | 186 | self._resource.validate_request(obj) 187 | 188 | try: 189 | obj = self._resource.update_object(obj) 190 | except Exception as e: 191 | self.handle_validation_error(e) 192 | 193 | def process_objects(self, objs): 194 | """ 195 | Update each object in the list one by one, and return the total count 196 | of updated objects. 197 | """ 198 | count = 0 199 | try: 200 | for obj in objs: 201 | self.process_object(obj) 202 | count += 1 # noqa: SIM113 203 | except ValidationError as e: 204 | e.args[0]["count"] = count 205 | raise e 206 | else: 207 | return {"count": count} 208 | 209 | def put(self, **kwargs): 210 | pk = kwargs.pop("pk", None) 211 | 212 | # Set the view_method on a resource instance 213 | if pk: 214 | self._resource.view_method = methods.Update 215 | else: 216 | self._resource.view_method = methods.BulkUpdate 217 | 218 | if pk is None: 219 | # Bulk update where the body contains the new values for certain 220 | # fields. 221 | 222 | # Currently, fetches all the objects and validates them separately. 223 | # If one of them fails, a ValidationError for this object will be 224 | # triggered. 225 | # Ideally, this would be translated into an update statement for 226 | # performance reasons and would perform the update either for all 227 | # objects, or for none, if (generic) validation fails. Since this 228 | # is a bulk update, only the count of objects which were updated is 229 | # returned. 230 | 231 | # Get a list of all objects matching the filters, capped at this 232 | # resource's `bulk_update_limit` 233 | result = self._resource.get_objects() 234 | if len(result) == 2: 235 | objs, has_more = result 236 | elif len(result) == 3: 237 | objs, has_more, extra = result 238 | 239 | # Update all the objects and return their count 240 | return self.process_objects(objs) 241 | else: 242 | obj = self._resource.get_object(pk) 243 | self.process_object(obj) 244 | ret = self._resource.serialize(obj, params=request.args) 245 | return ret 246 | 247 | def delete(self, **kwargs): 248 | pk = kwargs.pop("pk", None) 249 | 250 | # Set the view_method on a resource instance 251 | self._resource.view_method = methods.Delete 252 | 253 | obj = self._resource.get_object(pk) 254 | 255 | # Check if we have permission to delete this object 256 | if not self.has_delete_permission(request, obj): 257 | raise Unauthorized 258 | 259 | self._resource.delete_object(obj) 260 | return {} 261 | 262 | # This takes a QuerySet as an argument and then 263 | # returns a query set that this request can read 264 | def has_read_permission(self, request, qs): 265 | return qs 266 | 267 | def has_add_permission(self, request, obj): 268 | return True 269 | 270 | def has_change_permission(self, request, obj): 271 | return True 272 | 273 | def has_delete_permission(self, request, obj): 274 | return True 275 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | skip-magic-trailing-comma = true 3 | exclude = ''' 4 | /( 5 | \.git 6 | | \.venv 7 | | venv 8 | | src 9 | | \.eggs 10 | )/ 11 | ''' 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cleancat 2 | mongoengine 3 | flask-mongoengine 4 | mimerender 5 | python-dateutil 6 | Flask>=0.9 7 | pymongo 8 | flake8 9 | flake8-bugbear 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | verbosity=2 3 | detailed-errors=True 4 | cover-package=flask_mongorest 5 | cover-erase=True 6 | 7 | [flake8] 8 | ignore= 9 | # !!! make sure you have a comma at the end of each line EXCEPT the LAST one 10 | # Previously existing rules 11 | E2,E302,E5,W391,F403,F405,E501,E722,E731,F999,W504, 12 | # Future imports 13 | FI, 14 | # Missing docstrings 15 | D1, 16 | # One-line docstring should fit on one line with quotes. 17 | # We ignore this because it's OK to buy yourself a few extra characters 18 | # for the summary line even if the summary line is *the only* line. 19 | D200, 20 | # 1 blank line required between summary line and description 21 | D205, 22 | # Multi-line docstring summary should start at the first line. 23 | # We ignore this because we agreed in #20553 that we we want to put the 24 | # summary line below """ for multi-line docstrings. 25 | D212, 26 | # First line should end with a period 27 | D400, 28 | # First line should end with a period, question mark, or exclamation point. 29 | # TODO We should fix this. 30 | #D415, 31 | # variable in function should be lowercase - we use CONSTANT_LIKE stuff in functions 32 | #N806, 33 | # This is not PEP8-compliant and conflicts with black 34 | W503, 35 | W504, 36 | # This is not PEP8-compliant and conflicts with black 37 | E203, 38 | # Loop control variable 'x' not used within the loop body. 39 | #B007, 40 | # Do not call assert False 41 | #B011 42 | # Too intrusive, sometimes makes code less readable 43 | SIM106 44 | # Allow f-strings 45 | SFS301, 46 | # Allow .format 47 | SFS201, 48 | # exception class names -- these would require breaking changes 49 | N818 50 | exclude=build,dist,venv,.tox,.eggs,src 51 | max-complexity=35 52 | banned-modules= 53 | typing.Text = use str 54 | require-code=True 55 | 56 | [isort] 57 | skip=build,dist,venv,.tox,.eggs,src 58 | known_tests=tests 59 | sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,TESTS,LOCALFOLDER 60 | default_section=THIRDPARTY 61 | use_parentheses=true 62 | multi_line_output=3 63 | include_trailing_comma=True 64 | force_grid_wrap=0 65 | combine_as_imports=True 66 | line_length=87 67 | 68 | [mypy] 69 | python_version = 3.7 70 | ignore_missing_imports = True 71 | no_implicit_optional = True 72 | strict_equality = True 73 | follow_imports = normal 74 | warn_unreachable = True 75 | show_error_context = True 76 | pretty = True 77 | files = flask_mongorest -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | from setuptools import setup 4 | 5 | # Stops exit traceback on tests 6 | # TODO this makes flake8's F401 fail - maybe there's a better way 7 | with contextlib.suppress(Exception): 8 | import multiprocessing # noqa 9 | 10 | setup( 11 | name="Flask-MongoRest", 12 | version="0.2.3", 13 | url="http://github.com/closeio/flask-mongorest", 14 | license="BSD", 15 | author="Close.io", 16 | author_email="engineering@close.io", 17 | maintainer="Close.io", 18 | maintainer_email="engineering@close.io", 19 | description="Flask restful API framework for MongoDB/MongoEngine", 20 | long_description=__doc__, 21 | packages=["flask_mongorest"], 22 | package_data={"flask_mongorest": ["templates/mongorest/*"]}, 23 | test_suite="nose.collector", 24 | zip_safe=False, 25 | platforms="any", 26 | setup_requires=[ 27 | "Flask-MongoEngine", 28 | "mimerender", 29 | "nose", 30 | "python-dateutil", 31 | "cleancat", 32 | ], 33 | classifiers=[ 34 | "Development Status :: 4 - Beta", 35 | "Environment :: Web Environment", 36 | "Intended Audience :: Developers", 37 | "License :: OSI Approved :: BSD License", 38 | "Operating System :: OS Independent", 39 | "Programming Language :: Python", 40 | "Topic :: Internet :: WWW/HTTP :: Dynamic Content", 41 | "Topic :: Software Development :: Libraries :: Python Modules", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | import datetime 5 | import json 6 | import unittest 7 | 8 | from mongoengine.context_managers import query_counter 9 | from mongoengine.errors import ValidationError 10 | 11 | import example.app as example 12 | 13 | try: 14 | from mongoengine import SafeReferenceField 15 | except ImportError: 16 | SafeReferenceField = None 17 | 18 | 19 | # HACK: 20 | # Because mongoengine doesn't allow you to customize the connection alias, and 21 | # because flask_mongoengine uses a different DEFAULT_CONNECTION_NAME from 22 | # mongoengine, we need to override this method to use flask_mongoengine's 23 | # get_db() method instead of mongoengine's get_db() method. 24 | try: 25 | from flask_mongoengine import get_db 26 | 27 | class OurQueryCounter(query_counter): 28 | def __init__(self): 29 | self.counter = 0 30 | self.db = get_db() 31 | 32 | query_counter = OurQueryCounter 33 | 34 | 35 | except ImportError: 36 | # Older versions of flask-mongoengine don't have this issue. 37 | pass 38 | 39 | 40 | def response_success(response, code=None): 41 | if code is None: 42 | assert ( 43 | 200 <= response.status_code < 300 44 | ), f"Received {response.status_code} response: {response.data}" 45 | else: 46 | assert ( 47 | code == response.status_code 48 | ), f"Received {response.status_code} response: {response.data}" 49 | 50 | 51 | def response_error(response, code=None): 52 | if code is None: 53 | assert ( 54 | 400 <= response.status_code < 500 55 | ), f"Received {response.status_code} response: {response.data}" 56 | else: 57 | assert ( 58 | code == response.status_code 59 | ), f"Received {response.status_code} response: {response.data}" 60 | 61 | 62 | def compare_req_resp(req_obj, resp_obj): 63 | for k, v in req_obj.items(): 64 | assert ( 65 | k in resp_obj 66 | ), f"Key {k!r} not in response (keys are {resp_obj.keys()!r})" 67 | assert ( 68 | resp_obj[k] == v 69 | ), f"Value for key {k!r} should be {v!r} but is {resp_obj[k]}" 70 | 71 | 72 | def resp_json(resp): 73 | return json.loads(resp.get_data(as_text=True)) 74 | 75 | 76 | class MongoRestTestCase(unittest.TestCase): 77 | def setUp(self): 78 | self.user_1 = { 79 | "email": "1@b.com", 80 | "first_name": "alan", 81 | "last_name": "baker", 82 | "datetime": "2012-10-09T10:00:00", 83 | } 84 | 85 | self.user_2 = { 86 | "email": "2@b.com", 87 | "first_name": "olivia", 88 | "last_name": "baker", 89 | "datetime": "2012-11-09T11:00:00", 90 | } 91 | 92 | self.post_1 = { 93 | "title": "first post!", 94 | # author 95 | # editor 96 | "tags": ["tag1", "tag2", "tag3"], 97 | # user_lists 98 | "sections": [ 99 | {"text": "this is the first section of the first post.", "lang": "en"}, 100 | {"text": "this is the second section of the first post.", "lang": "de"}, 101 | {"text": "this is the third section of the first post.", "lang": "fr"}, 102 | ], 103 | "content": {"text": "this is the content for my first post.", "lang": "cn"}, 104 | "is_published": True, 105 | } 106 | 107 | self.post_2 = {"title": "Second post", "is_published": False} 108 | 109 | self.app = example.app.test_client() 110 | example.documents.User.drop_collection() 111 | example.documents.Post.drop_collection() 112 | example.TestDocument.drop_collection() 113 | example.A.drop_collection() 114 | example.B.drop_collection() 115 | example.C.drop_collection() 116 | example.MethodTestDoc.drop_collection() 117 | example.DictDoc.drop_collection() 118 | 119 | # create user 1 120 | resp = self.app.post("/user/", data=json.dumps(self.user_1)) 121 | response_success(resp) 122 | self.user_1_obj = resp_json(resp) 123 | compare_req_resp(self.user_1, self.user_1_obj) 124 | 125 | # create user 2 126 | resp = self.app.post("/user/", data=json.dumps(self.user_2)) 127 | response_success(resp) 128 | self.user_2_obj = resp_json(resp) 129 | compare_req_resp(self.user_2, self.user_2_obj) 130 | 131 | def tearDown(self): 132 | # delete user 1 133 | resp = self.app.delete(f"/user/{self.user_1_obj['id']}/") 134 | response_success(resp) 135 | resp = self.app.get(f"/user/{self.user_1_obj['id']}/") 136 | response_error(resp, code=404) 137 | 138 | # delete user 2 139 | resp = self.app.delete(f"/user/{self.user_2_obj['id']}/") 140 | response_success(resp) 141 | resp = self.app.get(f"/user/{self.user_2_obj['id']}/") 142 | response_error(resp, code=404) 143 | 144 | def test_update_user(self): 145 | self.user_1_obj["first_name"] = "anthony" 146 | self.user_1_obj["datetime"] = datetime.datetime.utcnow().isoformat() 147 | resp = self.app.put( 148 | f"/user/{self.user_1_obj['id']}/", data=json.dumps(self.user_1_obj) 149 | ) 150 | response_success(resp) 151 | 152 | # check for request params in response, except for date (since the format will differ) 153 | data_to_check = copy.copy(self.user_1_obj) 154 | del data_to_check["datetime"] 155 | data = resp_json(resp) 156 | compare_req_resp(data_to_check, data) 157 | 158 | # response from PUT should be completely identical as a subsequent GET 159 | # (including precision of datetimes) 160 | resp = self.app.get(f"/user/{self.user_1_obj['id']}/") 161 | data2 = resp_json(resp) 162 | self.assertEqual(data, data2) 163 | 164 | def test_unicode(self): 165 | """ 166 | Make sure unicode data payloads are properly decoded. 167 | """ 168 | self.user_1_obj["first_name"] = "Jörg" 169 | 170 | # Don't encode unicode characters 171 | resp = self.app.put( 172 | f"/user/{self.user_1_obj['id']}/", 173 | data=json.dumps(self.user_1_obj, ensure_ascii=False), 174 | ) 175 | response_success(resp) 176 | data = resp_json(resp) 177 | compare_req_resp(self.user_1_obj, data) 178 | 179 | # Encode unicode characters as "\uxxxx" (default) 180 | resp = self.app.put( 181 | f"/user/{self.user_1_obj['id']}/", 182 | data=json.dumps(self.user_1_obj, ensure_ascii=True), 183 | ) 184 | response_success(resp) 185 | data = resp_json(resp) 186 | compare_req_resp(self.user_1_obj, data) 187 | 188 | def test_model_validation_unicode(self): 189 | # MongoEngine validation error (no schema) 190 | resp = self.app.post("/test/", data=json.dumps({"email": "💩"})) 191 | response_error(resp) 192 | errors = resp_json(resp) 193 | self.assertTrue( 194 | errors 195 | in [ 196 | {"field-errors": {"email": "Invalid email address: 💩"}}, 197 | { 198 | # Workaround for 199 | # https://github.com/MongoEngine/mongoengine/pull/1384 200 | "field-errors": {"email": "Invalid Mail-address: 💩"} 201 | }, 202 | ] 203 | ) 204 | 205 | # Schema validation error 206 | resp = self.app.post( 207 | "/user/", 208 | data=json.dumps({"email": "test@example.com", "datetime": "invalid"}), 209 | ) 210 | response_error(resp) 211 | errors = resp_json(resp) 212 | self.assertEqual( 213 | errors, {"errors": [], "field-errors": {"datetime": "Invalid date 💩"}} 214 | ) 215 | 216 | def test_model_validation(self): 217 | resp = self.app.post( 218 | "/user/", 219 | data=json.dumps( 220 | { 221 | "email": "invalid", 222 | "first_name": "joe", 223 | "last_name": "baker", 224 | "datetime": "2012-08-13T05:25:04.362Z", 225 | "datetime_local": "2012-08-13T05:25:04.362-03:30", 226 | } 227 | ), 228 | ) 229 | response_error(resp) 230 | errors = resp_json(resp) 231 | self.assertTrue("field-errors" in errors) 232 | self.assertEqual(set(errors["field-errors"]), {"email"}) 233 | 234 | resp = self.app.put( 235 | f"/user/{self.user_1_obj['id']}/", 236 | data=json.dumps( 237 | {"email": "invalid", "first_name": "joe", "last_name": "baker"} 238 | ), 239 | ) 240 | response_error(resp) 241 | errors = resp_json(resp) 242 | self.assertTrue("field-errors" in errors) 243 | self.assertEqual(set(errors["field-errors"]), {"email"}) 244 | 245 | resp = self.app.put( 246 | f"/user/{self.user_1_obj['id']}/", 247 | data=json.dumps( 248 | { 249 | "emails": [ 250 | "one@example.com", 251 | "invalid", 252 | "second@example.com", 253 | "invalid2", 254 | ] 255 | } 256 | ), 257 | ) 258 | 259 | response_error(resp) 260 | errors = resp_json(resp) 261 | self.assertTrue("field-errors" in errors) 262 | self.assertEqual(set(errors["field-errors"]), {"emails"}) 263 | self.assertEqual(set(errors["field-errors"]["emails"]["errors"]), {"1", "3"}) 264 | 265 | def test_resource_fields(self): 266 | resp = self.app.post( 267 | "/testfields/", 268 | data=json.dumps( 269 | {"name": "thename", "other": "othervalue", "upper_name": "INVALID"} 270 | ), 271 | ) 272 | response_success(resp) 273 | obj = resp_json(resp) 274 | 275 | self.assertEqual(set(obj), {"id", "name", "upper_name"}) 276 | self.assertEqual(obj["name"], "thename") 277 | self.assertEqual(obj["upper_name"], "THENAME") 278 | 279 | resp = self.app.get(f"/test/{obj['id']}/") 280 | response_success(resp) 281 | obj = resp_json(resp) 282 | 283 | self.assertEqual(obj["name"], "thename") 284 | # We can edit all the fields since we don't have a schema 285 | # self.assertEqual(obj['other'], None) 286 | 287 | resp = self.app.put( 288 | f"/test/{obj['id']}/", data=json.dumps({"other": "new othervalue"}) 289 | ) 290 | response_success(resp) 291 | obj = resp_json(resp) 292 | self.assertEqual(obj["name"], "thename") 293 | self.assertEqual(obj["other"], "new othervalue") 294 | 295 | resp = self.app.put( 296 | f"/testfields/{obj['id']}/", 297 | data=json.dumps({"name": "namevalue2", "upper_name": "INVALID"}), 298 | ) 299 | response_success(resp) 300 | obj = resp_json(resp) 301 | self.assertEqual(obj["name"], "namevalue2") 302 | self.assertEqual(obj["upper_name"], "NAMEVALUE2") 303 | 304 | def test_restricted_auth(self): 305 | self.post_1["author_id"] = self.user_1_obj["id"] 306 | self.post_1["editor"] = self.user_2_obj["id"] 307 | self.post_1["user_lists"] = [self.user_1_obj["id"], self.user_2_obj["id"]] 308 | 309 | resp = self.app.get("/user/") 310 | objs = resp_json(resp)["data"] 311 | self.assertEqual(len(objs), 2) 312 | 313 | post = self.post_1.copy() 314 | 315 | # Try to create an already published Post 316 | post["is_published"] = True 317 | resp = self.app.post("/restricted/", data=json.dumps(post)) 318 | # Not allowed, must be added in unpublished state 319 | response_success(resp, code=401) 320 | 321 | # Try again, but with is_published set to False 322 | post["is_published"] = False 323 | resp = self.app.post("/restricted/", data=json.dumps(post)) 324 | # Should be OK 325 | response_success(resp, code=200) 326 | 327 | # Get data about the Post we just POSTed 328 | data = resp_json(resp) 329 | 330 | # Look at current number of posts through an unrestricted view 331 | resp = self.app.get("/posts/") 332 | tmp = resp_json(resp) 333 | nposts = len(tmp["data"]) 334 | # Should see 2 335 | self.assertEqual(2, nposts) 336 | 337 | # extra data returned in get_objects 338 | self.assertEqual(tmp["more"], "stuff") 339 | 340 | # Now look at posts through a restricted view 341 | resp = self.app.get("/restricted/") 342 | tmp = resp_json(resp) 343 | npubposts = len(tmp["data"]) 344 | # Should only see 1 (published) 345 | self.assertEqual(1, npubposts) 346 | 347 | # Try to change the title 348 | post["title"] = "New title" 349 | resp = self.app.put(f"/restricted/{str(data['id'])}/", data=json.dumps(post)) 350 | # Works because we haven't published it yet 351 | response_success(resp, code=200) 352 | 353 | # Now let's publish it 354 | post["is_published"] = True 355 | resp = self.app.put(f"/restricted/{str(data['id'])}/", data=json.dumps(post)) 356 | # This works because the object we are changing is still 357 | # in the unpublished state before we update it 358 | response_success(resp, code=200) 359 | 360 | # Now change the title again 361 | post["title"] = "Another title" 362 | resp = self.app.put(f"/restricted/{str(data['id'])}/", data=json.dumps(post)) 363 | # Can't do it, object has already been published 364 | response_success(resp, code=401) 365 | 366 | # Try to delete this post 367 | resp = self.app.delete(f"/restricted/{str(data['id'])}/", data=json.dumps(post)) 368 | # Again, won't work because it was already published 369 | response_success(resp, code=401) 370 | 371 | # OK, let's create another post 372 | post = self.post_1.copy() 373 | 374 | # Create it in the unpublished state 375 | post["is_published"] = False 376 | resp = self.app.post("/restricted/", data=json.dumps(post)) 377 | # Should work 378 | response_success(resp, code=200) 379 | 380 | data = resp_json(resp) 381 | 382 | # Now let's try and delete an unpublished post 383 | resp = self.app.delete(f"/restricted/{data['id']}/", data=json.dumps(post)) 384 | # Should work 385 | response_success(resp, code=200) 386 | 387 | def test_get(self): 388 | resp = self.app.get("/user/") 389 | objs = resp_json(resp)["data"] 390 | self.assertEqual(len(objs), 2) 391 | 392 | def test_get_primary_user(self): 393 | self.post_1["author_id"] = self.user_1_obj["id"] 394 | self.post_1["editor"] = self.user_2_obj["id"] 395 | self.post_1["user_lists"] = [self.user_1_obj["id"], self.user_2_obj["id"]] 396 | resp = self.app.post("/posts/", data=json.dumps(self.post_1)) 397 | resp = self.app.get("/posts/?_include_primary_user=1") 398 | objs = resp_json(resp)["data"] 399 | self.assertEqual(len(objs), 1) 400 | self.assertEqual(objs[0]["title"], "first post!") 401 | self.assertTrue(len(objs[0]["primary_user"]) > 0) 402 | 403 | def test_get_empty_primary_user(self): 404 | resp = self.app.post("/posts/", data=json.dumps(self.post_2)) 405 | resp = self.app.get("/posts/?_include_primary_user=1") 406 | objs = resp_json(resp)["data"] 407 | self.assertEqual(len(objs), 1) 408 | self.assertEqual(objs[0]["title"], "Second post") 409 | self.assertEqual(objs[0]["primary_user"], None) 410 | 411 | def test_post(self): 412 | self.post_1["author_id"] = self.user_1_obj["id"] 413 | self.post_1["editor"] = self.user_2_obj["id"] 414 | self.post_1["user_lists"] = [self.user_1_obj["id"], self.user_2_obj["id"]] 415 | resp = self.app.post("/posts/", data=json.dumps(self.post_1)) 416 | response_success(resp) 417 | compare_req_resp(self.post_1, resp_json(resp)) 418 | self.post_1_obj = resp_json(resp) 419 | resp = self.app.get(f"/posts/{self.post_1_obj['id']}/") 420 | response_success(resp) 421 | compare_req_resp(self.post_1_obj, resp_json(resp)) 422 | 423 | self.post_1_obj["author_id"] = self.user_2_obj["id"] 424 | resp = self.app.put( 425 | f"/posts/{self.post_1_obj['id']}/", data=json.dumps(self.post_1_obj) 426 | ) 427 | response_success(resp) 428 | jd = resp_json(resp) 429 | self.assertEqual(self.post_1_obj["author_id"], jd["author_id"]) 430 | 431 | response_success(resp) 432 | compare_req_resp(self.post_1_obj, resp_json(resp)) 433 | self.post_1_obj = resp_json(resp) 434 | 435 | resp = self.app.post("/posts/", data=json.dumps(self.post_2)) 436 | response_success(resp) 437 | compare_req_resp(self.post_2, resp_json(resp)) 438 | self.post_2_obj = resp_json(resp) 439 | 440 | # test filtering 441 | 442 | resp = self.app.get("/posts/?title__startswith=first") 443 | response_success(resp) 444 | data_list = resp_json(resp)["data"] 445 | compare_req_resp(self.post_1_obj, data_list[0]) 446 | 447 | resp = self.app.get("/posts/?title__startswith=second") 448 | response_success(resp) 449 | data_list = resp_json(resp)["data"] 450 | self.assertEqual(data_list, []) 451 | 452 | resp = self.app.get( 453 | f"/posts/?title__in={self.post_1_obj['title']},{self.post_2_obj['title']}" 454 | ) 455 | response_success(resp) 456 | posts = resp_json(resp) 457 | self.assertEqual(len(posts["data"]), 2) 458 | 459 | resp = self.app.get("/posts/?title__in=") 460 | response_success(resp) 461 | posts = resp_json(resp) 462 | self.assertEqual(len(posts["data"]), 0) 463 | 464 | resp = self.app.get(f"/user/?datetime={'2012-10-09 10:00:00'}") 465 | response_success(resp) 466 | users = resp_json(resp) 467 | self.assertEqual(len(users["data"]), 1) 468 | 469 | resp = self.app.get(f"/user/?datetime__gt={'2012-10-08 10:00:00'}") 470 | response_success(resp) 471 | users = resp_json(resp) 472 | self.assertEqual(len(users["data"]), 2) 473 | 474 | resp = self.app.get(f"/user/?datetime__gte={'2012-10-09 10:00:00'}") 475 | response_success(resp) 476 | users = resp_json(resp) 477 | self.assertEqual(len(users["data"]), 2) 478 | 479 | # test negation 480 | 481 | # exclude many 482 | resp = self.app.get( 483 | "/posts/?title__not__in={},{}".format( 484 | self.post_1_obj["title"], self.post_2_obj["title"] 485 | ) 486 | ) 487 | response_success(resp) 488 | posts = resp_json(resp) 489 | self.assertEqual(len(posts["data"]), 0) 490 | 491 | # exclude one 492 | resp = self.app.get(f"/posts/?title__not__in={self.post_1_obj['title']}") 493 | response_success(resp) 494 | posts = resp_json(resp) 495 | self.assertEqual(len(posts["data"]), 1) 496 | 497 | resp = self.app.get(f"/posts/?author_id={self.user_2_obj['id']}") 498 | response_success(resp) 499 | data_list = resp_json(resp)["data"] 500 | compare_req_resp(self.post_1_obj, data_list[0]) 501 | 502 | resp = self.app.get("/posts/?is_published=true") 503 | response_success(resp) 504 | data_list = resp_json(resp)["data"] 505 | self.assertEqual(len(data_list), 1) 506 | compare_req_resp(self.post_1_obj, data_list[0]) 507 | 508 | resp = self.app.get("/posts/?is_published=false") 509 | response_success(resp) 510 | data_list = resp_json(resp)["data"] 511 | self.assertEqual(len(data_list), 1) 512 | compare_req_resp(self.post_2_obj, data_list[0]) 513 | 514 | # default exact filtering 515 | resp = self.app.get("/posts/?title__exact=first post!") 516 | data_list_1 = resp_json(resp)["data"] 517 | resp = self.app.get("/posts/?title=first post!") 518 | data_list_2 = resp_json(resp)["data"] 519 | self.assertEqual(data_list_1, data_list_2) 520 | 521 | # test bulk update 522 | resp = self.app.put( 523 | "/posts/?title__startswith=first", 524 | data=json.dumps({"description": "Some description"}), 525 | ) 526 | response_success(resp) 527 | data = resp_json(resp) 528 | self.assertEqual(data["count"], 1) 529 | 530 | resp = self.app.put( 531 | "/posts/", data=json.dumps({"description": "Other description"}) 532 | ) 533 | response_success(resp) 534 | data = resp_json(resp) 535 | self.assertEqual(data["count"], 2) 536 | 537 | resp = self.app.get("/posts/") 538 | response_success(resp) 539 | data_list = resp_json(resp)["data"] 540 | self.assertEqual(data_list[0]["description"], "Other description") 541 | self.assertEqual(data_list[1]["description"], "Other description") 542 | 543 | resp = self.app.put( 544 | "/posts/", data=json.dumps({"description": "X" * 121}) # too long 545 | ) 546 | response_error(resp) 547 | data = resp_json(resp) 548 | self.assertEqual(data["count"], 0) 549 | self.assertEqual(set(data["field-errors"]), {"description"}) 550 | 551 | def test_post_auto_art_tag(self): 552 | # create a post by vangogh and an 'art' tag should be added automatically 553 | 554 | # create vangogh 555 | resp = self.app.post( 556 | "/user/", 557 | data=json.dumps( 558 | { 559 | "email": "vincent@vangogh.com", 560 | "first_name": "Vincent", 561 | "last_name": "Vangogh", 562 | } 563 | ), 564 | ) 565 | response_success(resp) 566 | author = resp_json(resp)["id"] 567 | 568 | # create a post 569 | resp = self.app.post("/posts/", data=json.dumps(self.post_1)) 570 | response_success(resp) 571 | post = resp_json(resp) 572 | 573 | resp = self.app.put( 574 | f"/posts/{post['id']}/", data=json.dumps({"author_id": author}) 575 | ) 576 | response_success(resp) 577 | post = resp_json(resp) 578 | post_obj = example.documents.Post.objects.get(pk=post["id"]) 579 | self.assertTrue("art" in post_obj.tags) 580 | self.assertTrue("art" in post["tags"]) 581 | 582 | @unittest.skipIf(not SafeReferenceField, "SafeReferenceField not available") 583 | def test_broken_reference(self): 584 | # create a new user 585 | resp = self.app.post( 586 | "/user/", 587 | data=json.dumps( 588 | { 589 | "email": "3@b.com", 590 | "first_name": "steve", 591 | "last_name": "wiseman", 592 | "datetime": "2012-11-09T11:00:00", 593 | } 594 | ), 595 | ) 596 | response_success(resp) 597 | user_3 = resp_json(resp) 598 | 599 | post = self.post_1.copy() 600 | post["author_id"] = self.user_1_obj["id"] 601 | post["editor"] = self.user_2_obj["id"] 602 | post["user_lists"] = [user_3["id"]] 603 | resp = self.app.post("/posts/", data=json.dumps(post)) 604 | response_success(resp) 605 | compare_req_resp(post, resp_json(resp)) 606 | 607 | post = resp_json(resp) 608 | 609 | # remove the user and see if its reference is cleaned up properly 610 | resp = self.app.delete(f"/user/{user_3['id']}/") 611 | response_success(resp) 612 | 613 | resp = self.app.get(f"/posts/{post['id']}/") 614 | response_success(resp) 615 | 616 | self.assertEqual(resp_json(resp)["user_lists"], []) 617 | 618 | post["user_lists"] = [] 619 | compare_req_resp(post, resp_json(resp)) 620 | 621 | def test_dummy_auth(self): 622 | resp = self.app.get("/auth/") 623 | response_success(resp, code=401) 624 | 625 | def test_pagination(self): 626 | # create 101 posts 627 | post = self.post_1.copy() 628 | for i in range(1, 102): 629 | post["title"] = f"Post #{i}" 630 | resp = self.app.post("/posts/", data=json.dumps(post)) 631 | response_success(resp) 632 | 633 | resp = self.app.get("/posts/?_limit=10") 634 | response_success(resp) 635 | data = resp_json(resp) 636 | self.assertEqual(len(data["data"]), 10) 637 | self.assertEqual(data["has_more"], True) 638 | 639 | resp = self.app.get("/posts/?_skip=100") 640 | response_success(resp) 641 | data = resp_json(resp) 642 | self.assertEqual(len(data["data"]), 1) 643 | self.assertEqual(data["has_more"], False) 644 | 645 | resp = self.app.get("/posts/?_limit=1") 646 | response_success(resp) 647 | data = resp_json(resp) 648 | self.assertEqual(len(data["data"]), 1) 649 | self.assertEqual(data["has_more"], True) 650 | 651 | resp = self.app.get("/posts/?_limit=0") 652 | response_success(resp) 653 | data = resp_json(resp) 654 | self.assertEqual(len(data["data"]), 0) 655 | self.assertEqual(data["has_more"], True) 656 | 657 | resp = self.app.get("/posts/?_skip=100&_limit=1") 658 | response_success(resp) 659 | data = resp_json(resp) 660 | self.assertEqual(len(data["data"]), 1) 661 | self.assertEqual(data["has_more"], False) 662 | 663 | # default limit 664 | resp = self.app.get("/posts/") 665 | response_success(resp) 666 | data = resp_json(resp) 667 | self.assertEqual(len(data["data"]), 100) 668 | 669 | # _limit > max_limit 670 | resp = self.app.get("/posts/?_limit=101") 671 | response_error(resp, code=400) 672 | data = resp_json(resp) 673 | self.assertEqual( 674 | data["error"], 675 | "The limit you set is larger than the maximum limit for this resource (max_limit = 100).", 676 | ) 677 | 678 | # respect custom max_limit 679 | resp = self.app.get("/posts10/?_limit=11") 680 | response_error(resp, code=400) 681 | data = resp_json(resp) 682 | self.assertEqual( 683 | data["error"], 684 | "The limit you set is larger than the maximum limit for this resource (max_limit = 10).", 685 | ) 686 | 687 | resp = self.app.get("/posts10/") 688 | response_success(resp) 689 | data = resp_json(resp) 690 | self.assertEqual(len(data["data"]), 10) 691 | 692 | resp = self.app.get("/posts10/?_limit=5") 693 | response_success(resp) 694 | data = resp_json(resp) 695 | self.assertEqual(len(data["data"]), 5) 696 | 697 | resp = self.app.get("/posts250/?_limit=251") 698 | response_error(resp, code=400) 699 | data = resp_json(resp) 700 | self.assertEqual( 701 | data["error"], 702 | "The limit you set is larger than the maximum limit for this resource (max_limit = 250).", 703 | ) 704 | 705 | resp = self.app.get("/posts250/") 706 | response_success(resp) 707 | data = resp_json(resp) 708 | self.assertEqual(len(data["data"]), 100) 709 | 710 | resp = self.app.get("/posts250/?_limit=10") 711 | response_success(resp) 712 | data = resp_json(resp) 713 | self.assertEqual(len(data["data"]), 10) 714 | 715 | def test_garbage_args(self): 716 | resp = self.app.get("/posts/?_limit=garbage") 717 | response_error(resp, code=400) 718 | self.assertEqual( 719 | resp_json(resp)["error"], 720 | '_limit must be an integer (got "garbage" instead).', 721 | ) 722 | 723 | resp = self.app.get("/posts/?_skip=garbage") 724 | response_error(resp, code=400) 725 | self.assertEqual( 726 | resp_json(resp)["error"], 727 | '_skip must be an integer (got "garbage" instead).', 728 | ) 729 | 730 | resp = self.app.get("/posts/?_skip=-1") 731 | response_error(resp, code=400) 732 | self.assertEqual( 733 | resp_json(resp)["error"], 734 | '_skip must be a non-negative integer (got "-1" instead).', 735 | ) 736 | 737 | def test_fields(self): 738 | resp = self.app.get(f"/user/{self.user_1_obj['id']}/?_fields=email") 739 | response_success(resp) 740 | user = resp_json(resp) 741 | self.assertEqual(set(user), {"email"}) 742 | 743 | resp = self.app.get( 744 | f"/user/{self.user_1_obj['id']}/?_fields=first_name,last_name" 745 | ) 746 | response_success(resp) 747 | user = resp_json(resp) 748 | self.assertEqual(set(user), {"first_name", "last_name"}) 749 | 750 | # Make sure all fields can still be posted. 751 | test_user_data = { 752 | "email": "u@example.com", 753 | "first_name": "first", 754 | "last_name": "first", 755 | "balance": 54, 756 | } 757 | 758 | resp = self.app.post("/user/?_fields=id", data=json.dumps(test_user_data)) 759 | response_success(resp) 760 | user = resp_json(resp) 761 | self.assertEqual(set(user), {"id"}) 762 | 763 | def test_invalid_json(self): 764 | resp = self.app.post("/user/", data='{"}') 765 | response_error(resp, code=400) 766 | resp = resp_json(resp) 767 | self.assertEqual(resp["error"], "The request contains invalid JSON.") 768 | 769 | def test_chunked_request(self): 770 | resp = self.app.post( 771 | "/a/", 772 | data=json.dumps({"txt": "test"}), 773 | headers={"Transfer-Encoding": "chunked"}, 774 | ) 775 | response_error(resp, code=400) 776 | self.assertEqual( 777 | resp_json(resp), {"error": "Chunked Transfer-Encoding is not supported."} 778 | ) 779 | 780 | # Original MongoEngine does not support assigning a string ID to a dbref 781 | # reference -- we'd have to use a schema. 782 | @unittest.skipIf(not SafeReferenceField, "SafeReferenceField not available") 783 | def test_dbref_vs_objectid(self): 784 | resp = self.app.post("/a/", data=json.dumps({"txt": "some text 1"})) 785 | response_success(resp) 786 | a1 = resp_json(resp) 787 | 788 | resp = self.app.post("/a/", data=json.dumps({"txt": "some text 2"})) 789 | response_success(resp) 790 | a2 = resp_json(resp) 791 | 792 | resp = self.app.post("/b/", data=json.dumps({"ref": a1["id"], "txt": "text"})) 793 | response_success(resp) 794 | dbref_obj = resp_json(resp) 795 | 796 | resp = self.app.post("/c/", data=json.dumps({"ref": a1["id"], "txt": "text"})) 797 | response_success(resp) 798 | objectid_obj = resp_json(resp) 799 | 800 | # compare objects with a dbref reference and an objectid reference 801 | resp = self.app.get(f"/b/{dbref_obj['id']}/") 802 | response_success(resp) 803 | dbref_obj = resp_json(resp) 804 | 805 | resp = self.app.get(f"/c/{objectid_obj['id']}/") 806 | response_success(resp) 807 | objectid_obj = resp_json(resp) 808 | 809 | self.assertEqual(dbref_obj["ref"], objectid_obj["ref"]) 810 | self.assertEqual(dbref_obj["txt"], objectid_obj["txt"]) 811 | 812 | # make sure both dbref and objectid are updated correctly 813 | resp = self.app.put( 814 | f"/b/{dbref_obj['id']}/", data=json.dumps({"ref": a2["id"]}) 815 | ) 816 | response_success(resp) 817 | 818 | resp = self.app.put( 819 | f"/c/{objectid_obj['id']}/", data=json.dumps({"ref": a2["id"]}) 820 | ) 821 | response_success(resp) 822 | 823 | resp = self.app.get(f"/b/{dbref_obj['id']}/") 824 | response_success(resp) 825 | dbref_obj = resp_json(resp) 826 | 827 | resp = self.app.get(f"/c/{objectid_obj['id']}/") 828 | response_success(resp) 829 | objectid_obj = resp_json(resp) 830 | 831 | self.assertEqual(dbref_obj["ref"], a2["id"]) 832 | self.assertEqual(dbref_obj["ref"], objectid_obj["ref"]) 833 | self.assertEqual(dbref_obj["txt"], objectid_obj["txt"]) 834 | 835 | def test_view_methods(self): 836 | doc = example.ViewMethodTestDoc.objects.create(txt="doc1") 837 | 838 | resp = self.app.get(f"/test_view_method/{doc.pk}/") 839 | response_success(resp) 840 | self.assertEqual(resp_json(resp), {"method": "Fetch"}) 841 | 842 | resp = self.app.get("/test_view_method/") 843 | response_success(resp) 844 | self.assertEqual(resp_json(resp), {"method": "List"}) 845 | 846 | resp = self.app.post("/test_view_method/", data=json.dumps({"txt": "doc2"})) 847 | response_success(resp) 848 | self.assertEqual(resp_json(resp), {"method": "Create"}) 849 | 850 | resp = self.app.put( 851 | f"/test_view_method/{doc.pk}/", data=json.dumps({"txt": "doc1new"}) 852 | ) 853 | response_success(resp) 854 | self.assertEqual(resp_json(resp), {"method": "Update"}) 855 | 856 | resp = self.app.put("/test_view_method/", data=json.dumps({"txt": "doc"})) 857 | response_success(resp) 858 | self.assertEqual(resp_json(resp), {"method": "BulkUpdate"}) 859 | 860 | resp = self.app.delete( 861 | f"/test_view_method/{doc.pk}/", data=json.dumps({"txt": "doc"}) 862 | ) 863 | response_success(resp) 864 | self.assertEqual(resp_json(resp), {"method": "Delete"}) 865 | 866 | def test_methods_success(self): 867 | doc1 = example.MethodTestDoc.objects.create(txt="doc1") 868 | doc2 = example.MethodTestDoc.objects.create(txt="doc2") 869 | 870 | resp = self.app.get(f"/fetch_only/{doc1.pk}/") 871 | response_success(resp) 872 | 873 | resp = self.app.get("/list_only/") 874 | response_success(resp) 875 | 876 | resp = self.app.post("/create_only/", data=json.dumps({"txt": "created"})) 877 | response_success(resp) 878 | 879 | resp = self.app.put( 880 | f"/update_only/{doc2.pk}/", data=json.dumps({"txt": "works"}) 881 | ) 882 | response_success(resp) 883 | 884 | resp = self.app.put("/bulk_update_only/", data=json.dumps({"txt": "both work"})) 885 | response_success(resp) 886 | 887 | resp = self.app.delete(f"/delete_only/{doc1.pk}/") 888 | response_success(resp) 889 | 890 | def test_fetch_method_permissions(self): 891 | doc1 = example.MethodTestDoc.objects.create(txt="doc1") 892 | 893 | # fetch 894 | resp = self.app.get(f"/fetch_only/{doc1.pk}/") 895 | response_success(resp) 896 | 897 | # list 898 | resp = self.app.get("/fetch_only/") 899 | response_error(resp, code=404) 900 | 901 | # create 902 | resp = self.app.post("/fetch_only/", data=json.dumps({"txt": "doesnt work"})) 903 | response_error(resp, code=404) 904 | 905 | # put 906 | resp = self.app.put( 907 | f"/fetch_only/{doc1.pk}/", data=json.dumps({"txt": "doesnt work"}) 908 | ) 909 | response_error(resp, code=405) 910 | 911 | # bulk put 912 | resp = self.app.put("/fetch_only/", data=json.dumps({"txt": "doesnt work"})) 913 | response_error(resp, code=404) 914 | 915 | # delete 916 | resp = self.app.delete(f"/fetch_only/{doc1.pk}/") 917 | response_error(resp, code=405) 918 | 919 | def test_list_method_permissions(self): 920 | doc1 = example.MethodTestDoc.objects.create(txt="doc1") 921 | 922 | # list 923 | resp = self.app.get("/list_only/") 924 | response_success(resp) 925 | 926 | # fetch 927 | resp = self.app.get(f"/list_only/{doc1.pk}/") 928 | response_error(resp, code=405) 929 | 930 | # create 931 | resp = self.app.post("/list_only/", data=json.dumps({"txt": "doesnt work"})) 932 | response_error(resp, code=405) 933 | 934 | # put 935 | resp = self.app.put( 936 | f"/list_only/{doc1.pk}/", data=json.dumps({"txt": "doesnt work"}) 937 | ) 938 | response_error(resp, code=405) 939 | 940 | # bulk put 941 | resp = self.app.put("/list_only/", data=json.dumps({"txt": "doesnt work"})) 942 | response_error(resp, code=405) 943 | 944 | # delete 945 | resp = self.app.delete(f"/list_only/{doc1.pk}/") 946 | response_error(resp, code=405) 947 | 948 | def test_create_method_permissions(self): 949 | doc1 = example.MethodTestDoc.objects.create(txt="doc1") 950 | 951 | # create 952 | resp = self.app.post("/create_only/", data=json.dumps({"txt": "works"})) 953 | response_success(resp) 954 | 955 | # list 956 | resp = self.app.get("/create_only/") 957 | response_error(resp, code=405) 958 | 959 | # fetch 960 | resp = self.app.get(f"/create_only/{doc1.pk}/") 961 | response_error(resp, code=405) 962 | 963 | # put 964 | resp = self.app.put( 965 | f"/create_only/{doc1.pk}/", data=json.dumps({"txt": "doesnt work"}) 966 | ) 967 | response_error(resp, code=405) 968 | 969 | # bulk put 970 | resp = self.app.put("/create_only/", data=json.dumps({"txt": "doesnt work"})) 971 | response_error(resp, code=405) 972 | 973 | # delete 974 | resp = self.app.delete(f"/create_only/{doc1.pk}/") 975 | response_error(resp, code=405) 976 | 977 | def test_update_method_permissions(self): 978 | doc1 = example.MethodTestDoc.objects.create(txt="doc1") 979 | 980 | # put 981 | resp = self.app.put( 982 | f"/update_only/{doc1.pk}/", data=json.dumps({"txt": "doesnt work"}) 983 | ) 984 | response_success(resp) 985 | 986 | # create 987 | resp = self.app.post("/update_only/", data=json.dumps({"txt": "works"})) 988 | response_error(resp, code=404) 989 | 990 | # list 991 | resp = self.app.get("/update_only/") 992 | response_error(resp, code=404) 993 | 994 | # fetch 995 | resp = self.app.get(f"/update_only/{doc1.pk}/") 996 | response_error(resp, code=405) 997 | 998 | # bulk put 999 | resp = self.app.put("/update_only/", data=json.dumps({"txt": "doesnt work"})) 1000 | response_error(resp, code=404) 1001 | 1002 | # delete 1003 | resp = self.app.delete(f"/update_only/{doc1.pk}/") 1004 | response_error(resp, code=405) 1005 | 1006 | def test_bulk_update_method_permissions(self): 1007 | doc1 = example.MethodTestDoc.objects.create(txt="doc1") 1008 | 1009 | # bulk put 1010 | resp = self.app.put("/bulk_update_only/", data=json.dumps({"txt": "works"})) 1011 | response_success(resp) 1012 | 1013 | # put 1014 | resp = self.app.put( 1015 | f"/bulk_update_only/{doc1.pk}/", data=json.dumps({"txt": "doesnt work"}) 1016 | ) 1017 | response_error(resp, code=405) 1018 | 1019 | # create 1020 | resp = self.app.post("/bulk_update_only/", data=json.dumps({"txt": "works"})) 1021 | response_error(resp, code=405) 1022 | 1023 | # list 1024 | resp = self.app.get("/bulk_update_only/") 1025 | response_error(resp, code=405) 1026 | 1027 | # fetch 1028 | resp = self.app.get(f"/bulk_update_only/{doc1.pk}/") 1029 | response_error(resp, code=405) 1030 | 1031 | # delete 1032 | resp = self.app.delete(f"/bulk_update_only/{doc1.pk}/") 1033 | response_error(resp, code=405) 1034 | 1035 | def test_delete_method_permissions(self): 1036 | doc1 = example.MethodTestDoc.objects.create(txt="doc1") 1037 | 1038 | # delete 1039 | resp = self.app.delete(f"/delete_only/{doc1.pk}/") 1040 | response_success(resp) 1041 | 1042 | # bulk put 1043 | resp = self.app.put("/delete_only/", data=json.dumps({"txt": "works"})) 1044 | response_error(resp, code=404) 1045 | 1046 | # put 1047 | resp = self.app.put( 1048 | f"/delete_only/{doc1.pk}/", data=json.dumps({"txt": "doesnt work"}) 1049 | ) 1050 | response_error(resp, code=405) 1051 | 1052 | # create 1053 | resp = self.app.post("/delete_only/", data=json.dumps({"txt": "works"})) 1054 | response_error(resp, code=404) 1055 | 1056 | # list 1057 | resp = self.app.get("/delete_only/") 1058 | response_error(resp, code=404) 1059 | 1060 | # fetch 1061 | resp = self.app.get(f"/delete_only/{doc1.pk}/") 1062 | response_error(resp, code=405) 1063 | 1064 | def test_request_bad_accept(self): 1065 | """Make sure we gracefully handle requests where an invalid Accept header is sent.""" 1066 | resp = self.app.get( 1067 | f"/user/{self.user_1_obj['id']}/", headers={"Accept": "whatever"} 1068 | ) 1069 | response_error(resp) 1070 | self.assertEqual(resp.data, b"Invalid Accept header requested") 1071 | 1072 | def test_bulk_update_limit(self): 1073 | """ 1074 | Make sure that the limit on the number of objects that can be 1075 | bulk-updated at once works. 1076 | """ 1077 | limit = example.PostResource.bulk_update_limit 1078 | 1079 | for i in range(limit + 1): 1080 | resp = self.app.post( 1081 | "/posts/", 1082 | data=json.dumps({"title": f"Title {i}", "is_published": False}), 1083 | ) 1084 | response_success(resp) 1085 | 1086 | # bulk update all posts 1087 | resp = self.app.put("/posts/", data=json.dumps({"title": "Title"})) 1088 | response_error(resp, code=400) 1089 | self.assertEqual( 1090 | resp_json(resp), 1091 | {"errors": ["It's not allowed to update more than 10 objects at once"]}, 1092 | ) 1093 | self.assertEqual( 1094 | 11, example.documents.Post.objects.filter(title__ne="Title").count() 1095 | ) 1096 | 1097 | 1098 | class MongoRestSchemaTestCase(unittest.TestCase): 1099 | def setUp(self): 1100 | self.app = example.app.test_client() 1101 | example.documents.Language.drop_collection() 1102 | example.documents.Person.drop_collection() 1103 | 1104 | def tearDown(self): 1105 | pass 1106 | 1107 | def test_person(self): 1108 | resp = self.app.post( 1109 | "/person/", 1110 | data=json.dumps( 1111 | {"name": "John", "languages": [{"name": "English"}, {"name": "German"}]} 1112 | ), 1113 | ) 1114 | response_success(resp) 1115 | person = resp_json(resp) 1116 | 1117 | person_id = person["id"] 1118 | 1119 | self.assertEqual(len(person["languages"]), 2) 1120 | self.assertEqual(person["name"], "John") 1121 | self.assertEqual(person["languages"][0]["name"], "English") 1122 | self.assertEqual(person["languages"][1]["name"], "German") 1123 | 1124 | english_id = person["languages"][0]["id"] 1125 | german_id = person["languages"][1]["id"] 1126 | 1127 | # No change (same data) 1128 | resp = self.app.put( 1129 | f"/person/{person_id}/", 1130 | data=json.dumps( 1131 | { 1132 | "name": "John", 1133 | "languages": [ 1134 | {"id": english_id, "name": "English"}, 1135 | {"id": german_id, "name": "German"}, 1136 | ], 1137 | } 1138 | ), 1139 | ) 1140 | response_success(resp) 1141 | person = resp_json(resp) 1142 | self.assertEqual(len(person["languages"]), 2) 1143 | self.assertEqual(person["name"], "John") 1144 | self.assertEqual(person["languages"][0]["id"], english_id) 1145 | self.assertEqual(person["languages"][1]["id"], german_id) 1146 | self.assertEqual(person["languages"][0]["name"], "English") 1147 | self.assertEqual(person["languages"][1]["name"], "German") 1148 | 1149 | # No change (omitted fields of related document) 1150 | resp = self.app.put( 1151 | f"/person/{person_id}/", 1152 | data=json.dumps( 1153 | {"name": "John", "languages": [{"id": english_id}, {"id": german_id}]} 1154 | ), 1155 | ) 1156 | response_success(resp) 1157 | person = resp_json(resp) 1158 | self.assertEqual(len(person["languages"]), 2) 1159 | self.assertEqual(person["name"], "John") 1160 | self.assertEqual(person["languages"][0]["id"], english_id) 1161 | self.assertEqual(person["languages"][1]["id"], german_id) 1162 | self.assertEqual(person["languages"][0]["name"], "English") 1163 | self.assertEqual(person["languages"][1]["name"], "German") 1164 | 1165 | # Also no change (empty data) 1166 | resp = self.app.put(f"/person/{person_id}/", data=json.dumps({})) 1167 | response_success(resp) 1168 | person = resp_json(resp) 1169 | self.assertEqual(len(person["languages"]), 2) 1170 | self.assertEqual(person["name"], "John") 1171 | self.assertEqual(person["languages"][0]["id"], english_id) 1172 | self.assertEqual(person["languages"][1]["id"], german_id) 1173 | self.assertEqual(person["languages"][0]["name"], "English") 1174 | self.assertEqual(person["languages"][1]["name"], "German") 1175 | 1176 | # Change value 1177 | resp = self.app.put( 1178 | f"/person/{person_id}/", 1179 | data=json.dumps( 1180 | { 1181 | "languages": [ 1182 | {"id": english_id, "name": "English"}, 1183 | {"id": german_id, "name": "French"}, 1184 | ] 1185 | } 1186 | ), 1187 | ) 1188 | response_success(resp) 1189 | person = resp_json(resp) 1190 | self.assertEqual(len(person["languages"]), 2) 1191 | self.assertEqual(person["name"], "John") 1192 | self.assertEqual(person["languages"][0]["id"], english_id) 1193 | self.assertEqual(person["languages"][1]["id"], german_id) 1194 | self.assertEqual(person["languages"][0]["name"], "English") 1195 | self.assertEqual(person["languages"][1]["name"], "French") 1196 | 1197 | # Insert item / rename back 1198 | resp = self.app.put( 1199 | f"/person/{person_id}/", 1200 | data=json.dumps( 1201 | { 1202 | "languages": [ 1203 | {"id": english_id}, 1204 | {"name": "Spanish"}, 1205 | {"id": german_id, "name": "German"}, 1206 | ] 1207 | } 1208 | ), 1209 | ) 1210 | response_success(resp) 1211 | person = resp_json(resp) 1212 | self.assertEqual(len(person["languages"]), 3) 1213 | self.assertEqual(person["name"], "John") 1214 | self.assertEqual(person["languages"][0]["id"], english_id) 1215 | self.assertEqual(person["languages"][2]["id"], german_id) 1216 | self.assertEqual(person["languages"][0]["name"], "English") 1217 | self.assertEqual(person["languages"][1]["name"], "Spanish") 1218 | self.assertEqual(person["languages"][2]["name"], "German") 1219 | 1220 | # Remove item 1221 | resp = self.app.put( 1222 | f"/person/{person_id}/", data=json.dumps({"languages": [{"id": german_id}]}) 1223 | ) 1224 | response_success(resp) 1225 | person = resp_json(resp) 1226 | self.assertEqual(len(person["languages"]), 1) 1227 | self.assertEqual(person["name"], "John") 1228 | self.assertEqual(person["languages"][0]["id"], german_id) 1229 | self.assertEqual(person["languages"][0]["name"], "German") 1230 | 1231 | # Assign back (item is still in the database) 1232 | resp = self.app.put( 1233 | f"/person/{person_id}/", 1234 | data=json.dumps({"languages": [{"id": german_id}, {"id": english_id}]}), 1235 | ) 1236 | response_success(resp) 1237 | person = resp_json(resp) 1238 | self.assertEqual(len(person["languages"]), 2) 1239 | self.assertEqual(person["name"], "John") 1240 | self.assertEqual(person["languages"][0]["id"], german_id) 1241 | self.assertEqual(person["languages"][0]["name"], "German") 1242 | self.assertEqual(person["languages"][1]["id"], english_id) 1243 | self.assertEqual(person["languages"][1]["name"], "English") 1244 | 1245 | # Test invalid ID 1246 | resp = self.app.put( 1247 | f"/person/{person_id}/", data=json.dumps({"languages": [{"id": "INVALID"}]}) 1248 | ) 1249 | response_error(resp) 1250 | 1251 | def test_datetime(self): 1252 | resp = self.app.post( 1253 | "/datetime/", data=json.dumps({"datetime": "2010-01-01T00:00:00"}) 1254 | ) 1255 | response_success(resp) 1256 | datetime = resp_json(resp) 1257 | self.assertEqual(datetime["datetime"], "2010-01-01T00:00:00") 1258 | 1259 | with query_counter() as c: 1260 | resp = self.app.put( 1261 | f"/datetime/{datetime['id']}/", 1262 | data=json.dumps({"datetime": "2010-01-02T00:00:00"}), 1263 | ) 1264 | response_success(resp) 1265 | datetime = resp_json(resp) 1266 | self.assertEqual(datetime["datetime"], "2010-01-02T00:00:00") 1267 | 1268 | self.assertEqual(c, 3) # query, update, query (reload) 1269 | 1270 | with query_counter() as c: 1271 | resp = self.app.put( 1272 | f"/datetime/{datetime['id']}/", 1273 | data=json.dumps({"datetime": "2010-01-02T00:00:00"}), 1274 | ) 1275 | response_success(resp) 1276 | datetime = resp_json(resp) 1277 | self.assertEqual(datetime["datetime"], "2010-01-02T00:00:00") 1278 | 1279 | # Ideally this would be one query since we're not modifying, but 1280 | # in the generic case the save method may have other side effects 1281 | # and we don't know if the object was modified, so we currently 1282 | # always reload. 1283 | self.assertEqual(c, 2) # 2x query (with reload) 1284 | 1285 | # Same as above, with no body 1286 | with query_counter() as c: 1287 | resp = self.app.put(f"/datetime/{datetime['id']}/", data=json.dumps({})) 1288 | response_success(resp) 1289 | datetime = resp_json(resp) 1290 | self.assertEqual(datetime["datetime"], "2010-01-02T00:00:00") 1291 | 1292 | self.assertEqual(c, 2) # 2x query (with reload) 1293 | 1294 | def test_receive_bad_json(self): 1295 | """ 1296 | Python is stupid and by default lets us accept an invalid JSON. Test 1297 | that flask-mongorest handles it correctly. 1298 | """ 1299 | # test create 1300 | resp = self.app.post( 1301 | "/dict_doc/", 1302 | data=json.dumps( 1303 | { 1304 | "dict": { 1305 | "nan": float("NaN"), 1306 | "inf": float("inf"), 1307 | "-inf": float("-inf"), 1308 | } 1309 | } 1310 | ), 1311 | ) 1312 | response_error(resp, code=400) 1313 | self.assertEqual( 1314 | resp_json(resp), {"error": "The request contains invalid JSON."} 1315 | ) 1316 | 1317 | # test update 1318 | resp = self.app.post("/dict_doc/", data=json.dumps({"dict": {"aaa": "bbb"}})) 1319 | response_success(resp) 1320 | resp = self.app.put( 1321 | f"/dict_doc/{resp_json(resp)['id']}/", 1322 | data=json.dumps( 1323 | { 1324 | "dict": { 1325 | "nan": float("NaN"), 1326 | "inf": float("inf"), 1327 | "-inf": float("-inf"), 1328 | } 1329 | } 1330 | ), 1331 | ) 1332 | response_error(resp, code=400) 1333 | self.assertEqual( 1334 | resp_json(resp), {"error": "The request contains invalid JSON."} 1335 | ) 1336 | 1337 | def test_send_bad_json(self): 1338 | """ 1339 | Make sure that - even if we store invalid JSON in database, we error out 1340 | instead of sending invalid data to the user. 1341 | """ 1342 | doc = example.DictDoc.objects.create( 1343 | dict={"NaN": float("NaN"), "inf": float("inf"), "-inf": float("-inf")} 1344 | ) 1345 | 1346 | # test fetch 1347 | self.assertRaises(ValueError, self.app.get, f"/dict_doc/{doc.id}/") 1348 | 1349 | # test list 1350 | self.assertRaises(ValueError, self.app.get, "/dict_doc/") 1351 | 1352 | 1353 | class InternalTestCase(unittest.TestCase): 1354 | """ 1355 | Test internal methods. 1356 | """ 1357 | 1358 | def test_serialize_mongoengine_validation_error(self): 1359 | from flask_mongorest.views import serialize_mongoengine_validation_error 1360 | 1361 | error = ValidationError(errors={"a": ValidationError("Invalid value")}) 1362 | result = serialize_mongoengine_validation_error(error) 1363 | self.assertEqual(result, {"field-errors": {"a": "Invalid value"}}) 1364 | 1365 | error = ValidationError("Invalid value") 1366 | result = serialize_mongoengine_validation_error(error) 1367 | self.assertEqual(result, {"error": "Invalid value"}) 1368 | 1369 | error = ValidationError(errors={"a": "Invalid value"}) 1370 | result = serialize_mongoengine_validation_error(error) 1371 | self.assertEqual(result, {"field-errors": {"a": "Invalid value"}}) 1372 | 1373 | 1374 | if __name__ == "__main__": 1375 | unittest.main() 1376 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py27,py35 3 | 4 | [testenv] 5 | commands=nosetests 6 | deps= 7 | nose 8 | py27: -rrequirements.txt 9 | py35: -rrequirements3.txt 10 | --------------------------------------------------------------------------------