├── .gitignore ├── .travis.yml ├── CHANGES.md ├── LICENSE ├── README.rst ├── flask_restful_extend ├── __init__.py ├── error_handling.py ├── extend_json.py ├── marshal.py ├── model_converter.py ├── model_reqparse.py ├── model_validates.py └── reqparse_fixed_type.py ├── setup.py └── tests ├── __init__.py ├── error_handle_test.py ├── json_extend_test.py ├── model_test.py └── my_test_case.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sublime-project 3 | *.sublime-workspace 4 | *.egg-info/ 5 | dist/ 6 | build/ 7 | .idea/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "2.7" 5 | - "3.4" 6 | 7 | install: "pip install --editable ." 8 | 9 | script: nosetests 10 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | Version 0.3.5 (2015-06-27) 2 | ========================= 3 | break changes: 4 | 5 | ### error_handle 处理逻辑变化 6 | flask-restful-extend error_handle 现在只处理 HTTPException,对于其他 Exception,会直接抛出 7 | 因此用户应慎选抛出的异常的类型。对于客户端的不合法请求,抛出 HTTPException;对于服务器内部错误,抛出普通 Exception 8 | 9 | ### 移除 fix\_argument\_convert(),以及部分 fixed types 10 | flask-restful 的 reqparse 原来对 None 的处理逻辑是: 11 | `如果 arg type 是 str,则直接返回 None。否则把 None 值传给 type 构造器,尝试构造(如果构造器无法解析 None,则最终会抛出异常)` 12 | 现在变成了: 13 | `对于任何 arg type,都直接返回 None。` 14 | 15 | 而我之前为了应对 flask-restful 的行为,是这样修正的: 16 | ``` 17 | 如果 arg type 是 str,返回一个字符串:"None"。否则调用 flask-restful 原生的 reqparse 继续处理。 18 | 同时,我预设了一些 fixed types,这些 types 是原生类型的包裹层,能正常处理 None 值(处理方式是直接返回 None) 19 | ``` 20 | 现在,根据 flask-restful 的新行为,我的行为也要跟着变化: 21 | ``` 22 | 不对 str 的 arg type 做特殊处理,和其他类型一样,直接返回 None 23 | fixed types 也无需考虑 None 的事情了,因为出现 None 时会直接跳过,这些构造器不会被调用 24 | 注意:原本当 None 值出现了的时候,构造器是会被调用的;但现在它们不会被调用了。如果有些构造器依赖于这一情况(它们专门为 None 值准备了处理逻辑),那么这些逻辑会无效。 25 | ``` 26 | 27 | 28 | Version 0.3.1 (2014-12) 29 | ========================= 30 | 31 | - Refactor SQLAlchemy model validator's implementation. 32 | 33 | **Notice:** the use case was change 34 | 35 | Before: 36 | ```python 37 | class People(db.Model): 38 | name = Column(String(100)) 39 | age = Column(Integer) 40 | IQ = Column(Integer) 41 | 42 | validate_rules = [ 43 | ('name', 'min_length', 1) 44 | ('name', 'max_length', 100) 45 | (['age', 'IQ', 'min', 0) 46 | ] 47 | ``` 48 | 49 | After: 50 | ```python 51 | class People(db.Model): 52 | name = Column(String(100)) 53 | age = Column(Integer) 54 | IQ = Column(Integer) 55 | 56 | validator = complex_validates({ 57 | 'name': (('min_length', 1), ('max_length', 100)) 58 | ('age', 'IQ'): [('min', 0)] 59 | }) 60 | ``` 61 | 62 | - `_CantEncodeObjException` renames to `CantEncodeObjException` 63 | 64 | - Changes default callback\_name\_source from 'jsonp' to 'callback' 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 anjianshi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | *Translate the document into English was a bit difficult, I have tried.* 2 | 3 | Flask-RESTFul-extend 4 | ==================== 5 | 6 | Improve Flask-RESTFul's behavior. Add some new features. 7 | 8 | All features that shown below was optional, choose according to your needs. 9 | 10 | 11 | Improve error handling 12 | ---------------------- 13 | source: error_handling.py_ 14 | 15 | Flask-RESTFul's error handler can only output error message for the exceptions that raised by itself. 16 | 17 | This make it can handling other kind of exceptions in same way. 18 | 19 | .. code-block:: python 20 | 21 | api = restful_extend.ErrorHandledApi(app) # instead of `api = restful.Api(app)` 22 | 23 | class MyRoutes(Resource): 24 | def get(self): 25 | raise Exception("errmsg") # now, the 'errmsg' can output to client 26 | 27 | 28 | 29 | Improve JSON support 30 | -------------------- 31 | 32 | enhance JSON encode ability 33 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 34 | sources: extend_json.py_, json_encode_manager.py_ 35 | 36 | Support more data type by default, and you can simply add supports for new data type. 37 | 38 | .. code-block:: python 39 | 40 | # This is a custom type, you can't direct return a value of this type in Flask or Flask-RESTFul. 41 | class RationalNumber(object): 42 | def __init__(self, numerator, denominator): 43 | self.numerator = numerator 44 | self.denominator = denominator 45 | 46 | api = restful.Api(app) 47 | 48 | # Enable enhanced json encode feature 49 | enhance_json_encode(api) 50 | 51 | 52 | # create and register a custom encoder, to encode your custom type to serializable value. 53 | def rational_encoder(rational): 54 | return rational.numerator * 1.0 / rational.denominator 55 | 56 | api.json_encoder.register(rational_encoder, RationalNumber) 57 | 58 | 59 | class MyRoutes(Resource): 60 | def get(self): 61 | return RationalNumber(1, 5) # now you can return the value of your custom type directly 62 | 63 | 64 | 65 | support JSONP 66 | ^^^^^^^^^^^^^ 67 | source: extend_json.py_ 68 | 69 | Respond jsonp request automatically 70 | 71 | .. code-block:: python 72 | 73 | api = restful.Api(app) 74 | support_jsonp(api) 75 | 76 | class MyRoutes(Resource): 77 | def get(self): 78 | return dict(foo='bar') 79 | 80 | api.add_resource(MyRoutes, '/my_routes') 81 | 82 | # normal request: /my_routes response: {"foo": "bar"} 83 | # jsonp request: /my_routes?callback=my_cb response: my_cb({"foo": "bar"}) 84 | 85 | 86 | 87 | SQLAlchemy related extend 88 | ------------------------- 89 | 90 | marshal_with_model 91 | ^^^^^^^^^^^^^^^^^^ 92 | source: marshal.py_ 93 | 94 | Extend Flask-RESTFul's `marshal_with` decorator's behavior. 95 | Auto define fields for ORM model. 96 | 97 | .. code-block:: python 98 | 99 | class MyRoutes(Resource): 100 | # With `marshal_with_model`, you can return an model instance or model query 101 | # in view function directly. 102 | @marshal_with_model(MyModel, excludes=['id']) 103 | def get(self): 104 | return MyModel.query # response: [my_model1, my_model2, ...] 105 | 106 | # If you need return different type of model in different situation, you can use `quick_marshal` 107 | def post(self): 108 | if something: 109 | return quick_marshal(MyModel)(MyModel.query.get(1)) # response: my_model 110 | else: 111 | return quick_marshal(HisModel)(HisModel.query) # response: [his_model1, ...] 112 | 113 | 114 | 115 | fast register URL converter for model 116 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 117 | source: model_converter.py_ 118 | 119 | .. code-block:: python 120 | 121 | api = restful.Api(app) 122 | 123 | 124 | class Student(db.model): 125 | id = Column(Integer, primary_key=True) 126 | name = Column(String(50)) 127 | 128 | register_model_converter(Student) 129 | 130 | 131 | class MyRoutes(object): 132 | def get(self, classmate): 133 | pass 134 | 135 | api.add_resource(MyRoutes, '/classmates/') 136 | 137 | # request: /classmates/102 response: {"id": 102, "name": "superman"} 138 | 139 | 140 | 141 | According to specified model's definition Create a RequestParser 142 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 143 | source: model_reqparse.py_ 144 | 145 | required: fixed_type (see next section) 146 | 147 | .. code-block:: python 148 | 149 | class Student(db.model): 150 | id = Column(Integer, primary_key=True) 151 | name = Column(String(50)) 152 | age = Column(Integer) 153 | 154 | 155 | class MyRoutes(object): 156 | def post(self): 157 | # use `make_request_parser` quickly create a `RequestParser` 158 | parser = make_request_parser(Student) 159 | 160 | # you can update the parser as usual 161 | parser.add_argument('is_a_boy') 162 | 163 | request_data = parser.parse_args() 164 | print request_data['name'], request_data['age'] 165 | # do something... 166 | 167 | 168 | class MyRoutes2(object): 169 | def post(self): 170 | # if you want padding the request data to a model, 171 | # you can use `populate_model`, it's more convenience. 172 | model = Student.query.get(1) 173 | populate_model(model) # the model was updated use user request data 174 | 175 | 176 | 177 | Improve Argument type handling 178 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 179 | sources: model_reqparse.py_, reqparse_fixed_type.py_ 180 | 181 | **fix_argument_convert** 182 | 183 | Change `reqparse.Argument.convert`'s original behavior. 184 | 185 | You should call this function before use `make_request_parser`, `populate_model` or fixed types. 186 | 187 | **fixed types** 188 | 189 | A set of customized type_constructor. 190 | 191 | Use them in place of int、str、datetime... to be the type used in `Argument`'s `type` parameter, 192 | this can provide some additional feature. 193 | 194 | 195 | 196 | Model validates 197 | ^^^^^^^^^^^^^^^ 198 | source: model_validates.py_ 199 | 200 | Simplify and extend SQLAlchemy's attribute validates process. 201 | This function has no relationship with Flask-RESTful. 202 | 203 | 204 | 205 | 206 | More Details 207 | ------------ 208 | For more details, please read the documents in source. 209 | 210 | 211 | 212 | 213 | .. _error_handling.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/error_handling.py 214 | 215 | .. _extend_json.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/extend_json.py 216 | 217 | .. _json_encode_manager.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/json_encode_manager.py 218 | 219 | .. _marshal.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/marshal.py 220 | 221 | .. _model_converter.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/model_converter.py 222 | 223 | .. _model_reqparse.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/model_reqparse.py 224 | 225 | .. _reqparse_fixed_type.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/reqparse_fixed_type.py 226 | 227 | .. _model_validates.py: https://github.com/anjianshi/flask-restful-extend/blob/master/flask_restful_extend/model_validates.py 228 | -------------------------------------------------------------------------------- /flask_restful_extend/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __version__ = '0.3.7' 4 | 5 | from .error_handling import ErrorHandledApi 6 | from .extend_json import enhance_json_encode, support_jsonp 7 | from .marshal import marshal_with_model, quick_marshal 8 | from .model_converter import register_model_converter 9 | from .model_reqparse import make_request_parser, populate_model 10 | from . import reqparse_fixed_type as fixed_type 11 | from .model_validates import complex_validates 12 | -------------------------------------------------------------------------------- /flask_restful_extend/error_handling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from flask_restful import Api 3 | from werkzeug.exceptions import HTTPException 4 | 5 | class ErrorHandledApi(Api): 6 | """Usage: 7 | api = restful_extend.ErrorHandledApi(app) 8 | 9 | instead of: 10 | 11 | api = restful.Api(app) 12 | 13 | todo: support Python3 (under python3, Exception has no "message" attribute) 14 | """ 15 | 16 | def handle_error(self, e): 17 | """ 18 | Resolve the problem about sometimes error message specified by programmer won't output to user. 19 | 20 | Flask-RESTFul's error handler handling format different exceptions has different behavior. 21 | If we raise an normal Exception, it will raise it again. 22 | 23 | If we report error by `restful.abort()`, 24 | likes `restful.abort(400, message="my_msg", custom_data="value")`, 25 | it will make a response like this: 26 | 27 | Status 400 28 | Content {"message": "my_msg", "custom_data": "value"} 29 | 30 | The error message we specified was outputted. 31 | 32 | And if we raise an HTTPException, 33 | likes `from werkzeug.exceptions import BadRequest; raise BadRequest('my_msg')`, 34 | if will make a response too, but the error message specified by ourselves was lost: 35 | 36 | Status 400 37 | Content {"status": 400, "message": "Bad Request"} 38 | 39 | The reason is, flask-restful always use the `data` attribute of HTTPException to generate response content. 40 | But, standard HTTPException object didn't has this attribute. 41 | So, we use this method to add it manually. 42 | 43 | 44 | Some reference material: 45 | 46 | Structure of exceptions raised by restful.abort(): 47 | code: status code 48 | description: predefined error message for this status code 49 | data: { 50 |     message: error message 51 | } 52 | 53 | Structure of python2's standard Exception: 54 | message: error message 55 | Exceptions in python3 didn't has hte `message` attribute, but use `str(exception)` can get it's message. 56 | 57 | Structure of standard `werkzeug.exceptions.HTTPException` (same as BadRequest): 58 | code: status code 59 | name: the name correspondence to status code 60 | description: error message 61 | """ 62 | if isinstance(e, HTTPException) and not hasattr(e, 'data'): 63 | e.data = dict(message=e.description) 64 | return super(ErrorHandledApi, self).handle_error(e) 65 | 66 | def unauthorized(self, response): 67 | """In default, when users was unauthorized, Flask-RESTFul will popup an login dialog for user. 68 | But for an RESTFul app, this is useless, so I override the method to remove this behavior.""" 69 | return response 70 | -------------------------------------------------------------------------------- /flask_restful_extend/extend_json.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from flask import request, current_app, make_response 4 | from json_encode_manager import JSONEncodeManager 5 | import json 6 | 7 | 8 | def enhance_json_encode(api_instance, extra_settings=None): 9 | """use `JSONEncodeManager` replace default `output_json` function of Flask-RESTful 10 | for the advantage of use `JSONEncodeManager`, please see https://github.com/anjianshi/json_encode_manager""" 11 | api_instance.json_encoder = JSONEncodeManager() 12 | 13 | dumps_settings = {} if extra_settings is None else extra_settings 14 | dumps_settings['default'] = api_instance.json_encoder 15 | dumps_settings.setdefault('ensure_ascii', False) 16 | 17 | @api_instance.representation('application/json') 18 | def output_json(data, code, headers=None): 19 | if current_app.debug: 20 | dumps_settings.setdefault('indent', 4) 21 | dumps_settings.setdefault('sort_keys', True) 22 | 23 | dumped = json.dumps(data, **dumps_settings) 24 | if 'indent' in dumps_settings: 25 | dumped += '\n' 26 | 27 | resp = make_response(dumped, code) 28 | resp.headers.extend(headers or {}) 29 | return resp 30 | 31 | 32 | def support_jsonp(api_instance, callback_name_source='callback'): 33 | """Let API instance can respond jsonp request automatically. 34 | 35 | `callback_name_source` can be a string or a callback. 36 | If it is a string, the system will find the argument that named by this string in `query string`. 37 | If found, determine this request to be a jsonp request, and use the argument's value as the js callback name. 38 | 39 | If `callback_name_source` is a callback, this callback should return js callback name when request 40 | is a jsonp request, and return False when request is not jsonp request. 41 | And system will handle request according to its return value. 42 | 43 | default support format:url?callback=js_callback_name 44 | """ 45 | output_json = api_instance.representations['application/json'] 46 | 47 | @api_instance.representation('application/json') 48 | def handle_jsonp(data, code, headers=None): 49 | resp = output_json(data, code, headers) 50 | 51 | if code == 200: 52 | callback = request.args.get(callback_name_source, False) if not callable(callback_name_source) \ 53 | else callback_name_source() 54 | if callback: 55 | resp.set_data(str(callback) + '(' + resp.get_data().decode("utf-8") + ')') 56 | 57 | return resp 58 | -------------------------------------------------------------------------------- /flask_restful_extend/marshal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from flask_restful import fields as _fields, marshal_with as _marshal_with 3 | from functools import wraps 4 | import time 5 | import six 6 | 7 | 8 | def marshal_with_model(model, excludes=None, only=None, extends=None): 9 | """With this decorator, you can return ORM model instance, or ORM query in view function directly. 10 | We'll transform these objects to standard python data structures, like Flask-RESTFul's `marshal_with` decorator. 11 | And, you don't need define fields at all. 12 | 13 | You can specific columns to be returned, by `excludes` or `only` parameter. 14 | (Don't use these tow parameters at the same time, otherwise only `excludes` parameter will be used.) 15 | 16 | If you want return fields that outside of model, or overwrite the type of some fields, 17 | use `extends` parameter to specify them. 18 | 19 | Notice: this function only support `Flask-SQLAlchemy` 20 | 21 | Example: 22 | class Student(db.Model): 23 | id = Column(Integer, primary_key=True) 24 | name = Column(String(100)) 25 | age = Column(Integer) 26 | 27 | class SomeApi(Resource): 28 | @marshal_with_model(Student, excludes=['id']) 29 | def get(self): 30 | return Student.query 31 | 32 | # response: [{"name": "student_a", "age": "16"}, {"name": "student_b", "age": 18}] 33 | 34 | class AnotherApi(Resource): 35 | @marshal_with_model(Student, extends={"nice_guy": fields.Boolean, "age": fields.String}) 36 | def get(self): 37 | student = Student.query.get(1) 38 | student.nice_guy = True 39 | student.age = "young" if student.age < 18 else "old" # transform int field to string 40 | return student 41 | """ 42 | if isinstance(excludes, six.string_types): 43 | excludes = [excludes] 44 | if excludes and only: 45 | only = None 46 | elif isinstance(only, six.string_types): 47 | only = [only] 48 | 49 | field_definition = {} 50 | for col in model.__table__.columns: 51 | if only: 52 | if col.name not in only: 53 | continue 54 | elif excludes and col.name in excludes: 55 | continue 56 | 57 | field_definition[col.name] = _type_map[col.type.python_type.__name__] 58 | 59 | if extends is not None: 60 | for k, v in extends.items(): 61 | field_definition[k] = v 62 | 63 | def decorated(f): 64 | @wraps(f) 65 | @_marshal_with(field_definition) 66 | def wrapper(*args, **kwargs): 67 | result = f(*args, **kwargs) 68 | return result if not _fields.is_indexable_but_not_string(result) else [v for v in result] 69 | return wrapper 70 | return decorated 71 | 72 | 73 | def quick_marshal(*args, **kwargs): 74 | """In some case, one view functions may return different model in different situation. 75 | Use `marshal_with_model` to handle this situation was tedious. 76 | This function can simplify this process. 77 | 78 | Usage: 79 | quick_marshal(args_to_marshal_with_model)(db_instance_or_query) 80 | """ 81 | @marshal_with_model(*args, **kwargs) 82 | def fn(value): 83 | return value 84 | return fn 85 | 86 | 87 | def _wrap_field(field): 88 | """Improve Flask-RESTFul's original field type""" 89 | class WrappedField(field): 90 | def output(self, key, obj): 91 | value = _fields.get_value(key if self.attribute is None else self.attribute, obj) 92 | 93 | # For all fields, when its value was null (None), return null directly, 94 | # instead of return its default value (eg. int type's default value was 0) 95 | # Because sometimes the client **needs** to know, was a field of the model empty, to decide its behavior. 96 | return None if value is None else self.format(value) 97 | return WrappedField 98 | 99 | 100 | class _DateTimeField(_fields.Raw): 101 | """Transform `datetime` and `date` objects to timestamp before return it.""" 102 | def format(self, value): 103 | try: 104 | return time.mktime(value.timetuple()) 105 | except OverflowError: 106 | # The `value` was generate by time zone UTC+0, 107 | # but `time.mktime()` will generate timestamp by local time zone (eg. in China, was UTC+8). 108 | # So, in some situation, we may got a timestamp that was negative. 109 | # In Linux, there's no problem. But in windows, this will cause an `OverflowError`. 110 | # Thinking of generally we don't need to handle a time so long before, at here we simply return 0. 111 | return 0 112 | 113 | except AttributeError as ae: 114 | raise _fields.MarshallingException(ae) 115 | 116 | 117 | class _FloatField(_fields.Raw): 118 | """Flask-RESTful will transform float value to a string before return it. 119 | This is not useful in most situation, so we change it to return float value directly""" 120 | 121 | def format(self, value): 122 | try: 123 | return float(value) 124 | except ValueError as ve: 125 | raise _fields.MarshallingException(ve) 126 | 127 | 128 | _type_map = { 129 | # python_type: flask-restful field 130 | 'str': _wrap_field(_fields.String), 131 | 'int': _wrap_field(_fields.Integer), 132 | 'float': _wrap_field(_FloatField), 133 | 'bool': _wrap_field(_fields.Boolean), 134 | 'datetime': _wrap_field(_DateTimeField), 135 | 'date': _wrap_field(_DateTimeField) 136 | } 137 | -------------------------------------------------------------------------------- /flask_restful_extend/model_converter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from werkzeug.routing import BaseConverter 3 | from werkzeug.exceptions import NotFound 4 | 5 | 6 | def register_model_converter(model, app): 7 | """Add url converter for model 8 | 9 | Example: 10 | class Student(db.model): 11 | id = Column(Integer, primary_key=True) 12 | name = Column(String(50)) 13 | 14 | register_model_converter(Student) 15 | 16 | @route('/classmates/') 17 | def get_classmate_info(classmate): 18 | pass 19 | 20 | This only support model's have single primary key. 21 | You need call this function before create view function. 22 | """ 23 | if hasattr(model, 'id'): 24 | class Converter(_ModelConverter): 25 | _model = model 26 | app.url_map.converters[model.__name__] = Converter 27 | 28 | 29 | class _ModelConverter(BaseConverter): 30 | _model = None 31 | 32 | def to_python(self, inst_id): 33 | instance = self._model.query.get(inst_id) 34 | if instance is None: 35 | raise NotFound(u'{}(id={}) not exists,request invalid'.format(self._model.__name__, inst_id)) 36 | return instance 37 | 38 | def to_url(self, inst): 39 | return str(inst.id) 40 | -------------------------------------------------------------------------------- /flask_restful_extend/model_reqparse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | __all__ = ['fix_argument_convert', 'make_request_parser', 'populate_model'] 3 | from flask_restful import reqparse 4 | from flask import request 5 | from . import reqparse_fixed_type as fixed_type 6 | import six 7 | 8 | 9 | _type_dict = { 10 | # python_type_name: fixed_type 11 | 'datetime': fixed_type.fixed_datetime, 12 | 'date': fixed_type.fixed_date, 13 | 'str': six.text_type, 14 | 'int': fixed_type.fixed_int, 15 | 'float': fixed_type.fixed_float 16 | } 17 | 18 | def make_request_parser(model_or_inst, excludes=None, only=None, for_populate=False): 19 | """Pass a `model class` or `model instance` to this function, 20 | then, it will generate a `RequestParser` that extract user request data from `request.json` 21 | according to the model class's definition. 22 | 23 | Parameter `excludes` and `only` can be `str` or list of `str`, 24 | then are used to specify which columns should be handled. 25 | If you passed `excludes` and `only` at same time, only `excludes` will be used. 26 | And, the primary key of the model will not be added to `RequestParser`'s argument list, 27 | unless you explicitly specify it use `only` parameter. 28 | 29 | If you pass in a model class, but not a model instance, the function will doing `required` checking, 30 | for columns that nullable=False. 31 | (If you pass in a model instance, the `required` checking will not proceed. Because in this situation, 32 | we should allow the user to ignore the assignment to a field) 33 | """ 34 | is_inst = _is_inst(model_or_inst) 35 | 36 | if isinstance(excludes, six.string_types): 37 | excludes = [excludes] 38 | if excludes and only: 39 | only = None 40 | elif isinstance(only, six.string_types): 41 | only = [only] 42 | 43 | parser = RequestPopulator() if for_populate else reqparse.RequestParser() 44 | for col in model_or_inst.__table__.columns: 45 | if only: 46 | if col.name not in only: 47 | continue 48 | elif (excludes and col.name in excludes) or col.primary_key: 49 | continue 50 | 51 | col_type = col.type.python_type 52 | kwargs = { 53 | "type": _type_dict.get(col_type.__name__, col_type) if hasattr(col_type, '__name__') else col_type 54 | } 55 | # When the context was to creating a new model instance, if a field has no default value, and is not nullable, 56 | # mark it's corresponding argument as `required`. 57 | # 创建新数据库实例时,若一个字段既没有默认值,又不允许 NULL,则把它对应 arg 设为 required 58 | if not is_inst and col.default is None and col.server_default is None and not col.nullable: 59 | kwargs["required"] = True 60 | parser.add_argument(col.name, **kwargs) 61 | return parser 62 | 63 | 64 | def populate_model(model_or_inst, excludes=None, only=None): 65 | """ 66 | Call `make_request_parser()` to build a `RequestParser`, use it extract user request data, 67 | and padding the data into model instance. 68 | If user passed a model class, instead of model instance, create a new instance use the extracted data. 69 | """ 70 | inst = model_or_inst if _is_inst(model_or_inst) else model_or_inst() 71 | 72 | parser = make_request_parser(model_or_inst, excludes, only, for_populate=True) 73 | req_args = parser.parse_args() 74 | 75 | for key, value in req_args.items(): 76 | setattr(inst, key, value) 77 | 78 | return inst 79 | 80 | 81 | def _is_inst(model_or_inst): 82 | return hasattr(model_or_inst, '_sa_instance_state') 83 | 84 | 85 | class RequestPopulator(reqparse.RequestParser): 86 | """Under the original process, whether client hasn't assign a value, or assign a null value, 87 | the argument's value will be None. 88 | That's no problem, generally. But in populate operation (eg. updating model instance's fields), 89 | it will cause problem. 90 | When we are do populating, we should not update the field if the client hasn't assign a value to it. 91 | And update it only if the client really assign a new value. 92 | 93 | The `RequestPopulator` parser is created specifically for the populate operation. 94 | In this parser, arguments that has not assigned a value, 95 | will not appear in argument list (implemented through `PopulatorArgument`). 96 | So the model fields corresponding to these arguments can keep its original value. 97 | """ 98 | def __init__(self, *args, **kwargs): 99 | kwargs['argument_class'] = PopulatorArgument 100 | super(RequestPopulator, self).__init__(*args, **kwargs) 101 | 102 | def parse_args(self, req=None): 103 | if req is None: 104 | req = request 105 | 106 | req.unparsed_arguments = {} 107 | 108 | namespace = self.namespace_class() 109 | 110 | for arg in self.args: 111 | try: 112 | value = arg.parse(req) 113 | namespace[arg.dest or arg.name] = value 114 | except ArgumentNoValue: 115 | pass 116 | 117 | return namespace 118 | 119 | 120 | class PopulatorArgument(reqparse.Argument): 121 | """Argument type that created specifically for populate operation. 122 | When the argument is not assigned, it will raise an exception rather than applying default value. 123 | (So, the `default` parameter will not be used) 124 | 125 | **关于值类型** 126 | (`arg` 指 Argument 实例,`参数` 指构建 arg 时给出的参数) 127 | 以 QueryString / FormData 形式提交的请求,每个 arg 的值在格式化之前都只能是字符串或空字符串。 128 | 对于 action != store 的 arg,可以指定多个值(?a=1&a=2),通过 type 指定的类型会分别应用到每个值上 129 | 130 | 以 JSON 形式提交的请求,arg 的值在格式化之前就可以是除数组外任意类型, 131 | 如果 arg 的值是一个数组, Flask-RESTFul 会视为对这个参数进行了多次赋值,并将 type 指定的类型会分别应用到每个值上 132 | 例如 json 的 {"a": ["x", "y"]} 相当于 QueryString 的 ?a=x&a=y 133 | 134 | **关于值解析** 135 | 解析前端提交的参数值时,不会对参数值有任何额外的处理(如预先进行一次类型转换),或者额外的行为(如碰到 None 就调用构造器调用), 136 | 一定是直接把它传给参数的构造器。 137 | 因此,只要参数的构造器本身不支持处理给定的值,就好报 400 错误。(例如:int 构造器既不支持空字符串,也不支持 None,那么碰到它们就会报错) 138 | 这样做可以避免歧义,例如 int 本身不支持空字符串,如果特意为了它把空字符串转成 None 或者 0,会使不了解内情的人误解,或者与他们预期的行为不符。 139 | 当然,可以通过自定义一个构造器来进行额外的处理,因为它明摆着是做了额外处理的,所以不会有误解的问题。 140 | P.S. flask-restful 在碰到 text 参数且值为 None 时,会返回 None, 141 | 这个行为不符合上面的规则,可调用 fix_argument_convert() 修复它 142 | """ 143 | def __init__(self, *args, **kwargs): 144 | # 把 action 强制设定为 append,以便解析参数值的时候判断此参数有没有被赋值 145 | # 记录原来的 action 是为了在最后仍能以用户期望的格式返回参数值 146 | self.real_action = kwargs.get('action', 'store') 147 | kwargs['action'] = 'append' 148 | 149 | super(PopulatorArgument, self).__init__(*args, **kwargs) 150 | 151 | def parse(self, req): 152 | results = super(PopulatorArgument, self).parse(req)[0] 153 | 154 | # 因为把 action 强制设定为了 append,因此在提交了参数值的情况下,results 一定是一个数组, 155 | # 不会和 self.default 是同一个值 156 | # (即使 self.default 也是数组,也不会和 results 是同一个数组) 157 | # 因此就可以通过这一点来判断当前请求中,到底有没有提交此参数的值 158 | if results is self.default: 159 | raise ArgumentNoValue() 160 | elif self.real_action == 'store' or (self.real_action != 'append' and len(results) == 1): 161 | return results[0] 162 | else: 163 | return results 164 | 165 | 166 | class ArgumentNoValue(Exception): 167 | pass 168 | -------------------------------------------------------------------------------- /flask_restful_extend/model_validates.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Simplify and extend SQLAlchemy's attribute validates process""" 3 | 4 | __all__ = ['complex_validates'] 5 | 6 | from sqlalchemy.orm import validates 7 | import re 8 | 9 | 10 | predefined_predicates = { 11 | 'min': lambda value, min_val: value >= min_val, 12 | 'max': lambda value, max_val: value <= max_val, 13 | 'min_length': lambda value, min_val: len(value) >= min_val, 14 | 15 | # When define a CHAR-like column, we have generally already specify a max length. 16 | # So, do we really need the max_length predicate? Yes, we need. 17 | # Cause the database will just cut sting when it is too lang, and will not raise an error, 18 | # so we won't noticed. But use this predicate, it will raise and Exception, and stop the whole action, 19 | # so we can handle the sting safety. 20 | 'max_length': lambda value, max_val: len(value) <= max_val, 21 | 22 | 'match': lambda value, pattern: bool(re.match(pattern, value)), 23 | 24 | # If a predicate will change the original value, we'd better add `trans_` prefix to it's name, 25 | # so we can know it's use clearly. 26 | 'trans_upper': lambda value: dict(value=value.upper()) 27 | } 28 | 29 | 30 | class ModelInvalid(Exception): 31 | code = 400 32 | 33 | 34 | def complex_validates(validate_rule): 35 | """Quickly setup attributes validation by one-time, based on `sqlalchemy.orm.validates`. 36 | 37 | Don't like `sqlalchemy.orm.validates`, you don't need create many model method, 38 | as long as pass formatted validate rule. 39 | (Cause of SQLAlchemy's validate mechanism, you need assignment this funciton's return value 40 | to a model property.) 41 | 42 | For simplicity, complex_validates don't support `include_removes` and `include_backrefs` parameters 43 | that in `sqlalchemy.orm.validates`. 44 | 45 | And we don't recommend you use this function multiple times in one model. 46 | Because this will bring many problems, like: 47 | 1. Multiple complex_validates's execute order was decide by it's model property name, and by reversed order. 48 | eg. predicates in `validator1 = complex_validates(...)` 49 | will be executed **AFTER** predicates in `validator2 = complex_validates(...)` 50 | 2. If you try to validate the same attribute in two (or more) complex_validates, only one of complex_validates 51 | will be execute. (May be this is a bug of SQLAlchemy?) 52 | `complex_validates` was currently based on `sqlalchemy.orm.validates`, so it is difficult to solve these problems. 53 | May be we can try to use `AttributeEvents` directly in further, to provide more reliable function. 54 | 55 | Rule Format 56 | ----------- 57 | 58 | { 59 | column_name: predicate # basic format 60 | (column_name2, column_name3): predicate # you can specify multiple column_names to given predicates 61 | column_name4: (predicate, predicate2) # you can specify multiple predicates to given column_names 62 | column_name5: [(predicate, arg1, ... argN)] # and you can specify what arguments should pass to predicate 63 | # when it doing validate 64 | (column_name6, column_name7): [(predicate, arg1, ... argN), predicate2] # another example 65 | } 66 | 67 | Notice: If you want pass arguments to predicate, you must wrap whole command by another list or tuple. 68 | Otherwise, we will determine the argument as another predicate. 69 | So, this is wrong: { column_name: (predicate, arg) } 70 | this is right: { column_name: [(predicate, arg)] } 71 | 72 | Predicate 73 | --------- 74 | 75 | There's some `predefined_predicates`, you can just reference its name in validate rule. 76 | 77 | {column_name: ['trans_upper']} 78 | 79 | Or you can pass your own predicate function to the rule, like this: 80 | 81 | def custom_predicate(value): 82 | return value_is_legal # return True or False for valid or invalid value 83 | 84 | {column_name: [custom_predicate]} 85 | 86 | If you want change the value when doing validate, return an `dict(value=new_value)` instead of boolean 87 | 88 | {column_name: lambda value: dict(value = value * 2)} # And you see, we can use lambda as a predicate. 89 | 90 | And the predicate can receive extra arguments, that passes in rule: 91 | 92 | def multiple(value, target_multiple): 93 | return dict(value= value * target_multiple) 94 | 95 | {column_name: (multiple, 10)} 96 | 97 | Complete Example 98 | ---------------- 99 | 100 | class People(db.Model): 101 | name = Column(String(100)) 102 | age = Column(Integer) 103 | IQ = Column(Integer) 104 | has_lover = Column(Boolean) 105 | 106 | validator = complex_validates({ 107 | 'name': [('min_length', 1), ('max_length', 100)], 108 | ('age', 'IQ'): [('min', 0)], 109 | 'has_lover': lambda value: return !value # hate you! 110 | })""" 111 | 112 | ref_dict = { 113 | # column_name: ( 114 | # (predicate, arg1, ... argN), 115 | # ... 116 | # ) 117 | } 118 | 119 | for column_names, predicate_refs in validate_rule.items(): 120 | for column_name in _to_tuple(column_names): 121 | ref_dict[column_name] = \ 122 | ref_dict.get(column_name, tuple()) + _normalize_predicate_refs(predicate_refs) 123 | 124 | return validates(*ref_dict.keys())( 125 | lambda self, name, value: _validate_handler(name, value, ref_dict[name])) 126 | 127 | 128 | def _to_tuple(value): 129 | return tuple(value) if type(value) in [tuple, list] else (value,) 130 | 131 | 132 | def _normalize_predicate_refs(predicate_refs): 133 | """ 134 | In Out 135 | 'trans_upper' (('trans_upper',),) 136 | ('trans_upper', 'trans_lower') (('trans_upper',), ('trans_lower',)) 137 | [('min', 1)] (('min', 1),) 138 | (('min', 1), 'trans_lower') (('min', 1), ('trans_lower',)) 139 | """ 140 | return tuple(_to_tuple(predicate_ref) for predicate_ref in _to_tuple(predicate_refs)) 141 | 142 | 143 | def _validate_handler(column_name, value, predicate_refs): 144 | """handle predicate's return value""" 145 | 146 | # only does validate when attribute value is not None 147 | # else, just return it, let sqlalchemy decide if the value was legal according to `nullable` argument's value 148 | if value is not None: 149 | for predicate_ref in predicate_refs: 150 | predicate, predicate_name, predicate_args = _decode_predicate_ref(predicate_ref) 151 | validate_result = predicate(value, *predicate_args) 152 | 153 | if isinstance(validate_result, dict) and 'value' in validate_result: 154 | value = validate_result['value'] 155 | elif type(validate_result) != bool: 156 | raise Exception( 157 | 'predicate (name={}) can only return bool or dict(value=new_value) value'.format(predicate_name)) 158 | elif not validate_result: 159 | raise ModelInvalid(u'db model validate failed: column={}, value={}, predicate={}, arguments={}'.format( 160 | column_name, value, predicate_name, ','.join(map(str, predicate_args)) 161 | )) 162 | return value 163 | 164 | 165 | def _decode_predicate_ref(rule): 166 | predicate, predicate_args = rule[0], rule[1:] 167 | 168 | if isinstance(predicate, str): 169 | predicate_name = predicate 170 | predicate = predefined_predicates[predicate] 171 | else: 172 | predicate_name = predicate.__name__ 173 | 174 | return [predicate, predicate_name, predicate_args] 175 | -------------------------------------------------------------------------------- /flask_restful_extend/reqparse_fixed_type.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime, date 3 | import six 4 | 5 | 6 | def fix_number(target_type): 7 | return lambda value: None if isinstance(value, (str, six.text_type)) and len(value) == 0 else target_type(value) 8 | 9 | 10 | fixed_datetime = lambda time_str: datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S') 11 | fixed_date = lambda time_str: date.fromtimestamp(time_str) 12 | fixed_int = fix_number(int) 13 | fixed_float = fix_number(float) 14 | 15 | 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | try: 5 | from setuptools import setup 6 | except ImportError: 7 | from distutils.core import setup 8 | 9 | setup( 10 | name='Flask-RESTful-extend', 11 | version='0.3.7', 12 | url='https://github.com/anjianshi/flask-restful-extend', 13 | license='MIT', 14 | author='anjianshi', 15 | author_email='anjianshi@gmail.com', 16 | description="Improve Flask-RESTFul's behavior. Add some new features.", 17 | packages=['flask_restful_extend'], 18 | zip_safe=False, 19 | platforms='any', 20 | install_requires=['Flask>=0.10', 'Flask-RESTful>=0.3', 'Flask-SQLAlchemy', "six", "json_encode_manager"], 21 | keywords=['flask', 'python', 'rest', 'api'], 22 | classifiers=[ 23 | 'Intended Audience :: Developers', 24 | 'License :: OSI Approved :: MIT License', 25 | 'Operating System :: OS Independent', 26 | 'Programming Language :: Python', 27 | 'Programming Language :: Python :: 2.7', 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 确保测试时载入的是本地的 flask-restful-extend 而不是通过 pip 安装到全局范围的 3 | import sys 4 | sys.path.insert(1, sys.path[0] + '/../') 5 | 6 | import unittest 7 | 8 | from .error_handle_test import ErrorHandleTestCase 9 | from .json_extend_test import JSONEncoderTestCase, JSONPTestCase 10 | from .model_test import ModelValidateTestCase, MarshalTestCase, ReqparseTestCase 11 | 12 | """ 13 | from flask_restful import Resource 14 | 15 | 16 | class MyRoute(Resource): 17 | def get(self): 18 | from flask_restful_extend.extend_model import ModelInvalid 19 | raise ModelInvalid('abc error') 20 | 21 | api.add_resource(MyRoute, '/test/') 22 | 23 | 24 | class TestTestCase(unittest.TestCase): 25 | def test_route(self): 26 | self.app = app.test_client() 27 | rv = self.app.get('/test/') 28 | print rv.data 29 | """ 30 | 31 | if __name__ == '__main__': 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /tests/error_handle_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .my_test_case import MyTestCase 3 | from flask import json 4 | from flask_restful import Resource 5 | import flask_restful_extend as restful_extend 6 | from werkzeug.exceptions import BadRequest, Unauthorized, HTTPException 7 | 8 | 9 | class ErrorHandleTestCase(MyTestCase): 10 | def setUp(self): 11 | self.setup_app() 12 | 13 | testcase = self 14 | testcase.exception_to_raise = None 15 | testcase.error_message = "custom_error_msg" 16 | 17 | class Routes(Resource): 18 | def get(self): 19 | raise testcase.exception_to_raise 20 | 21 | api = restful_extend.ErrorHandledApi(self.app) 22 | api.add_resource(Routes, '/') 23 | 24 | def verify(self, exception_cls, extra_verify=None): 25 | """check if `ErrorHandledApi` can handling this kind of exception 26 | and return right status code and error message.""" 27 | 28 | error_message = "custom_error_msg" 29 | exception = exception_cls(error_message) 30 | status_code = exception.code if hasattr(exception, 'code') else 500 31 | 32 | self.exception_to_raise = exception_cls(error_message) 33 | rv = self.client.get('/') 34 | self.assertEqual(rv.status_code, status_code) 35 | self.assertEqual(json.loads(rv.data)['message'], error_message) 36 | 37 | if extra_verify: 38 | extra_verify(rv) 39 | 40 | def make_exce_request(self, exception_cls): 41 | """make a request, that raise the specified exception""" 42 | self.exception_to_raise = exception_cls(self.error_message) 43 | resp = self.client.get('/') 44 | return [resp.status_code, json.loads(resp.data)["message"]] 45 | 46 | def test_HTTPException(self): 47 | [code, msg] = self.make_exce_request(BadRequest) 48 | self.assertEqual(code, 400) 49 | self.assertEqual(msg, self.error_message) 50 | 51 | def test_custom_HTTPException(self): 52 | class CustomHTTPException(HTTPException): 53 | code = 401 54 | [code, msg] = self.make_exce_request(CustomHTTPException) 55 | self.assertEqual(code, 401) 56 | self.assertEqual(msg, self.error_message) 57 | 58 | def test_HTTPException_that_already_has_data_attr(self): 59 | """如果一个 HTTPException 已经有 data attribute,那么 flask-restful-extend 就不应该再用 exception 的 describe 去填充 data 了""" 60 | class CustomHTTPException2(HTTPException): 61 | code = 403 62 | data = dict(message="another message") 63 | [code, msg] = self.make_exce_request(CustomHTTPException2) 64 | self.assertEqual(code, 403) 65 | self.assertEqual(msg, "another message") 66 | 67 | def test_std_python_exception(self): 68 | with self.assertRaises(Exception) as cm: 69 | self.make_exce_request(Exception) 70 | self.assertEqual(str(cm.exception), self.error_message) 71 | 72 | def test_custom_python_exception_with_code_attr(self): 73 | """早期的 flask-restful 的 error_handle 会把所有包含 code attribute 的 exception 作为 HTTPException 来处理, 74 | 但经过 https://github.com/flask-restful/flask-restful/pull/445 这个 issue,它改为只处理真正的 HTTPException 了。 75 | 所以这里要检查一下 flask-restful-extend 是否也遵从了这个行为""" 76 | class CustomException1(Exception): 77 | code = 405 78 | 79 | with self.assertRaises(CustomException1) as cm: 80 | self.make_exce_request(CustomException1) 81 | self.assertEqual(str(cm.exception), self.error_message) 82 | 83 | def test_unauthorized_handle(self): 84 | """test has `ErrorHandledApi` disabled the unauthorized dialog""" 85 | self.verify(Unauthorized, 86 | lambda rv: self.assertFalse(rv.headers.get('WWW-Authenticate', False))) 87 | -------------------------------------------------------------------------------- /tests/json_extend_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .my_test_case import MyTestCase 3 | from flask import request 4 | from flask_restful import Api, Resource 5 | import flask_restful_extend as restful_extend 6 | from datetime import datetime 7 | import time 8 | from decimal import Decimal 9 | 10 | 11 | class JSONEncoderTestCase(MyTestCase): 12 | def setUp(self): 13 | self.setup_app() 14 | 15 | testcase = self 16 | testcase.json_data = None 17 | 18 | class Routes(Resource): 19 | def get(self): 20 | return testcase.json_data 21 | 22 | self.api = Api(self.app) 23 | self.api.add_resource(Routes, '/') 24 | 25 | restful_extend.enhance_json_encode(self.api) 26 | 27 | def verify(self, data, expect_text_result): 28 | self.json_data = data 29 | rv = self.client.get('/') 30 | self.assertEqual(rv.content_type, 'application/json') 31 | self.assertEqual(expect_text_result, rv.data.decode("utf-8")) 32 | 33 | def test_basic(self): 34 | def gen(): 35 | l = [1, 2, 3] 36 | for i in l: 37 | yield i 38 | 39 | now = datetime.now() 40 | 41 | samples = [ 42 | (105.132, '105.132'), 43 | ('abc', '"abc"'), 44 | (u'你好', u'"你好"'), 45 | (True, 'true'), 46 | (None, 'null'), 47 | 48 | ([1, 'a', 10.5], '[1, "a", 10.5]'), 49 | 50 | (now, str(time.mktime(now.timetuple()))), 51 | (Decimal(10.5), '10.5'), 52 | (gen(), '[1, 2, 3]'), 53 | ] 54 | 55 | for data, result in samples: 56 | self.verify(data, result) 57 | 58 | def test_custom_encoder(self): 59 | class CustomDataType(object): 60 | def __init__(self, a, b): 61 | self.a = a 62 | self.b = b 63 | 64 | self.api.json_encoder.register(lambda obj: obj.a + obj.b, CustomDataType) 65 | self.verify(CustomDataType(Decimal(10.5), 1), '11.5') 66 | 67 | 68 | class JSONPTestCase(MyTestCase): 69 | 70 | callback_arg_name = 'jsonp_callback' 71 | js_callback = 'doIt' 72 | return_data = 'custom_result' 73 | 74 | def setUp(self): 75 | self.setup_app() 76 | 77 | testcase = self 78 | 79 | class Routes(Resource): 80 | def get(self): 81 | return testcase.return_data 82 | 83 | self.api = Api(self.app) 84 | self.api.add_resource(Routes, '/') 85 | 86 | def verify(self): 87 | rv = self.client.get('/?{}={}'.format(self.callback_arg_name, self.js_callback)) 88 | self.assertEqual(rv.content_type, 'application/json') 89 | self.assertEqual(rv.data.decode("utf-8"), '{}("{}")'.format(self.js_callback, self.return_data)) 90 | 91 | rv = self.client.get('/') 92 | self.assertEqual(rv.content_type, 'application/json') 93 | self.assertEqual(rv.data.decode("utf-8"), '"{}"'.format(self.return_data)) 94 | 95 | def test_str_source(self): 96 | restful_extend.support_jsonp(self.api, self.callback_arg_name) 97 | self.verify() 98 | 99 | def test_fn_source(self): 100 | restful_extend.support_jsonp(self.api, lambda: request.args.get(self.callback_arg_name, False)) 101 | self.verify() 102 | 103 | 104 | -------------------------------------------------------------------------------- /tests/model_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .my_test_case import MyTestCase 3 | from sqlalchemy import Column, Integer, String, Float, Boolean, TIMESTAMP, text 4 | from flask_sqlalchemy import SQLAlchemy 5 | from flask_restful_extend.model_validates import complex_validates, ModelInvalid 6 | from flask_restful_extend import register_model_converter, marshal_with_model, quick_marshal 7 | from flask_restful_extend.model_reqparse import make_request_parser, populate_model, \ 8 | RequestPopulator, PopulatorArgument, ArgumentNoValue 9 | from flask_restful_extend.reqparse_fixed_type import * 10 | from flask_restful import Api, Resource 11 | from flask_restful.reqparse import Argument 12 | from flask_restful import fields 13 | from flask import url_for, request 14 | from datetime import datetime 15 | import time 16 | from copy import copy 17 | from copy import deepcopy 18 | import json 19 | import six 20 | 21 | 22 | class ModelValidateTestCase(MyTestCase): 23 | def setUp(self): 24 | self.setup_app() 25 | 26 | self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://' 27 | self.db = SQLAlchemy(self.app) 28 | 29 | def setup_model(self, validate_rule): 30 | class Student(self.db.Model): 31 | id = Column(Integer, primary_key=True) 32 | name = Column(String(100), nullable=False) 33 | age = Column(Integer) 34 | notes = Column(String(100)) 35 | 36 | validator = complex_validates(validate_rule) 37 | 38 | self.Student = Student 39 | self.db.create_all() 40 | 41 | def verify_exception(self, data): 42 | with self.assertRaises(ModelInvalid) as cm: 43 | self.Student(**data) 44 | self.assertRegexpMatches(six.text_type(cm.exception), '^db model validate failed:') 45 | 46 | def test_predefined_predicates(self): 47 | self.setup_model({ 48 | # 对允许 NULL 值和不允许 NULL 值的字段都进行测试 49 | 'name': [('min_length', 5), ('max_length', 10), ('match', '^abc'), 'trans_upper'], 50 | 'notes': [('min_length', 5), ('max_length', 10), ('match', '^abc'), 'trans_upper'], 51 | 'age': [('min', 15), ('max', 20)] 52 | }) 53 | 54 | # ===== 能否正确处理合法值 ===== 55 | 56 | valid_data = [ 57 | dict(name='abcdee', age=18, notes='abcdegf'), # 填充所有字段 58 | dict(name='abcd13', age=16), # 填充必填和部分选填字段 59 | dict(name='abcdefewe'), # 只填充必填字段 60 | 61 | dict(name='abcd1', age=15), # 测试各验证器是否正确处理下界值 62 | dict(name='abcde12345', age=20), # 测试各验证器是否正确处理上界值 63 | 64 | dict(name=u'abc四五'), # 测试各验证器是否正确处理中文 65 | dict(name=u'abc四五六'), 66 | dict(name=u'abc四五六七八九十'), 67 | ] 68 | for data in valid_data: 69 | self.db.session.add(self.Student(**data)) 70 | self.db.session.commit() 71 | 72 | # 所有数据是否都成功写入 73 | instances = [i for i in self.Student.query] 74 | self.assertEqual(len(instances), len(valid_data)) 75 | 76 | # 检查 trans_upper 是否正常工作 77 | for i, instance in zip(range(0, len(instances)-1), instances): 78 | self.assertEqual(instance.name, valid_data[i]['name'].upper()) 79 | 80 | # ===== 能否正确处理非法值 ===== 81 | 82 | invalid_data = [ 83 | dict(name='abcd', age=15), # 低于下界值 84 | dict(notes=u'abc四'), 85 | dict(name='abcde123456', age=21), # 高于上界值 86 | dict(name=u'abc四五六七八九十A'), 87 | 88 | dict(name='xabc'), # not match 89 | dict(notes=u'你好'), 90 | ] 91 | for data in invalid_data: 92 | self.verify_exception(data) 93 | 94 | def test_custom_predicates(self): 95 | def trans_int(value): 96 | """给 value 加 1""" 97 | return dict(value=value + 1) 98 | 99 | def valid_int(value, arg): 100 | """检查 value 是不是 arg 的整倍数""" 101 | return value % arg == 0 102 | 103 | self.setup_model({ 104 | 'age': [trans_int, (valid_int, 4)] 105 | }) 106 | 107 | # valid data 108 | self.Student(name='a', age=7) 109 | 110 | # invalid data 111 | self.verify_exception(dict(name='a', age=8)) 112 | 113 | def test_rule_format(self): 114 | """各种格式的验证规则是否能被正常解析: 115 | 一次性指定单个、多个字段; 116 | 一次性指定单条,多条,有参数,无参数的验证规则; 117 | 同一个字段在多个条目里被指定规则""" 118 | 119 | def trans_int(value): 120 | return dict(value=value * 2) 121 | 122 | self.setup_model({ 123 | 'name': [('min_length', 5)], 124 | ('name', 'notes'): [('max_length', 10)], 125 | 'notes': [('trans_upper', )], 126 | 127 | 'age': [trans_int] 128 | }) 129 | 130 | # valid data 131 | self.db.session.add(self.Student(name='abcdefeg', notes='abcde', age=10)) 132 | self.db.session.commit() 133 | 134 | instances = [i for i in self.Student.query] 135 | self.assertEqual(instances[0].notes, 'ABCDE') 136 | self.assertEqual(instances[0].age, 20) 137 | 138 | # invalid data 139 | invalid_data = [ 140 | dict(name='012345678901'), 141 | dict(name='abcdee', notes='012345678901') 142 | ] 143 | for data in invalid_data: 144 | self.verify_exception(data) 145 | 146 | 147 | class _ModelTestCase(MyTestCase): 148 | def setUp(self): 149 | self.setup_app() 150 | self.setup_model() 151 | 152 | def setup_model(self): 153 | time_now = datetime.now() 154 | timestamp = time.mktime(time_now.timetuple()) 155 | float_default_value = 125.225 156 | bool_default_value = True 157 | 158 | self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://' 159 | self.db = SQLAlchemy(self.app) 160 | 161 | class TestModel(self.db.Model): 162 | id = Column(Integer, primary_key=True) 163 | 164 | col_int = Column(Integer, nullable=False) 165 | col_str = Column(String(100), nullable=False) 166 | col_float = Column(Float, nullable=False, default=float_default_value) 167 | col_bool = Column(Boolean, nullable=False, server_default=text(str(int(bool_default_value)))) 168 | col_timestamp = Column(TIMESTAMP, nullable=False) 169 | 170 | col_int_null = Column(Integer) 171 | col_str_null = Column(String(100)) 172 | col_float_null = Column(Float) 173 | col_bool_null = Column(Boolean) 174 | col_timestamp_null = Column(TIMESTAMP) 175 | 176 | self.TestModel = TestModel 177 | self.db.create_all() 178 | 179 | # ===== padding data ===== 180 | 181 | self.data1 = dict( 182 | id=1, 183 | col_int=10, col_str="a", col_float=1.5, col_bool=True, col_timestamp=time_now, 184 | col_int_null=25, col_str_null=u"你好", col_float_null=0.001, col_bool_null=False, 185 | col_timestamp_null=time_now) 186 | self.result1 = self.data1.copy() 187 | self.result1['col_timestamp'] = self.result1['col_timestamp_null'] = timestamp 188 | 189 | # 通过只给 nullable=False 且没有默认值的字段赋值,可以测试 null 值以及字段默认值能否被正常处理 190 | self.data2 = dict(id=2, 191 | col_int=10, col_str="a", col_timestamp=time_now) 192 | self.result2 = self.data2.copy() 193 | self.result2.update(dict( 194 | col_float=float_default_value, 195 | col_bool=bool_default_value, 196 | col_timestamp=timestamp, 197 | col_int_null=None, col_str_null=None, col_float_null=None, col_bool_null=None, col_timestamp_null=None, 198 | )) 199 | 200 | self.db.session.add(self.TestModel(**self.data1)) 201 | self.db.session.add(self.TestModel(**self.data2)) 202 | self.db.session.commit() 203 | 204 | class FixedRequestContext(object): 205 | def __init__(self, orig_context): 206 | self.context = orig_context 207 | 208 | def __enter__(self, *args, **kwargs): 209 | ret = self.context.__enter__(*args, **kwargs) 210 | request.unparsed_arguments = dict(Argument('').source(request)) 211 | return ret 212 | 213 | def __exit__(self, *args, **kwargs): 214 | return self.context.__exit__(*args, **kwargs) 215 | 216 | def fixed_request_context(self, *args, **kwargs): 217 | return self.FixedRequestContext(self.app.test_request_context(*args, **kwargs)) 218 | 219 | 220 | class MarshalTestCase(_ModelTestCase): 221 | def setUp(self): 222 | super(MarshalTestCase, self).setUp() 223 | self.maxDiff = None 224 | 225 | def verify_marshal(self, model_data, expect_result, excludes=None, only=None, extends=None): 226 | @marshal_with_model(self.TestModel, excludes=excludes, only=only, extends=extends) 227 | def fn(): 228 | return model_data 229 | 230 | expect_result = copy(expect_result) 231 | need_delete = [] 232 | if excludes: 233 | need_delete = excludes 234 | elif only: 235 | need_delete = [k for k in expect_result.keys() if k not in only] 236 | for key in need_delete: 237 | del expect_result[key] 238 | 239 | self.assertEqual(fn(), expect_result) 240 | 241 | def test_normal_marshal(self): 242 | self.verify_marshal(self.TestModel.query.get(1), self.result1) 243 | self.verify_marshal(self.TestModel.query.get(2), self.result2) 244 | 245 | def test_query_marshal(self): 246 | self.verify_marshal(self.TestModel.query, [self.result1, self.result2]) 247 | 248 | def test_excludes_and_only(self): 249 | self.verify_marshal(self.TestModel.query.get(1), self.result1, excludes=['id', 'col_int']) 250 | self.verify_marshal(self.TestModel.query.get(1), self.result1, only=['id', 'col_int']) 251 | self.verify_marshal(self.TestModel.query.get(1), self.result1, 252 | excludes=['id', 'col_int'], only=['col_str_null']) 253 | 254 | def test_extends(self): 255 | extends = { 256 | # test extend column 257 | "extend_col": fields.String, 258 | # test overwrite exists column 259 | "col_int": fields.Boolean 260 | } 261 | 262 | data = deepcopy(self.TestModel.query.get(1)) 263 | result = copy(self.result1) 264 | data.extend_col = result["extend_col"] = "abc" 265 | data.col_int = result["col_int"] = True 266 | 267 | self.verify_marshal(data, result, extends=extends) 268 | 269 | def test_quick_marshal(self): 270 | self.assertEqual( 271 | quick_marshal(self.TestModel)(self.TestModel.query.get(2)), 272 | self.result2 273 | ) 274 | 275 | def test_converter(self): 276 | register_model_converter(self.TestModel, self.app) 277 | 278 | testcase = self 279 | 280 | class Routes(Resource): 281 | def get(self, model): 282 | # test to_python 283 | testcase.assertEquals(model, testcase.TestModel.query.get(1)) 284 | 285 | # test to_url 286 | testcase.assertEqual( 287 | url_for('routes', model=testcase.TestModel.query.get(2)), 288 | '/2' 289 | ) 290 | 291 | api = Api(self.app) 292 | api.add_resource(Routes, '/') 293 | self.client.get('/1') 294 | 295 | 296 | class ReqparseTestCase(_ModelTestCase): 297 | def __init__(self, *args, **kwargs): 298 | super(ReqparseTestCase, self).__init__(*args, **kwargs) 299 | 300 | def test_fixed_type(self): 301 | # 测试类型转换 302 | time_str = '2013-12-21 14:19:05' 303 | d = fixed_datetime(time_str) 304 | self.assertTrue(isinstance(d, datetime)) 305 | self.assertEqual(str(d), time_str) 306 | 307 | self.assertEqual(fixed_int(987), 987) 308 | self.assertEqual(fixed_int("850"), 850) 309 | self.assertIsNone(fixed_int("")) 310 | self.assertIsNone(fixed_int(u"")) 311 | 312 | self.assertEqual(fixed_float(987.5), 987.5) 313 | self.assertEqual(fixed_float("850.3"), 850.3) 314 | self.assertIsNone(fixed_float("")) 315 | self.assertIsNone(fixed_float(u"")) 316 | 317 | # 测试实际调用时能否正确运行 318 | with self.fixed_request_context( 319 | method='POST', 320 | data='{"n1": 100, "n2": "100", "n3": "", "n4": null}', 321 | content_type="application/json"): 322 | self.assertEqual(Argument('n1', type=fixed_int).parse(request)[0], 100) 323 | self.assertEqual(Argument('n2', type=fixed_int).parse(request)[0], 100) 324 | self.assertEqual(Argument('n3', type=fixed_int).parse(request)[0], None) 325 | self.assertEqual(Argument('n4', type=fixed_int).parse(request)[0], None) 326 | 327 | def test_populator_argument(self): 328 | # 测试 JSON 下的情况 329 | with self.fixed_request_context( 330 | method='POST', 331 | data='{"foo": 100, "bar": "abc", "li": [300, 100, 200]}', 332 | content_type="application/json"): 333 | 334 | # 确认能否成功取到参数值 335 | # 因为内部实现中涉及到了 Argument 的 action 属性,同时也要确认一下有没有造成不良影响 336 | 337 | # 1. 默认为 action=store,总是只返回此参数的第一个值 338 | self.assertEqual( 339 | PopulatorArgument('foo', type=int).parse(request), 340 | 100) 341 | self.assertEqual( 342 | PopulatorArgument('bar').parse(request), 343 | u'abc') 344 | self.assertEqual( 345 | PopulatorArgument('li', type=int).parse(request), 346 | 300) 347 | 348 | # 2. action=append 的情况下,总是返回此参数的值列表 349 | self.assertEqual( 350 | PopulatorArgument('bar', action='append').parse(request), 351 | [u'abc']) 352 | self.assertEqual( 353 | PopulatorArgument('li', type=int, action='append').parse(request), 354 | [300, 100, 200]) 355 | 356 | # 3. 在 action 是其他值得情况下,若值数量为1,返回此值;否则返回值列表 357 | self.assertEqual( 358 | PopulatorArgument('foo', type=int, action='something').parse(request), 359 | 100) 360 | self.assertEqual( 361 | PopulatorArgument('bar', action='something').parse(request), 362 | u'abc') 363 | self.assertEqual( 364 | PopulatorArgument('li', type=int, action='something').parse(request), 365 | [300, 100, 200]) 366 | 367 | # 确认在未给出参数值的情况下,是否会按照预期抛出异常 368 | with self.assertRaises(ArgumentNoValue): 369 | PopulatorArgument('no_val_arg').parse(request) 370 | 371 | # 测试 QueryString / FormData 下的情况 372 | with self.fixed_request_context('/?foo=100&bar=abc&li=300&li=100&li=200', method='GET'): 373 | # 确认能否成功取到参数值 374 | # 因为内部实现中涉及到了 Argument 的 action 属性,同时也要确认一下有没有造成不良影响 375 | 376 | # 1. 默认为 action=store,总是只返回此参数的第一个值 377 | self.assertEqual( 378 | PopulatorArgument('foo', type=int).parse(request), 379 | 100) 380 | self.assertEqual( 381 | PopulatorArgument('bar').parse(request), 382 | u'abc') 383 | self.assertEqual( 384 | PopulatorArgument('li', type=int).parse(request), 385 | 300) 386 | 387 | # 2. action=append 的情况下,总是返回此参数的值列表 388 | self.assertListEqual( 389 | PopulatorArgument('foo', type=int, action='append').parse(request), 390 | [100] 391 | ) 392 | self.assertListEqual( 393 | PopulatorArgument('li', type=int, action='append').parse(request), 394 | [300, 100, 200] 395 | ) 396 | # 3. 在 action 是其他值得情况下,若值数量为1,返回此值;否则返回值列表 397 | self.assertEqual( 398 | PopulatorArgument('foo', type=int, action='something').parse(request), 399 | 100 400 | ) 401 | self.assertListEqual( 402 | PopulatorArgument('li', type=int, action='something').parse(request), 403 | [300, 100, 200] 404 | ) 405 | 406 | def test_request_populator(self): 407 | with self.fixed_request_context( 408 | method='POST', 409 | data='{"foo": 100, "bar": "abc"}', 410 | content_type="application/json"): 411 | parser = RequestPopulator() 412 | parser.add_argument(name='foo', type=int) 413 | parser.add_argument(name='bar') 414 | # 此参数不应出现在最终获得的参数列表里,因为没有给它赋值 415 | parser.add_argument(name='xyz', type=int) 416 | self.assertEqual( 417 | parser.parse_args(), 418 | dict(foo=100, bar="abc")) 419 | 420 | def test_make_request_parser(self): 421 | # for model 422 | self._test_make_request_parser(self.TestModel, True) 423 | 424 | # for instance 425 | self._test_make_request_parser(self.TestModel(), False) 426 | 427 | def _test_make_request_parser(self, model_or_inst, is_model): 428 | # 创建一个临时的 parser,提取其 args 429 | common_args = make_request_parser(model_or_inst).args 430 | 431 | # parser args 数量应该比 model col 数量少1,因为主键被排除了 432 | self.assertEqual( 433 | len(common_args), 434 | len(model_or_inst.__mapper__.columns) - 1) 435 | 436 | # 检查是否正确的把 arg type 设置成了 model col type 的 fixed 版本 437 | expect_types = [ 438 | # arg index, type 439 | (0, fixed_int), 440 | (1, six.text_type), 441 | (2, fixed_float), 442 | (4, fixed_datetime) 443 | ] 444 | for col_index, expect_type in expect_types: 445 | self.assertEqual(common_args[col_index].type, expect_type) 446 | 447 | # 若给出的是 model,则没有默认值且不允许 NULL 值的字段所对应的 arg 应设为 required 448 | # 若给出的是 instance,则所有字段都不是 required 449 | required_arg_index = [0, 1, 4] 450 | for arg, i in zip(common_args, range(len(common_args))): 451 | self.assertEqual(arg.required, is_model and i in required_arg_index) 452 | 453 | # 测试 excludes 和 only 参数 454 | def verify_args(args, remain): 455 | self.assertEqual(len(args), len(remain)) 456 | for arg, expect_name in zip(args, remain): 457 | self.assertEqual(arg.name, expect_name) 458 | 459 | # 1. excludes 460 | excludes = ['col_str', 'col_bool', 'col_int_null', 'col_timestamp_null'] 461 | exclude_remain = ['col_int', 'col_float', 'col_timestamp', 'col_str_null', 'col_float_null', 'col_bool_null'] 462 | verify_args( 463 | make_request_parser(model_or_inst, excludes=excludes).args, 464 | exclude_remain) 465 | # 测试把字符串当做参数值时能否正确处理 466 | verify_args( 467 | make_request_parser(model_or_inst, excludes='col_int_null').args, 468 | ['col_int', 'col_str', 'col_float', 'col_bool', 'col_timestamp', 469 | 'col_str_null', 'col_float_null', 'col_bool_null', 'col_timestamp_null', ]) 470 | 471 | # 2. only 472 | only = [ 473 | 'id', # 测试主键是否能成功进行“强制添加” 474 | 'col_bool', 'col_str_null' 475 | ] 476 | verify_args( 477 | make_request_parser(model_or_inst, only=only).args, 478 | only) 479 | # 测试把字符串当做参数值时能否正确处理 480 | verify_args( 481 | make_request_parser(model_or_inst, only='col_int_null').args, 482 | ['col_int_null']) 483 | 484 | # 3. 测试 excludes 和 only 都给出的情况下,是否只有 excludes 生效 485 | verify_args( 486 | make_request_parser(model_or_inst, excludes=excludes, only=only).args, 487 | exclude_remain) 488 | 489 | def test_populate_model(self): 490 | data = { 491 | 'col_int': 2, 492 | 'col_float': 10.5, 493 | 'col_bool_null': True, 494 | } 495 | 496 | with self.fixed_request_context( 497 | method='POST', 498 | data=json.dumps(data), 499 | content_type="application/json"): 500 | # model 501 | entity = populate_model(self.TestModel, only=[col for col, val in data.items()]) 502 | for col in entity.__mapper__.columns: 503 | self.assertEqual( 504 | getattr(entity, col.name), 505 | data.get(col.name, None) 506 | ) 507 | 508 | # inst 509 | entity = self.TestModel() 510 | populate_model(entity, only=[col for col, val in data.items()]) 511 | for col in entity.__mapper__.columns: 512 | self.assertEqual( 513 | getattr(entity, col.name), 514 | data.get(col.name, None) 515 | ) 516 | -------------------------------------------------------------------------------- /tests/my_test_case.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import unittest 3 | from flask import Flask 4 | 5 | __all__ = ['MyTestCase'] 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | def setup_app(self): 10 | app = Flask(__name__) 11 | app.config['TESTING'] = True 12 | 13 | self.app = app 14 | self.client = app.test_client() --------------------------------------------------------------------------------