├── .gitignore ├── .gitmodules ├── .pylintrc ├── .travis.yml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.rst ├── dev-requirements.txt ├── performance ├── benchmark.proto ├── benchmark.py └── benchmark_pb2.py ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── run_marshmallow_tests.sh ├── test_compat.py ├── test_jit.py ├── test_toasted_marshmallow.py └── test_utils.py └── toastedmarshmallow ├── __init__.py ├── compat.py ├── jit.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | 21 | # Installer logs 22 | pip-log.txt 23 | 24 | # Unit test / coverage reports 25 | .coverage 26 | htmlcov 27 | .tox 28 | nosetests.xml 29 | .cache 30 | 31 | # Translations 32 | *.mo 33 | 34 | # Mr Developer 35 | .mr.developer.cfg 36 | 37 | # IDE 38 | .project 39 | .pydevproject 40 | .idea 41 | 42 | # Coverage 43 | cover 44 | .coveragerc 45 | 46 | # Sphinx 47 | docs/_build 48 | README.html 49 | 50 | *.ipynb 51 | .ipynb_checkpoints 52 | 53 | Vagrantfile 54 | .vagrant 55 | 56 | *.db 57 | *.ai 58 | .konchrc 59 | _sandbox 60 | pylintrc 61 | 62 | # Virtualenvs 63 | env 64 | venv 65 | 66 | # Other 67 | .directory 68 | *.pprof 69 | .mypy_cache 70 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "marshmallow"] 2 | path = marshmallow 3 | url = https://github.com/lyft/marshmallow.git 4 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [pylint] 2 | disable = missing-docstring, 3 | import-error, 4 | too-few-public-methods, 5 | protected-access, 6 | too-many-branches, 7 | too-many-instance-attributes, 8 | too-many-locals, 9 | too-many-statements 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | cache: pip 4 | 5 | matrix: 6 | include: 7 | - python: "2.7" 8 | - python: "3.4" 9 | - python: "3.5" 10 | - python: "3.6" 11 | - python: "pypy" 12 | 13 | 14 | before_install: 15 | - pip install pip setuptools --upgrade 16 | 17 | install: 18 | - pip install -r dev-requirements.txt 19 | - python setup.py install 20 | - if [[ $TRAVIS_PYTHON_VERSION == "3.6" ]]; then pip install mypy; fi 21 | 22 | script: 23 | - if [[ $TRAVIS_PYTHON_VERSION == "3.6" ]]; then mypy --ignore-missing-imports --follow-imports=silent --python-version=3.6 --warn-no-return --strict-optional toastedmarshmallow tests; fi 24 | - if [[ $TRAVIS_PYTHON_VERSION == "3.6" ]]; then mypy --ignore-missing-imports --follow-imports=silent --python-version=2.7 --warn-no-return --strict-optional toastedmarshmallow tests; fi 25 | - if [[ $TRAVIS_PYTHON_VERSION != "pypy" ]]; then flake8 performance tests; fi 26 | - if [[ $TRAVIS_PYTHON_VERSION != "pypy" ]]; then pylint toastedmarshmallow; fi 27 | - python -m pytest --cov toastedmarshmallow --cov-report term-missing tests 28 | - python performance/benchmark.py 29 | - ./tests/run_marshmallow_tests.sh 30 | 31 | after_success: 32 | - coveralls || echo "!! intermittent coveralls failure" 33 | 34 | notifications: 35 | email: false 36 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | This project is governed by [Lyft's code of conduct](https://github.com/lyft/code-of-conduct). All contributors and participants agree to abide by its terms. 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "{}" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2014 Lyft, Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.rst LICENSE NOTICE 2 | recursive-include toastedmarshmallow * 3 | recursive-exclude docs *.pyc 4 | recursive-exclude docs *.pyo 5 | recursive-exclude tests *.pyc 6 | recursive-exclude tests *.pyo 7 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | toastedmarshmallow 2 | Copyright 2017 Lyft Inc. 3 | 4 | This product includes software developed at Lyft Inc. 5 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ************************************************************* 2 | :fire:toastedmarshmallow:fire:: Makes Marshmallow Toasty Fast 3 | ************************************************************* 4 | 5 | Toasted Marshmallow implements a JIT for marshmallow that speeds up dumping 6 | objects 10-25X (depending on your schema). Toasted Marshmallow allows you to 7 | have the great API that 8 | `Marshmallow `_ provides 9 | without having to sacrifice performance! 10 | 11 | :: 12 | 13 | Benchmark Result: 14 | Original Time: 2682.61 usec/dump 15 | Optimized Time: 176.38 usec/dump 16 | Speed up: 15.21x 17 | 18 | Even ``PyPy`` benefits from ``toastedmarshmallow``! 19 | 20 | :: 21 | 22 | Benchmark Result: 23 | Original Time: 189.78 usec/dump 24 | Optimized Time: 20.03 usec/dump 25 | Speed up: 9.48x 26 | 27 | Installing toastedmarshmallow 28 | ----------------------------- 29 | 30 | .. code-block:: bash 31 | 32 | pip install toastedmarshmallow 33 | 34 | This will *also* install a slightly-forked ``marshmallow`` that includes some 35 | hooks Toastedmarshmallow needs enable the JIT to run before falling back 36 | to the original marshmallow code. These changes are minimal making it easier 37 | to track upstream. You can find the changes 38 | `Here `_. 39 | 40 | This means you should **remove** ``marshmallow`` from your requirements and 41 | replace it with ``toastedmarshmallow``. By default there is no 42 | difference unless you explicitly enable Toasted Marshmallow. 43 | 44 | Enabling Toasted Marshmallow 45 | ---------------------------- 46 | 47 | Enabling Toasted Marshmallow on an existing Schema is just one line of code, 48 | set the ``jit`` property on any ``Schema`` instance to 49 | ``toastedmarshmallow.Jit``. For example: 50 | 51 | .. code-block:: python 52 | 53 | from datetime import date 54 | import toastedmarshmallow 55 | from marshmallow import Schema, fields, pprint 56 | 57 | class ArtistSchema(Schema): 58 | name = fields.Str() 59 | 60 | class AlbumSchema(Schema): 61 | title = fields.Str() 62 | release_date = fields.Date() 63 | artist = fields.Nested(ArtistSchema()) 64 | 65 | schema = AlbumSchema() 66 | # Specify the jit method as toastedmarshmallow's jit 67 | schema.jit = toastedmarshmallow.Jit 68 | # And that's it! Your dump methods are 15x faster! 69 | 70 | It's also possible to use the ``Meta`` class on the ``Marshmallow`` schema 71 | to specify all instances of a given ``Schema`` should be optimized: 72 | 73 | .. code-block:: python 74 | 75 | import toastedmarshmallow 76 | from marshmallow import Schema, fields, pprint 77 | 78 | class ArtistSchema(Schema): 79 | class Meta: 80 | jit = toastedMarshmallow.Jit 81 | name = fields.Str() 82 | 83 | You can also enable Toasted Marshmallow globally by setting the environment 84 | variable ``MARSHMALLOW_SCHEMA_DEFAULT_JIT`` to ``toastedmarshmallow.Jit`` . 85 | Future versions of Toasted Marshmallow may make this the default. 86 | 87 | How it works 88 | ------------ 89 | 90 | Toasted Marshmallow works by generating code at runtime to optimize dumping 91 | objects without going through layers and layers of reflection. The generated 92 | code optimistically assumes the objects being passed in are schematically valid, 93 | falling back to the original marshmallow code on failure. 94 | 95 | For example, taking ``AlbumSchema`` from above, Toastedmarshmallow will 96 | generate the following 3 methods: 97 | 98 | .. code-block:: python 99 | 100 | def InstanceSerializer(obj): 101 | res = {} 102 | value = obj.release_date; value = value() if callable(value) else value; res["release_date"] = _field_release_date__serialize(value, "release_date", obj) 103 | value = obj.artist; value = value() if callable(value) else value; res["artist"] = _field_artist__serialize(value, "artist", obj) 104 | value = obj.title; value = value() if callable(value) else value; value = str(value) if value is not None else None; res["title"] = value 105 | return res 106 | 107 | def DictSerializer(obj): 108 | res = {} 109 | if "release_date" in obj: 110 | value = obj["release_date"]; value = value() if callable(value) else value; res["release_date"] = _field_release_date__serialize(value, "release_date", obj) 111 | if "artist" in obj: 112 | value = obj["artist"]; value = value() if callable(value) else value; res["artist"] = _field_artist__serialize(value, "artist", obj) 113 | if "title" in obj: 114 | value = obj["title"]; value = value() if callable(value) else value; value = str(value) if value is not None else None; res["title"] = value 115 | return res 116 | 117 | def HybridSerializer(obj): 118 | res = {} 119 | try: 120 | value = obj["release_date"] 121 | except (KeyError, AttributeError, IndexError, TypeError): 122 | value = obj.release_date 123 | value = value; value = value() if callable(value) else value; res["release_date"] = _field_release_date__serialize(value, "release_date", obj) 124 | try: 125 | value = obj["artist"] 126 | except (KeyError, AttributeError, IndexError, TypeError): 127 | value = obj.artist 128 | value = value; value = value() if callable(value) else value; res["artist"] = _field_artist__serialize(value, "artist", obj) 129 | try: 130 | value = obj["title"] 131 | except (KeyError, AttributeError, IndexError, TypeError): 132 | value = obj.title 133 | value = value; value = value() if callable(value) else value; value = str(value) if value is not None else None; res["title"] = value 134 | return res 135 | 136 | Toastedmarshmallow will invoke the proper serializer based upon the input. 137 | 138 | Since Toastedmarshmallow is generating code at runtime, it's critical you 139 | re-use Schema objects. If you're creating a new Schema object every time you 140 | serialize/deserialize an object you'll likely have much worse performance. 141 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | coverage==4.4.1 2 | coveralls==1.1 3 | flake8==3.3.0 4 | protobuf==3.3.0 5 | pytest==3.0.7 6 | pytest-cov==2.5.1 7 | python-coveralls==2.9.1 8 | pylint==1.7.1 9 | astroid==1.5.3 10 | python-dateutil==2.6.1 11 | pytz==2017.2 12 | simplejson==3.11.1 13 | -------------------------------------------------------------------------------- /performance/benchmark.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message Author { 4 | int64 id = 1; 5 | string first = 2; 6 | string last = 3; 7 | float book_count = 4; 8 | float age = 5; 9 | string address = 6; 10 | } 11 | 12 | message Quote { 13 | int64 id = 1; 14 | Author author = 2; 15 | string content = 3; 16 | int64 posted_at = 4; 17 | string book_name = 5; 18 | float page_number = 6; 19 | float line_number = 7; 20 | float col_number = 8; 21 | } 22 | 23 | message Quotes { 24 | repeated Quote quotes = 1; 25 | } 26 | -------------------------------------------------------------------------------- /performance/benchmark.py: -------------------------------------------------------------------------------- 1 | """Benchmark for Marshmallow serialization of a moderately complex object. 2 | 3 | """ 4 | 5 | from __future__ import print_function, unicode_literals, division 6 | 7 | import argparse 8 | import cProfile 9 | import gc 10 | import time 11 | import timeit 12 | from marshmallow import Schema, fields, ValidationError 13 | from toastedmarshmallow import Jit 14 | 15 | 16 | # Custom validator 17 | def must_not_be_blank(data): 18 | if not data: 19 | raise ValidationError('Data not provided.') 20 | 21 | 22 | class AuthorSchema(Schema): 23 | class Meta: 24 | jit_options = { 25 | 'no_callable_fields': True, 26 | } 27 | id = fields.Int(dump_only=True) 28 | first = fields.Str() 29 | last = fields.Str() 30 | book_count = fields.Float() 31 | age = fields.Float() 32 | address = fields.Str() 33 | deceased = fields.Boolean() 34 | 35 | def full_name(self, obj): 36 | return obj.first + ' ' + obj.last 37 | 38 | def format_name(self, author): 39 | return "{0}, {1}".format(author.last, author.first) 40 | 41 | 42 | class QuoteSchema(Schema): 43 | class Meta: 44 | jit_options = { 45 | 'no_callable_fields': True, 46 | 'expected_marshal_type': 'object', 47 | } 48 | 49 | id = fields.Int(dump_only=True) 50 | author = fields.Nested(AuthorSchema) 51 | content = fields.Str(required=True) 52 | posted_at = fields.Int(dump_only=True) 53 | book_name = fields.Str() 54 | page_number = fields.Float() 55 | line_number = fields.Float() 56 | col_number = fields.Float() 57 | is_verified = fields.Boolean() 58 | 59 | 60 | class Author(object): 61 | def __init__(self, id, first, last, book_count, age, address, deceased): 62 | self.id = id 63 | self.first = first 64 | self.last = last 65 | self.book_count = book_count 66 | self.age = age 67 | self.address = address 68 | self.deceased = deceased 69 | 70 | 71 | class Quote(object): 72 | def __init__(self, id, author, content, posted_at, book_name, page_number, 73 | line_number, col_number, is_verified): 74 | self.id = id 75 | self.author = author 76 | self.content = content 77 | self.posted_at = posted_at 78 | self.book_name = book_name 79 | self.page_number = page_number 80 | self.line_number = line_number 81 | self.col_number = col_number 82 | self.is_verified = is_verified 83 | 84 | 85 | def run_timeit(quotes, iterations, repeat, jit=False, load=False, 86 | profile=False): 87 | quotes_schema = QuoteSchema(many=True) 88 | if jit: 89 | quotes_schema.jit = Jit 90 | if profile: 91 | profile = cProfile.Profile() 92 | profile.enable() 93 | dumped_quotes = quotes_schema.dump(quotes).data 94 | gc.collect() 95 | 96 | if load: 97 | def marshmallow_func(): 98 | quotes_schema.load(dumped_quotes, many=True) 99 | else: 100 | def marshmallow_func(): 101 | quotes_schema.dump(quotes) 102 | 103 | best = min(timeit.repeat(marshmallow_func, 104 | 'gc.enable()', 105 | number=iterations, 106 | repeat=repeat)) 107 | if profile: 108 | profile.disable() 109 | file_name = 'optimized.pprof' if jit else 'original.pprof' 110 | profile.dump_stats(file_name) 111 | 112 | usec = best * 1e6 / iterations 113 | return usec 114 | 115 | 116 | def main(): 117 | parser = argparse.ArgumentParser( 118 | description='Runs a benchmark of Marshmallow.') 119 | parser.add_argument('--iterations', type=int, default=1000, 120 | help='Number of iterations to run per test.') 121 | parser.add_argument('--repeat', type=int, default=5, 122 | help='Number of times to repeat the performance test. ' 123 | 'The minimum will be used.') 124 | parser.add_argument('--object-count', type=int, default=20, 125 | help='Number of objects to dump.') 126 | parser.add_argument('--profile', action='store_true', 127 | help='Whether or not to profile Marshmallow while ' 128 | 'running the benchmark.') 129 | args = parser.parse_args() 130 | 131 | quotes = [] 132 | for i in range(args.object_count): 133 | quotes.append( 134 | Quote(i, Author(i, 'Foo', 'Bar', 42, 66, '123 Fake St', False), 135 | 'Hello World', time.time(), 'The World', 34, 3, 70, False) 136 | ) 137 | 138 | original_dump_time = run_timeit(quotes, args.iterations, args.repeat, 139 | load=False, jit=False, 140 | profile=args.profile) 141 | original_load_time = run_timeit(quotes, args.iterations, args.repeat, 142 | load=True, jit=False, profile=args.profile) 143 | optimized_dump_time = run_timeit(quotes, args.iterations, args.repeat, 144 | load=False, jit=True, 145 | profile=args.profile) 146 | optimized_load_time = run_timeit(quotes, args.iterations, args.repeat, 147 | load=True, jit=True, profile=args.profile) 148 | print('Benchmark Result:') 149 | print('\tOriginal Dump Time: {0:.2f} usec/dump'.format(original_dump_time)) 150 | print('\tOptimized Dump Time: {0:.2f} usec/dump'.format( 151 | optimized_dump_time)) 152 | print('\tOriginal Load Time: {0:.2f} usec/load'.format(original_load_time)) 153 | print('\tOptimized Load Time: {0:.2f} usec/load'.format( 154 | optimized_load_time)) 155 | print('\tSpeed up for dump: {0:.2f}x'.format( 156 | original_dump_time / optimized_dump_time)) 157 | print('\tSpeed up for load: {0:.2f}x'.format( 158 | original_load_time / optimized_load_time)) 159 | 160 | 161 | if __name__ == '__main__': 162 | main() 163 | -------------------------------------------------------------------------------- /performance/benchmark_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: performance/benchmark.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='performance/benchmark.proto', 20 | package='', 21 | syntax='proto3', 22 | serialized_pb=_b('\n\x1bperformance/benchmark.proto\"c\n\x06\x41uthor\x12\n\n\x02id\x18\x01 \x01(\x03\x12\r\n\x05\x66irst\x18\x02 \x01(\t\x12\x0c\n\x04last\x18\x03 \x01(\t\x12\x12\n\nbook_count\x18\x04 \x01(\x02\x12\x0b\n\x03\x61ge\x18\x05 \x01(\x02\x12\x0f\n\x07\x61\x64\x64ress\x18\x06 \x01(\t\"\xa1\x01\n\x05Quote\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x17\n\x06\x61uthor\x18\x02 \x01(\x0b\x32\x07.Author\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\x12\x11\n\tposted_at\x18\x04 \x01(\x03\x12\x11\n\tbook_name\x18\x05 \x01(\t\x12\x13\n\x0bpage_number\x18\x06 \x01(\x02\x12\x13\n\x0bline_number\x18\x07 \x01(\x02\x12\x12\n\ncol_number\x18\x08 \x01(\x02\" \n\x06Quotes\x12\x16\n\x06quotes\x18\x01 \x03(\x0b\x32\x06.Quoteb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AUTHOR = _descriptor.Descriptor( 29 | name='Author', 30 | full_name='Author', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='id', full_name='Author.id', index=0, 37 | number=1, type=3, cpp_type=2, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='first', full_name='Author.first', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='last', full_name='Author.last', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='book_count', full_name='Author.book_count', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=False, default_value=float(0), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='age', full_name='Author.age', index=4, 65 | number=5, type=2, cpp_type=6, label=1, 66 | has_default_value=False, default_value=float(0), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='address', full_name='Author.address', index=5, 72 | number=6, type=9, cpp_type=9, label=1, 73 | has_default_value=False, default_value=_b("").decode('utf-8'), 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | ], 78 | extensions=[ 79 | ], 80 | nested_types=[], 81 | enum_types=[ 82 | ], 83 | options=None, 84 | is_extendable=False, 85 | syntax='proto3', 86 | extension_ranges=[], 87 | oneofs=[ 88 | ], 89 | serialized_start=31, 90 | serialized_end=130, 91 | ) 92 | 93 | 94 | _QUOTE = _descriptor.Descriptor( 95 | name='Quote', 96 | full_name='Quote', 97 | filename=None, 98 | file=DESCRIPTOR, 99 | containing_type=None, 100 | fields=[ 101 | _descriptor.FieldDescriptor( 102 | name='id', full_name='Quote.id', index=0, 103 | number=1, type=3, cpp_type=2, label=1, 104 | has_default_value=False, default_value=0, 105 | message_type=None, enum_type=None, containing_type=None, 106 | is_extension=False, extension_scope=None, 107 | options=None), 108 | _descriptor.FieldDescriptor( 109 | name='author', full_name='Quote.author', index=1, 110 | number=2, type=11, cpp_type=10, label=1, 111 | has_default_value=False, default_value=None, 112 | message_type=None, enum_type=None, containing_type=None, 113 | is_extension=False, extension_scope=None, 114 | options=None), 115 | _descriptor.FieldDescriptor( 116 | name='content', full_name='Quote.content', index=2, 117 | number=3, type=9, cpp_type=9, label=1, 118 | has_default_value=False, default_value=_b("").decode('utf-8'), 119 | message_type=None, enum_type=None, containing_type=None, 120 | is_extension=False, extension_scope=None, 121 | options=None), 122 | _descriptor.FieldDescriptor( 123 | name='posted_at', full_name='Quote.posted_at', index=3, 124 | number=4, type=3, cpp_type=2, label=1, 125 | has_default_value=False, default_value=0, 126 | message_type=None, enum_type=None, containing_type=None, 127 | is_extension=False, extension_scope=None, 128 | options=None), 129 | _descriptor.FieldDescriptor( 130 | name='book_name', full_name='Quote.book_name', index=4, 131 | number=5, type=9, cpp_type=9, label=1, 132 | has_default_value=False, default_value=_b("").decode('utf-8'), 133 | message_type=None, enum_type=None, containing_type=None, 134 | is_extension=False, extension_scope=None, 135 | options=None), 136 | _descriptor.FieldDescriptor( 137 | name='page_number', full_name='Quote.page_number', index=5, 138 | number=6, type=2, cpp_type=6, label=1, 139 | has_default_value=False, default_value=float(0), 140 | message_type=None, enum_type=None, containing_type=None, 141 | is_extension=False, extension_scope=None, 142 | options=None), 143 | _descriptor.FieldDescriptor( 144 | name='line_number', full_name='Quote.line_number', index=6, 145 | number=7, type=2, cpp_type=6, label=1, 146 | has_default_value=False, default_value=float(0), 147 | message_type=None, enum_type=None, containing_type=None, 148 | is_extension=False, extension_scope=None, 149 | options=None), 150 | _descriptor.FieldDescriptor( 151 | name='col_number', full_name='Quote.col_number', index=7, 152 | number=8, type=2, cpp_type=6, label=1, 153 | has_default_value=False, default_value=float(0), 154 | message_type=None, enum_type=None, containing_type=None, 155 | is_extension=False, extension_scope=None, 156 | options=None), 157 | ], 158 | extensions=[ 159 | ], 160 | nested_types=[], 161 | enum_types=[ 162 | ], 163 | options=None, 164 | is_extendable=False, 165 | syntax='proto3', 166 | extension_ranges=[], 167 | oneofs=[ 168 | ], 169 | serialized_start=133, 170 | serialized_end=294, 171 | ) 172 | 173 | 174 | _QUOTES = _descriptor.Descriptor( 175 | name='Quotes', 176 | full_name='Quotes', 177 | filename=None, 178 | file=DESCRIPTOR, 179 | containing_type=None, 180 | fields=[ 181 | _descriptor.FieldDescriptor( 182 | name='quotes', full_name='Quotes.quotes', index=0, 183 | number=1, type=11, cpp_type=10, label=3, 184 | has_default_value=False, default_value=[], 185 | message_type=None, enum_type=None, containing_type=None, 186 | is_extension=False, extension_scope=None, 187 | options=None), 188 | ], 189 | extensions=[ 190 | ], 191 | nested_types=[], 192 | enum_types=[ 193 | ], 194 | options=None, 195 | is_extendable=False, 196 | syntax='proto3', 197 | extension_ranges=[], 198 | oneofs=[ 199 | ], 200 | serialized_start=296, 201 | serialized_end=328, 202 | ) 203 | 204 | _QUOTE.fields_by_name['author'].message_type = _AUTHOR 205 | _QUOTES.fields_by_name['quotes'].message_type = _QUOTE 206 | DESCRIPTOR.message_types_by_name['Author'] = _AUTHOR 207 | DESCRIPTOR.message_types_by_name['Quote'] = _QUOTE 208 | DESCRIPTOR.message_types_by_name['Quotes'] = _QUOTES 209 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 210 | 211 | Author = _reflection.GeneratedProtocolMessageType('Author', (_message.Message,), dict( 212 | DESCRIPTOR = _AUTHOR, 213 | __module__ = 'performance.benchmark_pb2' 214 | # @@protoc_insertion_point(class_scope:Author) 215 | )) 216 | _sym_db.RegisterMessage(Author) 217 | 218 | Quote = _reflection.GeneratedProtocolMessageType('Quote', (_message.Message,), dict( 219 | DESCRIPTOR = _QUOTE, 220 | __module__ = 'performance.benchmark_pb2' 221 | # @@protoc_insertion_point(class_scope:Quote) 222 | )) 223 | _sym_db.RegisterMessage(Quote) 224 | 225 | Quotes = _reflection.GeneratedProtocolMessageType('Quotes', (_message.Message,), dict( 226 | DESCRIPTOR = _QUOTES, 227 | __module__ = 'performance.benchmark_pb2' 228 | # @@protoc_insertion_point(class_scope:Quotes) 229 | )) 230 | _sym_db.RegisterMessage(Quotes) 231 | 232 | 233 | # @@protoc_insertion_point(module_scope) 234 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,__pycache__,*_pb2.py 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import re 4 | from setuptools import setup 5 | 6 | 7 | EXTRA_REQUIREMENTS = ['python-dateutil', 'simplejson'] 8 | 9 | 10 | def find_version(fname): 11 | """Attempts to find the version number in the file names fname. 12 | Raises RuntimeError if not found. 13 | """ 14 | version = '' 15 | with open(fname, 'r') as fp: 16 | reg = re.compile(r'__version__ = [\'"]([^\'"]*)[\'"]') 17 | for line in fp: 18 | m = reg.match(line) 19 | if m: 20 | version = m.group(1) 21 | break 22 | if not version: 23 | raise RuntimeError('Cannot find version information') 24 | return version 25 | 26 | 27 | __version__ = find_version("toastedmarshmallow/__init__.py") 28 | 29 | 30 | def read(fname): 31 | with open(fname) as fp: 32 | content = fp.read() 33 | return content 34 | 35 | 36 | setup( 37 | name='toastedmarshmallow', 38 | version=__version__, 39 | description=('A JIT implementation for Marshmallow to speed up ' 40 | 'dumping and loading objects.'), 41 | long_description=read('README.rst'), 42 | author='Roy Williams', 43 | author_email='rwilliams@lyft.com', 44 | url='https://github.com/lyft/toastedmarshmallow', 45 | packages=['toastedmarshmallow', 'marshmallow'], 46 | package_dir={ 47 | 'toastedmarshmallow': 'toastedmarshmallow', 48 | 'marshmallow': 'marshmallow/marshmallow' 49 | }, 50 | include_package_data=True, 51 | extras_require={'reco': EXTRA_REQUIREMENTS}, 52 | license='apache2', 53 | install_requires=[ 54 | 'attrs >= 17.1.0' 55 | ], 56 | zip_safe=False, 57 | keywords=( 58 | 'serialization', 'rest', 'json', 'api', 'marshal', 59 | 'marshalling', 'deserialization', 'validation', 'schema' 60 | ), 61 | classifiers=[ 62 | 'Development Status :: 5 - Production/Stable', 63 | 'Intended Audience :: Developers', 64 | 'License :: OSI Approved :: Apache Software License', 65 | 'Programming Language :: Python :: 2', 66 | 'Programming Language :: Python :: 2.7', 67 | 'Programming Language :: Python :: 3', 68 | 'Programming Language :: Python :: 3.4', 69 | 'Programming Language :: Python :: 3.5', 70 | 'Programming Language :: Python :: 3.6', 71 | 'Programming Language :: Python :: Implementation :: CPython', 72 | 'Programming Language :: Python :: Implementation :: PyPy', 73 | ], 74 | test_suite='tests' 75 | ) 76 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyft/toasted-marshmallow/00a8e76198e45a5731a664d3fec31e138ee4035c/tests/__init__.py -------------------------------------------------------------------------------- /tests/run_marshmallow_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd marshmallow 4 | pip install -U .[reco] 5 | pip install -U -r dev-requirements.txt 6 | export MARSHMALLOW_SCHEMA_DEFAULT_JIT=toastedmarshmallow.Jit 7 | invoke test 8 | -------------------------------------------------------------------------------- /tests/test_compat.py: -------------------------------------------------------------------------------- 1 | from toastedmarshmallow.compat import is_overridden 2 | 3 | 4 | class Base(object): 5 | def foo(self): 6 | pass 7 | 8 | 9 | class NoOverride(Base): 10 | pass 11 | 12 | 13 | class HasOverride(Base): 14 | def foo(self): 15 | pass 16 | 17 | 18 | def test_is_overridden(): 19 | assert is_overridden(HasOverride().foo, Base.foo) 20 | assert not is_overridden(NoOverride().foo, Base.foo) 21 | -------------------------------------------------------------------------------- /tests/test_jit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from marshmallow import fields, Schema 3 | from six import text_type 4 | 5 | from toastedmarshmallow.jit import ( 6 | attr_str, field_symbol_name, InstanceSerializer, DictSerializer, 7 | HybridSerializer, 8 | generate_transform_method_body, generate_method_bodies, 9 | generate_marshall_method, generate_unmarshall_method, JitContext) 10 | 11 | 12 | @pytest.fixture() 13 | def simple_schema(): 14 | class InstanceSchema(Schema): 15 | key = fields.String() 16 | value = fields.Integer(default=0) 17 | return InstanceSchema() 18 | 19 | 20 | @pytest.fixture() 21 | def nested_circular_ref_schema(): 22 | class NestedStringSchema(Schema): 23 | key = fields.String() 24 | me = fields.Nested('NestedStringSchema') 25 | return NestedStringSchema() 26 | 27 | 28 | @pytest.fixture() 29 | def nested_schema(): 30 | class GrandChildSchema(Schema): 31 | bar = fields.String() 32 | raz = fields.String() 33 | 34 | class SubSchema(Schema): 35 | name = fields.String() 36 | value = fields.Nested(GrandChildSchema) 37 | 38 | class NestedSchema(Schema): 39 | key = fields.String() 40 | value = fields.Nested(SubSchema, only=('name', 'value.bar')) 41 | values = fields.Nested(SubSchema, exclude=('value', ), many=True) 42 | return NestedSchema() 43 | 44 | 45 | @pytest.fixture() 46 | def optimized_schema(): 47 | class OptimizedSchema(Schema): 48 | class Meta: 49 | jit_options = { 50 | 'no_callable_fields': True, 51 | 'expected_marshal_type': 'object' 52 | } 53 | key = fields.String() 54 | value = fields.Integer(default=0, as_string=True) 55 | return OptimizedSchema() 56 | 57 | 58 | @pytest.fixture() 59 | def simple_object(): 60 | class InstanceObject(object): 61 | def __init__(self): 62 | self.key = 'some_key' 63 | self.value = 42 64 | return InstanceObject() 65 | 66 | 67 | @pytest.fixture() 68 | def simple_dict(): 69 | return { 70 | 'key': u'some_key', 71 | 'value': 42 72 | } 73 | 74 | 75 | @pytest.fixture() 76 | def simple_hybrid(): 77 | class HybridObject(object): 78 | def __init__(self): 79 | self.key = 'some_key' 80 | 81 | def __getitem__(self, item): 82 | if item == 'value': 83 | return 42 84 | raise KeyError() 85 | return HybridObject() 86 | 87 | 88 | @pytest.fixture() 89 | def schema(): 90 | class BasicSchema(Schema): 91 | class Meta: 92 | ordered = True 93 | foo = fields.Integer(attribute='@#') 94 | bar = fields.String() 95 | raz = fields.Method('raz_') 96 | meh = fields.String(load_only=True) 97 | blargh = fields.Boolean() 98 | 99 | def raz_(self, obj): 100 | return 'Hello!' 101 | return BasicSchema() 102 | 103 | 104 | class RoundedFloat(fields.Float): 105 | def __init__(self, places, **kwargs): 106 | super(fields.Float, self).__init__(**kwargs) 107 | self.num_type = lambda x: round(x, places) 108 | 109 | 110 | @pytest.fixture 111 | def non_primitive_num_type_schema(): 112 | class NonPrimitiveNumTypeSchema(Schema): 113 | gps_longitude = RoundedFloat(places=6, attribute='lng') 114 | return NonPrimitiveNumTypeSchema() 115 | 116 | 117 | def test_field_symbol_name(): 118 | assert '_field_hello' == field_symbol_name('hello') 119 | assert '_field_MHdvcmxkMA' == field_symbol_name('0world0') 120 | 121 | 122 | def test_attr_str(): 123 | assert 'obj.foo' == attr_str('foo') 124 | assert 'getattr(obj, "def")' == attr_str('def') 125 | 126 | 127 | def test_instance_serializer(): 128 | serializer = InstanceSerializer() 129 | field = fields.Integer() 130 | assert 'result["foo"] = obj.foo' == str(serializer.serialize( 131 | 'foo', 'bar', 'result["foo"] = {0}', field)) 132 | 133 | 134 | def test_dict_serializer_with_default(): 135 | serializer = DictSerializer() 136 | field = fields.Integer(default=3) 137 | result = str(serializer.serialize('foo', 'bar', 'result["foo"] = {0}', 138 | field)) 139 | assert 'result["foo"] = obj.get("foo", bar__default)' == result 140 | 141 | 142 | def test_dict_serializer_with_callable_default(): 143 | serializer = DictSerializer() 144 | field = fields.Integer(default=int) 145 | result = str(serializer.serialize('foo', 'bar', 'result["foo"] = {0}', 146 | field)) 147 | assert 'result["foo"] = obj.get("foo", bar__default())' == result 148 | 149 | 150 | def test_dict_serializer_no_default(): 151 | serializer = DictSerializer() 152 | field = fields.Integer() 153 | result = str(serializer.serialize('foo', 'bar', 'result["foo"] = {0}', 154 | field)) 155 | expected = ('if "foo" in obj:\n' 156 | ' result["foo"] = obj["foo"]') 157 | assert expected == result 158 | 159 | 160 | def test_hybrid_serializer(): 161 | serializer = HybridSerializer() 162 | field = fields.Integer() 163 | result = str(serializer.serialize('foo', 'bar', 'result["foo"] = {0}', 164 | field)) 165 | expected = ('try:\n' 166 | ' value = obj["foo"]\n' 167 | 'except (KeyError, AttributeError, IndexError, TypeError):\n' 168 | ' value = obj.foo\n' 169 | 'result["foo"] = value') 170 | assert expected == result 171 | 172 | 173 | def test_generate_marshall_method_body(schema): 174 | expected_start = '''\ 175 | def InstanceSerializer(obj): 176 | res = dict_class() 177 | ''' 178 | raz_assignment = ('value = None; ' 179 | 'value = value() if callable(value) else value; ' 180 | 'res["raz"] = _field_raz__serialize(value, "raz", obj)') 181 | 182 | foo_assignment = ( 183 | 'if "@#" in obj:\n' 184 | ' value = obj["@#"]; ' 185 | 'value = value() if callable(value) else value; ' 186 | 'value = int(value) if value is not None else None; ' 187 | 'res["foo"] = value') 188 | bar_assignment = ( 189 | 'value = obj.bar; ' 190 | 'value = value() if callable(value) else value; ' 191 | 'value = {text_type}(value) if value is not None else None; ' 192 | 'res["bar"] = value').format(text_type=text_type.__name__) 193 | blargh_assignment = ( 194 | 'value = obj.blargh; ' 195 | 'value = value() if callable(value) else value; ' 196 | 'value = ((value in __blargh_truthy) or ' 197 | '(False if value in __blargh_falsy else dict()["error"])) ' 198 | 'if value is not None else None; ' 199 | 'res["blargh"] = value') 200 | 201 | context = JitContext() 202 | result = str(generate_transform_method_body(schema, 203 | InstanceSerializer(), 204 | context)) 205 | assert result.startswith(expected_start) 206 | assert raz_assignment in result 207 | assert foo_assignment in result 208 | assert bar_assignment in result 209 | assert blargh_assignment in result 210 | assert 'meh' not in result 211 | assert result.endswith('return res') 212 | 213 | 214 | def test_generate_marshall_method_bodies(): 215 | class OneFieldSchema(Schema): 216 | foo = fields.Integer() 217 | context = JitContext() 218 | result = generate_method_bodies(OneFieldSchema(), context) 219 | expected = '''\ 220 | def InstanceSerializer(obj): 221 | res = {} 222 | value = obj.foo; value = value() if callable(value) else value; \ 223 | value = int(value) if value is not None else None; res["foo"] = value 224 | return res 225 | def DictSerializer(obj): 226 | res = {} 227 | if "foo" in obj: 228 | value = obj["foo"]; value = value() if callable(value) else value; \ 229 | value = int(value) if value is not None else None; res["foo"] = value 230 | return res 231 | def HybridSerializer(obj): 232 | res = {} 233 | try: 234 | value = obj["foo"] 235 | except (KeyError, AttributeError, IndexError, TypeError): 236 | value = obj.foo 237 | value = value; value = value() if callable(value) else value; \ 238 | value = int(value) if value is not None else None; res["foo"] = value 239 | return res''' 240 | assert expected == result 241 | 242 | 243 | def test_generate_unmarshall_method_bodies(): 244 | class OneFieldSchema(Schema): 245 | foo = fields.Integer() 246 | context = JitContext(is_serializing=False, use_inliners=False) 247 | result = generate_method_bodies(OneFieldSchema(), context) 248 | expected = '''\ 249 | def InstanceSerializer(obj): 250 | res = {} 251 | __res_get = res.get 252 | res["foo"] = _field_foo__deserialize(obj.foo, "foo", obj) 253 | if __res_get("foo", res) is None: 254 | raise ValueError() 255 | return res 256 | def DictSerializer(obj): 257 | res = {} 258 | __res_get = res.get 259 | if "foo" in obj: 260 | res["foo"] = _field_foo__deserialize(obj["foo"], "foo", obj) 261 | if __res_get("foo", res) is None: 262 | raise ValueError() 263 | return res 264 | def HybridSerializer(obj): 265 | res = {} 266 | __res_get = res.get 267 | try: 268 | value = obj["foo"] 269 | except (KeyError, AttributeError, IndexError, TypeError): 270 | value = obj.foo 271 | res["foo"] = _field_foo__deserialize(value, "foo", obj) 272 | if __res_get("foo", res) is None: 273 | raise ValueError() 274 | return res''' 275 | assert expected == result 276 | 277 | 278 | def test_generate_unmarshall_method_bodies_with_load_from(): 279 | class OneFieldSchema(Schema): 280 | foo = fields.Integer(load_from='bar', allow_none=True) 281 | context = JitContext(is_serializing=False, use_inliners=False) 282 | result = str(generate_transform_method_body(OneFieldSchema(), 283 | DictSerializer(context), 284 | context)) 285 | expected = '''\ 286 | def DictSerializer(obj): 287 | res = {} 288 | __res_get = res.get 289 | if "foo" in obj: 290 | res["foo"] = _field_foo__deserialize(obj["foo"], "bar", obj) 291 | if "foo" not in res: 292 | if "bar" in obj: 293 | res["foo"] = _field_foo__deserialize(obj["bar"], "bar", obj) 294 | return res''' 295 | assert expected == result 296 | 297 | 298 | def test_generate_unmarshall_method_bodies_required(): 299 | class OneFieldSchema(Schema): 300 | foo = fields.Integer(required=True) 301 | context = JitContext(is_serializing=False, use_inliners=False) 302 | result = str(generate_transform_method_body(OneFieldSchema(), 303 | DictSerializer(context), 304 | context)) 305 | expected = '''\ 306 | def DictSerializer(obj): 307 | res = {} 308 | __res_get = res.get 309 | res["foo"] = _field_foo__deserialize(obj["foo"], "foo", obj) 310 | if "foo" not in res: 311 | raise ValueError() 312 | if __res_get("foo", res) is None: 313 | raise ValueError() 314 | return res''' 315 | assert expected == result 316 | 317 | 318 | def test_jit_bails_with_get_attribute(): 319 | class DynamicSchema(Schema): 320 | def get_attribute(self, obj, attr, default): 321 | pass 322 | marshal_method = generate_marshall_method(DynamicSchema()) 323 | assert marshal_method is None 324 | 325 | 326 | def test_jit_bails_nested_attribute(): 327 | class DynamicSchema(Schema): 328 | foo = fields.String(attribute='foo.bar') 329 | 330 | marshal_method = generate_marshall_method(DynamicSchema()) 331 | assert marshal_method is None 332 | 333 | 334 | def test_jitted_marshal_method(schema): 335 | context = JitContext() 336 | marshal_method = generate_marshall_method(schema, threshold=1, 337 | context=context) 338 | result = marshal_method({ 339 | '@#': 32, 340 | 'bar': 'Hello', 341 | 'meh': 'Foo' 342 | }) 343 | expected = { 344 | 'bar': u'Hello', 345 | 'foo': 32, 346 | 'raz': 'Hello!' 347 | } 348 | assert expected == result 349 | # Test specialization 350 | result = marshal_method({ 351 | '@#': 32, 352 | 'bar': 'Hello', 353 | 'meh': 'Foo' 354 | }) 355 | assert expected == result 356 | assert marshal_method.proxy._call == marshal_method.proxy.dict_serializer 357 | 358 | 359 | def test_non_primitive_num_type_schema(non_primitive_num_type_schema): 360 | context = JitContext() 361 | marshall_method = generate_marshall_method( 362 | non_primitive_num_type_schema, threshold=1, context=context 363 | ) 364 | result = marshall_method({}) 365 | expected = {} 366 | assert expected == result 367 | 368 | 369 | def test_jitted_unmarshal_method(schema): 370 | context = JitContext() 371 | unmarshal_method = generate_unmarshall_method(schema, context=context) 372 | result = unmarshal_method({ 373 | 'foo': 32, 374 | 'bar': 'Hello', 375 | 'meh': 'Foo' 376 | }) 377 | expected = { 378 | 'bar': u'Hello', 379 | '@#': 32, 380 | 'meh': 'Foo' 381 | } 382 | assert expected == result 383 | 384 | assert not hasattr(unmarshal_method, 'proxy') 385 | 386 | 387 | def test_jitted_marshal_method_bails_on_specialize(simple_schema, 388 | simple_object, 389 | simple_dict, 390 | simple_hybrid): 391 | marshal_method = generate_marshall_method(simple_schema, threshold=2) 392 | assert simple_dict == marshal_method(simple_dict) 393 | assert marshal_method.proxy._call == marshal_method.proxy.tracing_call 394 | assert simple_dict == marshal_method(simple_object) 395 | assert marshal_method.proxy._call == marshal_method.proxy.no_tracing_call 396 | assert simple_dict == marshal_method(simple_object) 397 | assert marshal_method.proxy._call == marshal_method.proxy.no_tracing_call 398 | assert simple_dict == marshal_method(simple_dict) 399 | assert marshal_method.proxy._call == marshal_method.proxy.no_tracing_call 400 | assert simple_dict == marshal_method(simple_hybrid) 401 | assert marshal_method.proxy._call == marshal_method.proxy.no_tracing_call 402 | 403 | 404 | def test_dict_jitted_marshal_method(simple_schema): 405 | marshal_method = generate_marshall_method(simple_schema) 406 | result = marshal_method({'key': 'some_key'}) 407 | expected = { 408 | 'key': 'some_key', 409 | 'value': 0 410 | } 411 | assert expected == result 412 | 413 | 414 | def test_jitted_marshal_method_no_threshold(simple_schema, simple_dict): 415 | marshal_method = generate_marshall_method(simple_schema, threshold=0) 416 | assert marshal_method.proxy._call == marshal_method.proxy.no_tracing_call 417 | result = marshal_method(simple_dict) 418 | assert simple_dict == result 419 | assert marshal_method.proxy._call == marshal_method.proxy.no_tracing_call 420 | 421 | 422 | def test_hybrid_jitted_marshal_method(simple_schema, 423 | simple_hybrid, 424 | simple_dict): 425 | marshal_method = generate_marshall_method(simple_schema, threshold=1) 426 | result = marshal_method(simple_hybrid) 427 | assert simple_dict == result 428 | result = marshal_method(simple_hybrid) 429 | assert simple_dict == result 430 | assert marshal_method.proxy._call == marshal_method.proxy.hybrid_serializer 431 | 432 | 433 | def test_instance_jitted_instance_marshal_method(simple_schema, 434 | simple_object, 435 | simple_dict): 436 | marshal_method = generate_marshall_method(simple_schema, threshold=1) 437 | result = marshal_method(simple_object) 438 | assert simple_dict == result 439 | result = marshal_method(simple_object) 440 | assert simple_dict == result 441 | assert (marshal_method.proxy._call == 442 | marshal_method.proxy.instance_serializer) 443 | 444 | 445 | def test_instance_jitted_instance_marshal_method_supports_none_int( 446 | simple_schema, simple_object 447 | ): 448 | simple_object.value = None 449 | marshal_method = generate_marshall_method(simple_schema) 450 | result = marshal_method(simple_object) 451 | expected = { 452 | 'key': 'some_key', 453 | 'value': None 454 | } 455 | assert expected == result 456 | 457 | 458 | def test_optimized_jitted_marshal_method(optimized_schema, simple_object): 459 | marshal_method = generate_marshall_method(optimized_schema) 460 | result = marshal_method(simple_object) 461 | expected = { 462 | 'key': 'some_key', 463 | 'value': '42' 464 | } 465 | assert expected == result 466 | 467 | 468 | def test_nested_marshal_method_circular_ref(nested_circular_ref_schema): 469 | marshal_method = generate_marshall_method(nested_circular_ref_schema) 470 | result = marshal_method({ 471 | 'key': 'some_key', 472 | 'me': { 473 | 'key': 'sub_key' 474 | } 475 | }) 476 | expected = { 477 | 'key': 'some_key', 478 | 'me': { 479 | 'key': 'sub_key' 480 | } 481 | } 482 | assert expected == result 483 | 484 | 485 | def test_nested_marshal_method(nested_schema): 486 | marshal_method = generate_marshall_method(nested_schema) 487 | result = marshal_method({ 488 | 'key': 'some_key', 489 | 'value': { 490 | 'name': 'sub_key', 491 | 'value': { 492 | 'bar': 'frob', 493 | 'raz': 'blah' 494 | } 495 | }, 496 | 'values': [ 497 | { 498 | 'name': 'first_child', 499 | 'value': 'foo' 500 | }, 501 | { 502 | 'name': 'second_child', 503 | 'value': 'bar' 504 | } 505 | ] 506 | }) 507 | expected = { 508 | 'key': 'some_key', 509 | 'value': { 510 | 'name': 'sub_key', 511 | 'value': { 512 | 'bar': 'frob' 513 | } 514 | }, 515 | 'values': [ 516 | { 517 | 'name': 'first_child', 518 | }, 519 | { 520 | 'name': 'second_child' 521 | } 522 | ] 523 | } 524 | assert expected == result 525 | -------------------------------------------------------------------------------- /tests/test_toasted_marshmallow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from marshmallow import Schema, fields 3 | 4 | import toastedmarshmallow 5 | 6 | 7 | @pytest.fixture() 8 | def schema(): 9 | class TestSchema(Schema): 10 | key = fields.String(default='world') 11 | value = fields.Integer(missing=42) 12 | return TestSchema() 13 | 14 | 15 | def test_marshmallow_integration_dump(schema): 16 | schema.jit = toastedmarshmallow.Jit 17 | assert schema._jit_instance is not None 18 | 19 | result = schema.dump({'key': 'hello', 'value': 32}) 20 | assert not result.errors 21 | assert result.data == {'key': 'hello', 'value': 32} 22 | 23 | result = schema.dump({'value': 32}) 24 | assert not result.errors 25 | assert result.data == {'key': 'world', 'value': 32} 26 | 27 | assert schema._jit_instance is not None 28 | 29 | 30 | def test_marshmallow_integration_load(schema): 31 | schema.jit = toastedmarshmallow.Jit 32 | assert schema._jit_instance is not None 33 | 34 | result = schema.load({'key': 'hello', 'value': 32}) 35 | assert not result.errors 36 | assert result.data == {'key': 'hello', 'value': 32} 37 | 38 | result = schema.load([{'key': 'hello'}], many=True) 39 | assert not result.errors 40 | assert result.data == [{'key': 'hello', 'value': 42}] 41 | assert schema._jit_instance is not None 42 | 43 | 44 | def test_marshmallow_integration_invalid_data(schema): 45 | schema.jit = toastedmarshmallow.Jit 46 | assert schema._jit_instance is not None 47 | result = schema.dump({'key': 'hello', 'value': 'foo'}) 48 | assert {'value': ['Not a valid integer.']} == result.errors 49 | result = schema.load({'key': 'hello', 'value': 'foo'}) 50 | assert {'value': ['Not a valid integer.']} == result.errors 51 | assert schema._jit_instance is not None 52 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from toastedmarshmallow.utils import IndentedString 2 | 3 | 4 | def test_indented_string(): 5 | body = IndentedString() 6 | subbody = IndentedString('if False:') 7 | 8 | with subbody.indent(): 9 | subbody += 'print("How are you?")' 10 | 11 | body += 'def foo():' 12 | with body.indent(): 13 | body += 'print("Hello World!")' 14 | body += subbody 15 | 16 | assert str(body) == ('def foo():\n' 17 | ' print("Hello World!")\n' 18 | ' if False:\n' 19 | ' print("How are you?")') 20 | -------------------------------------------------------------------------------- /toastedmarshmallow/__init__.py: -------------------------------------------------------------------------------- 1 | from marshmallow import SchemaJit 2 | 3 | from .jit import ( 4 | generate_marshall_method, 5 | generate_unmarshall_method, 6 | JitContext 7 | ) 8 | 9 | __version__ = '2.15.2post1' 10 | 11 | 12 | class Jit(SchemaJit): 13 | def __init__(self, schema): 14 | super(Jit, self).__init__(schema) 15 | self.schema = schema 16 | self.marshal_method = generate_marshall_method( 17 | schema, context=JitContext()) 18 | self.unmarshal_method = generate_unmarshall_method( 19 | schema, context=JitContext()) 20 | 21 | @property 22 | def jitted_marshal_method(self): 23 | return self.marshal_method 24 | 25 | @property 26 | def jitted_unmarshal_method(self): 27 | return self.unmarshal_method 28 | -------------------------------------------------------------------------------- /toastedmarshmallow/compat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if False: # pylint: disable=using-constant-test 4 | # pylint: disable=unused-import 5 | from types import MethodType 6 | 7 | 8 | if sys.version_info[0] >= 3: 9 | def is_overridden(instance_func, class_func): 10 | # type: (MethodType, MethodType) -> bool 11 | return instance_func.__func__ is not class_func 12 | else: 13 | def is_overridden(instance_func, class_func): 14 | # type: (MethodType, MethodType) -> bool 15 | return instance_func.__func__ is not class_func.__func__ 16 | -------------------------------------------------------------------------------- /toastedmarshmallow/jit.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import keyword 3 | import re 4 | from abc import ABCMeta, abstractmethod 5 | from collections import Mapping 6 | 7 | import attr 8 | from six import exec_, iteritems, add_metaclass, text_type, string_types 9 | from marshmallow import missing, Schema, fields 10 | from marshmallow.base import SchemaABC 11 | 12 | from .compat import is_overridden 13 | from .utils import IndentedString 14 | 15 | 16 | # Regular Expression for identifying a valid Python identifier name. 17 | _VALID_IDENTIFIER = re.compile(r'[a-zA-Z_][a-zA-Z0-9_]*') 18 | 19 | if False: # pylint: disable=using-constant-test 20 | # pylint: disable=unused-import 21 | from typing import Any, Callable, Dict, Optional, Tuple, Union, Set 22 | 23 | 24 | def field_symbol_name(field_name): 25 | # type: (str) -> str 26 | """Generates the symbol name to be used when accessing a field in generated 27 | code. 28 | 29 | If the field name isn't a valid identifier name, synthesizes a name by 30 | base64 encoding the fieldname. 31 | """ 32 | if not _VALID_IDENTIFIER.match(field_name): 33 | field_name = str(base64.b64encode( 34 | field_name.encode('utf-8')).decode('utf-8').strip('=')) 35 | return '_field_{field_name}'.format(field_name=field_name) 36 | 37 | 38 | def attr_str(attr_name): 39 | # type: (str) -> str 40 | """Gets the string to use when accessing an attribute on an object. 41 | 42 | Handles case where the attribute name collides with a keyword and would 43 | therefore be illegal to access with dot notation. 44 | """ 45 | if keyword.iskeyword(attr_name): 46 | return 'getattr(obj, "{0}")'.format(attr_name) 47 | return 'obj.{0}'.format(attr_name) 48 | 49 | 50 | @add_metaclass(ABCMeta) 51 | class FieldSerializer(object): 52 | """Base class for generating code to serialize a field. 53 | """ 54 | def __init__(self, context=None): 55 | # type: (JitContext) -> None 56 | """ 57 | :param context: The context for the current Jit 58 | """ 59 | self.context = context or JitContext() 60 | 61 | @abstractmethod 62 | def serialize(self, attr_name, field_symbol, 63 | assignment_template, field_obj): 64 | # type: (str, str, str, fields.Field) -> IndentedString 65 | """Generates the code to pull a field off of an object into the result. 66 | 67 | :param attr_name: The name of the attribute being accessed/ 68 | :param field_symbol: The symbol to use when accessing the field. Should 69 | be generated via field_symbol_name. 70 | :param assignment_template: A string template to use when generating 71 | code. The assignment template is passed into the serializer and 72 | has a single possitional placeholder for string formatting. An 73 | example of a value that may be passed into assignment_template is: 74 | `res['some_field'] = {0}` 75 | :param field_obj: The instance of the Marshmallow field being 76 | serialized. 77 | :return: The code to pull a field off of the object passed in. 78 | """ 79 | pass # pragma: no cover 80 | 81 | 82 | class InstanceSerializer(FieldSerializer): 83 | """Generates code for accessing fields as if they were instance variables. 84 | 85 | For example, generates: 86 | 87 | res['some_value'] = obj.some_value 88 | """ 89 | def serialize(self, attr_name, field_symbol, 90 | assignment_template, field_obj): 91 | # type: (str, str, str, fields.Field) -> IndentedString 92 | return IndentedString(assignment_template.format(attr_str(attr_name))) 93 | 94 | 95 | class DictSerializer(FieldSerializer): 96 | """Generates code for accessing fields as if they were a dict, generating 97 | the proper code for handing missing fields as well. For example, generates: 98 | 99 | # Required field with no default 100 | res['some_value'] = obj['some_value'] 101 | 102 | # Field with a default. some_value__default will be injected at exec time. 103 | res['some_value'] = obj.get('some_value', some_value__default) 104 | 105 | # Non required field: 106 | if 'some_value' in obj: 107 | res['some_value'] = obj['some_value'] 108 | """ 109 | def serialize(self, attr_name, field_symbol, 110 | assignment_template, field_obj): 111 | # type: (str, str, str, fields.Field) -> IndentedString 112 | body = IndentedString() 113 | if self.context.is_serializing: 114 | default_str = 'default' 115 | default_value = field_obj.default 116 | else: 117 | default_str = 'missing' 118 | default_value = field_obj.missing 119 | if field_obj.required: 120 | body += assignment_template.format('obj["{attr_name}"]'.format( 121 | attr_name=attr_name)) 122 | return body 123 | if default_value == missing: 124 | body += 'if "{attr_name}" in obj:'.format(attr_name=attr_name) 125 | with body.indent(): 126 | body += assignment_template.format('obj["{attr_name}"]'.format( 127 | attr_name=attr_name)) 128 | else: 129 | if callable(default_value): 130 | default_str += '()' 131 | 132 | body += assignment_template.format( 133 | 'obj.get("{attr_name}", {field_symbol}__{default_str})'.format( 134 | attr_name=attr_name, field_symbol=field_symbol, 135 | default_str=default_str)) 136 | return body 137 | 138 | 139 | class HybridSerializer(FieldSerializer): 140 | """Generates code for accessing fields as if they were a hybrid object. 141 | 142 | Hybrid objects are objects that don't inherit from `Mapping`, but do 143 | implement `__getitem__`. This means we first have to attempt a lookup by 144 | key, then fall back to looking up by instance variable. 145 | 146 | For example, generates: 147 | 148 | try: 149 | value = obj['some_value'] 150 | except (KeyError, AttributeError, IndexError, TypeError): 151 | value = obj.some_value 152 | res['some_value'] = value 153 | """ 154 | def serialize(self, attr_name, field_symbol, 155 | assignment_template, field_obj): 156 | # type: (str, str, str, fields.Field) -> IndentedString 157 | body = IndentedString() 158 | body += 'try:' 159 | with body.indent(): 160 | body += 'value = obj["{attr_name}"]'.format(attr_name=attr_name) 161 | body += 'except (KeyError, AttributeError, IndexError, TypeError):' 162 | with body.indent(): 163 | body += 'value = {attr_str}'.format(attr_str=attr_str(attr_name)) 164 | body += assignment_template.format('value') 165 | return body 166 | 167 | 168 | @attr.s 169 | class JitContext(object): 170 | """ Bag of properties to keep track of the context of what's being jitted. 171 | 172 | """ 173 | namespace = attr.ib(default={}) # type: Dict[str, Any] 174 | use_inliners = attr.ib(default=True) # type: bool 175 | schema_stack = attr.ib(default=attr.Factory(set)) # type: Set[str] 176 | only = attr.ib(default=None) # type: Optional[Set[str]] 177 | exclude = attr.ib(default=set()) # type: Set[str] 178 | is_serializing = attr.ib(default=True) # type: bool 179 | 180 | 181 | @add_metaclass(ABCMeta) 182 | class FieldInliner(object): 183 | """Base class for generating code to serialize a field. 184 | 185 | Inliners are used to generate the code to validate/parse fields without 186 | having to bounce back into the underlying marshmallow code. While this is 187 | somewhat fragile as it requires the inliners to be kept in sync with the 188 | underlying implementation, it's good for a >2X speedup on benchmarks. 189 | """ 190 | @abstractmethod 191 | def inline(self, field, context): 192 | # type: (fields.Field, JitContext) -> Optional[str] 193 | pass # pragma: no cover 194 | 195 | 196 | class StringInliner(FieldInliner): 197 | def inline(self, field, context): 198 | # type: (fields.Field, JitContext) -> Optional[str] 199 | """Generates a template for inlining string serialization. 200 | 201 | For example, generates "unicode(value) if value is not None else None" 202 | to serialize a string in Python 2.7 203 | """ 204 | if is_overridden(field._serialize, fields.String._serialize): 205 | return None 206 | result = text_type.__name__ + '({0})' 207 | result += ' if {0} is not None else None' 208 | if not context.is_serializing: 209 | string_type_strings = ','.join([x.__name__ for x in string_types]) 210 | result = ('(' + result + ') if ' 211 | '(isinstance({0}, (' + string_type_strings + 212 | ')) or {0} is None) else dict()["error"]') 213 | return result 214 | 215 | 216 | class BooleanInliner(FieldInliner): 217 | def inline(self, field, context): 218 | # type: (fields.Field, JitContext) -> Optional[str] 219 | """Generates a template for inlining boolean serialization. 220 | 221 | For example, generates: 222 | 223 | ( 224 | (value in __some_field_truthy) or 225 | (False if value in __some_field_falsy else bool(value)) 226 | ) 227 | 228 | This is somewhat fragile but it tracks what Marshmallow does. 229 | """ 230 | if is_overridden(field._serialize, fields.Boolean._serialize): 231 | return None 232 | truthy_symbol = '__{0}_truthy'.format(field.name) 233 | falsy_symbol = '__{0}_falsy'.format(field.name) 234 | context.namespace[truthy_symbol] = field.truthy 235 | context.namespace[falsy_symbol] = field.falsy 236 | result = ('(({0} in ' + truthy_symbol + 237 | ') or (False if {0} in ' + falsy_symbol + 238 | ' else dict()["error"]))') 239 | return result + ' if {0} is not None else None' 240 | 241 | 242 | class NumberInliner(FieldInliner): 243 | def inline(self, field, context): 244 | # type: (fields.Field, JitContext) -> Optional[str] 245 | """Generates a template for inlining string serialization. 246 | 247 | For example, generates "float(value) if value is not None else None" 248 | to serialize a float. If `field.as_string` is `True` the result will 249 | be coerced to a string if not None. 250 | """ 251 | if (is_overridden(field._validated, fields.Number._validated) or 252 | is_overridden(field._serialize, fields.Number._serialize) or 253 | field.num_type not in (int, float)): 254 | return None 255 | result = field.num_type.__name__ + '({0})' 256 | if field.as_string and context.is_serializing: 257 | result = 'str({0})'.format(result) 258 | if field.allow_none is True or context.is_serializing: 259 | # Only emit the Null checking code if nulls are allowed. If they 260 | # aren't allowed casting `None` to an integer will throw and the 261 | # slow path will take over. 262 | result += ' if {0} is not None else None' 263 | return result 264 | 265 | 266 | class NestedInliner(FieldInliner): # pragma: no cover 267 | def inline(self, field, context): 268 | """Generates a template for inlining nested field. 269 | 270 | This doesn't pass tests yet in Marshmallow, namely due to issues around 271 | code expecting the context of nested schema to be populated on first 272 | access, so disabling for now. 273 | """ 274 | if is_overridden(field._serialize, fields.Nested._serialize): 275 | return None 276 | 277 | if not (isinstance(field.nested, type) and 278 | issubclass(field.nested, SchemaABC)): 279 | return None 280 | 281 | if field.nested.__class__ in context.schema_stack: 282 | return None 283 | 284 | method_name = '__nested_{}_serialize'.format( 285 | field_symbol_name(field.name)) 286 | 287 | old_only = context.only 288 | old_exclude = context.exclude 289 | old_namespace = context.namespace 290 | 291 | context.only = set(field.only) if field.only else None 292 | context.exclude = set(field.exclude) 293 | context.namespace = {} 294 | 295 | for only_field in old_only or []: 296 | if only_field.startswith(field.name + '.'): 297 | if not context.only: 298 | context.only = set() 299 | context.only.add(only_field[len(field.name + '.'):]) 300 | for only_field in list((context.only or [])): 301 | if '.' in only_field: 302 | if not context.only: 303 | context.only = set() 304 | context.only.add(only_field.split('.')[0]) 305 | 306 | for exclude_field in old_exclude: 307 | if exclude_field.startswith(field.name + '.'): 308 | context.exclude.add(exclude_field[len(field.name + '.'):]) 309 | 310 | serialize_method = generate_marshall_method(field.schema, context) 311 | if serialize_method is None: 312 | return None 313 | 314 | context.namespace = old_namespace 315 | context.only = old_only 316 | context.exclude = old_exclude 317 | 318 | context.namespace[method_name] = serialize_method 319 | 320 | if field.many: 321 | return ('[' + method_name + 322 | '(_x) for _x in {0}] if {0} is not None else None') 323 | return method_name + '({0}) if {0} is not None else None' 324 | 325 | 326 | INLINERS = { 327 | fields.String: StringInliner(), 328 | fields.Number: NumberInliner(), 329 | fields.Boolean: BooleanInliner(), 330 | } 331 | 332 | EXPECTED_TYPE_TO_CLASS = { 333 | 'object': InstanceSerializer, 334 | 'dict': DictSerializer, 335 | 'hybrid': HybridSerializer 336 | } 337 | 338 | 339 | def _should_skip_field(field_name, field_obj, context): 340 | # type: (str, fields.Field, JitContext) -> bool 341 | load_only = getattr(field_obj, 'load_only', False) 342 | dump_only = getattr(field_obj, 'dump_only', False) 343 | # Marshmallow 2.x.x doesn't properly set load_only or 344 | # dump_only on Method objects. This is fixed in 3.0.0 345 | # https://github.com/marshmallow-code/marshmallow/commit/1b676dd36cbb5cf040da4f5f6d43b0430684325c 346 | if isinstance(field_obj, fields.Method): 347 | load_only = ( 348 | bool(field_obj.deserialize_method_name) and 349 | not bool(field_obj.serialize_method_name) 350 | ) 351 | dump_only = ( 352 | bool(field_obj.serialize_method_name) and 353 | not bool(field_obj.deserialize_method_name) 354 | ) 355 | 356 | if load_only and context.is_serializing: 357 | return True 358 | if dump_only and not context.is_serializing: 359 | return True 360 | if context.only and field_name not in context.only: 361 | return True 362 | if context.exclude and field_name in context.exclude: 363 | return True 364 | return False 365 | 366 | 367 | def generate_transform_method_body(schema, on_field, context): 368 | # type: (Schema, FieldSerializer, JitContext) -> IndentedString 369 | """Generates the method body for a schema and a given field serialization 370 | strategy. 371 | """ 372 | body = IndentedString() 373 | body += 'def {method_name}(obj):'.format( 374 | method_name=on_field.__class__.__name__) 375 | with body.indent(): 376 | if schema.dict_class is dict: 377 | # Declaring dictionaries via `{}` is faster than `dict()` since it 378 | # avoids the global lookup. 379 | body += 'res = {}' 380 | else: 381 | # dict_class will be injected before `exec` is called. 382 | body += 'res = dict_class()' 383 | if not context.is_serializing: 384 | body += '__res_get = res.get' 385 | for field_name, field_obj in iteritems(schema.fields): 386 | if _should_skip_field(field_name, field_obj, context): 387 | continue 388 | 389 | attr_name, destination = _get_attr_and_destination(context, 390 | field_name, 391 | field_obj) 392 | 393 | result_key = ''.join( 394 | [schema.prefix or '', destination]) 395 | 396 | field_symbol = field_symbol_name(field_name) 397 | assignment_template = '' 398 | value_key = '{0}' 399 | 400 | # If we have to assume any field can be callable we always have to 401 | # check to see if we need to invoke the method first. 402 | # We can investigate tracing this as well. 403 | jit_options = getattr(schema.opts, 'jit_options', {}) 404 | no_callable_fields = (jit_options.get('no_callable_fields') or 405 | not context.is_serializing) 406 | if not no_callable_fields: 407 | assignment_template = ( 408 | 'value = {0}; ' 409 | 'value = value() if callable(value) else value; ') 410 | value_key = 'value' 411 | 412 | # Attempt to see if this field type can be inlined. 413 | inliner = inliner_for_field(context, field_obj) 414 | 415 | if inliner: 416 | assignment_template += _generate_inlined_access_template( 417 | inliner, result_key, no_callable_fields) 418 | 419 | else: 420 | assignment_template += _generate_fallback_access_template( 421 | context, field_name, field_obj, result_key, value_key) 422 | if not field_obj._CHECK_ATTRIBUTE: 423 | # fields like 'Method' expect to have `None` passed in when 424 | # invoking their _serialize method. 425 | body += assignment_template.format('None') 426 | context.namespace['__marshmallow_missing'] = missing 427 | body += 'if res["{key}"] is __marshmallow_missing:'.format( 428 | key=result_key) 429 | with body.indent(): 430 | body += 'del res["{key}"]'.format(key=result_key) 431 | 432 | else: 433 | serializer = on_field 434 | if not _VALID_IDENTIFIER.match(attr_name): 435 | # If attr_name is not a valid python identifier, it can only 436 | # be accessed via key lookups. 437 | serializer = DictSerializer(context) 438 | 439 | body += serializer.serialize( 440 | attr_name, field_symbol, assignment_template, field_obj) 441 | 442 | if not context.is_serializing and field_obj.load_from: 443 | # Marshmallow has a somewhat counter intuitive behavior. 444 | # It will first load from the name of the field, then, 445 | # should that fail, will load from the field specified in 446 | # 'load_from'. 447 | # 448 | # For example: 449 | # 450 | # class TestSchema(Schema): 451 | # foo = StringField(load_from='bar') 452 | # TestSchema().load({'foo': 'haha'}).result 453 | # 454 | # Works just fine with no errors. 455 | # 456 | # class TestSchema(Schema): 457 | # foo = StringField(load_from='bar') 458 | # TestSchema().load({'foo': 'haha', 'bar': 'value'}).result 459 | # 460 | # Results in {'foo': 'haha'} 461 | # 462 | # Therefore, we generate code to mimic this behavior in 463 | # cases where `load_from` is specified. 464 | body += 'if "{key}" not in res:'.format(key=result_key) 465 | with body.indent(): 466 | body += serializer.serialize( 467 | field_obj.load_from, field_symbol, 468 | assignment_template, field_obj) 469 | if not context.is_serializing: 470 | if field_obj.required: 471 | body += 'if "{key}" not in res:'.format(key=result_key) 472 | with body.indent(): 473 | body += 'raise ValueError()' 474 | if field_obj.allow_none is not True: 475 | body += 'if __res_get("{key}", res) is None:'.format( 476 | key=result_key) 477 | with body.indent(): 478 | body += 'raise ValueError()' 479 | if (field_obj.validators or 480 | is_overridden(field_obj._validate, 481 | fields.Field._validate)): 482 | body += 'if "{key}" in res:'.format(key=result_key) 483 | with body.indent(): 484 | body += '{field_symbol}__validate(res["{result_key}"])'.format( 485 | field_symbol=field_symbol, result_key=result_key 486 | ) 487 | 488 | body += 'return res' 489 | return body 490 | 491 | 492 | def _generate_fallback_access_template(context, field_name, field_obj, 493 | result_key, value_key): 494 | field_symbol = field_symbol_name(field_name) 495 | transform_method_name = 'serialize' 496 | if not context.is_serializing: 497 | transform_method_name = 'deserialize' 498 | key_name = field_name 499 | if not context.is_serializing: 500 | key_name = field_obj.load_from or field_name 501 | return ( 502 | 'res["{key}"] = {field_symbol}__{transform}(' 503 | '{value_key}, "{key_name}", obj)'.format( 504 | key=result_key, field_symbol=field_symbol, 505 | transform=transform_method_name, 506 | key_name=key_name, value_key=value_key)) 507 | 508 | 509 | def _get_attr_and_destination(context, field_name, field_obj): 510 | # type: (JitContext, str, fields.Field) -> Tuple[str, str] 511 | # The name of the attribute to pull off the incoming object 512 | attr_name = field_name 513 | # The destination of the field in the result dictionary. 514 | destination = field_name 515 | if context.is_serializing: 516 | destination = field_obj.dump_to or field_name 517 | if field_obj.attribute: 518 | if context.is_serializing: 519 | attr_name = field_obj.attribute 520 | else: 521 | destination = field_obj.attribute 522 | return attr_name, destination 523 | 524 | 525 | def _generate_inlined_access_template(inliner, key, no_callable_fields): 526 | # type: (str, str, bool) -> str 527 | """Generates the code to access a field with an inliner.""" 528 | value_key = 'value' 529 | assignment_template = '' 530 | if not no_callable_fields: 531 | assignment_template += 'value = {0}; '.format( 532 | inliner.format(value_key)) 533 | else: 534 | assignment_template += 'value = {0}; ' 535 | value_key = inliner.format('value') 536 | assignment_template += 'res["{key}"] = {value_key}'.format( 537 | key=key, value_key=value_key) 538 | return assignment_template 539 | 540 | 541 | def inliner_for_field(context, field_obj): 542 | # type: (JitContext, fields.Field) -> Optional[str] 543 | if context.use_inliners: 544 | inliner = None 545 | for field_type, inliner_class in iteritems(INLINERS): 546 | if isinstance(field_obj, field_type): 547 | inliner = inliner_class.inline(field_obj, context) 548 | if inliner: 549 | break 550 | return inliner 551 | return None 552 | 553 | 554 | def generate_method_bodies(schema, context): 555 | # type: (Schema, JitContext) -> str 556 | """Generate 3 method bodies for marshalling objects, dictionaries, or hybrid 557 | objects. 558 | """ 559 | result = IndentedString() 560 | 561 | result += generate_transform_method_body(schema, 562 | InstanceSerializer(context), 563 | context) 564 | result += generate_transform_method_body(schema, 565 | DictSerializer(context), 566 | context) 567 | result += generate_transform_method_body(schema, 568 | HybridSerializer(context), 569 | context) 570 | return str(result) 571 | 572 | 573 | class SerializeProxy(object): 574 | """Proxy object for calling serializer methods. 575 | 576 | Initially trace calls to serialize and if the number of calls 577 | of a specific type crosses `threshold` swaps out the implementation being 578 | used for the most specialized one available. 579 | """ 580 | def __init__(self, dict_serializer, hybrid_serializer, 581 | instance_serializer, 582 | threshold=100): 583 | # type: (Callable, Callable, Callable, int) -> None 584 | self.dict_serializer = dict_serializer 585 | self.hybrid_serializer = hybrid_serializer 586 | self.instance_serializer = instance_serializer 587 | self.threshold = threshold 588 | self.dict_count = 0 589 | self.hybrid_count = 0 590 | self.instance_count = 0 591 | self._call = self.tracing_call 592 | 593 | if not threshold: 594 | self._call = self.no_tracing_call 595 | 596 | def __call__(self, obj): 597 | return self._call(obj) 598 | 599 | def tracing_call(self, obj): 600 | # type: (Any) -> Any 601 | """Dispatcher which traces calls and specializes if possible. 602 | """ 603 | try: 604 | if isinstance(obj, Mapping): 605 | self.dict_count += 1 606 | return self.dict_serializer(obj) 607 | elif hasattr(obj, '__getitem__'): 608 | self.hybrid_count += 1 609 | return self.hybrid_serializer(obj) 610 | self.instance_count += 1 611 | return self.instance_serializer(obj) 612 | finally: 613 | non_zeros = [x for x in 614 | [self.dict_count, 615 | self.hybrid_count, 616 | self.instance_count] if x > 0] 617 | if len(non_zeros) > 1: 618 | self._call = self.no_tracing_call 619 | elif self.dict_count >= self.threshold: 620 | self._call = self.dict_serializer 621 | elif self.hybrid_count >= self.threshold: 622 | self._call = self.hybrid_serializer 623 | elif self.instance_count >= self.threshold: 624 | self._call = self.instance_serializer 625 | 626 | def no_tracing_call(self, obj): 627 | # type: (Any) -> Any 628 | """Dispatcher with no tracing. 629 | """ 630 | if isinstance(obj, Mapping): 631 | return self.dict_serializer(obj) 632 | elif hasattr(obj, '__getitem__'): 633 | return self.hybrid_serializer(obj) 634 | return self.instance_serializer(obj) 635 | 636 | 637 | def generate_marshall_method(schema, context=missing, threshold=100): 638 | # type: (Schema, JitContext, int) -> Union[SerializeProxy, Callable, None] 639 | """Generates a function to marshall objects for a given schema. 640 | 641 | :param schema: The Schema to generate a marshall method for. 642 | :param threshold: The number of calls of the same type to observe before 643 | specializing the marshal method for that type. 644 | :return: A Callable that can be used to marshall objects for the schema 645 | """ 646 | if is_overridden(schema.get_attribute, Schema.get_attribute): 647 | # Bail if get_attribute is overridden. This provides the schema author 648 | # too much control to reasonably JIT. 649 | return None 650 | 651 | if context is missing: 652 | context = JitContext() 653 | 654 | context.namespace = {} 655 | context.namespace['dict_class'] = lambda: schema.dict_class() # pylint: disable=unnecessary-lambda 656 | 657 | jit_options = getattr(schema.opts, 'jit_options', {}) 658 | 659 | context.schema_stack.add(schema.__class__) 660 | 661 | result = generate_method_bodies(schema, context) 662 | 663 | context.schema_stack.remove(schema.__class__) 664 | 665 | namespace = context.namespace 666 | 667 | for key, value in iteritems(schema.fields): 668 | if value.attribute and '.' in value.attribute: 669 | # We're currently unable to handle dotted attributes. These don't 670 | # seem to be widely used so punting for now. For more information 671 | # see 672 | # https://github.com/marshmallow-code/marshmallow/issues/450 673 | return None 674 | namespace[field_symbol_name(key) + '__serialize'] = value._serialize 675 | namespace[field_symbol_name(key) + '__deserialize'] = value._deserialize 676 | namespace[field_symbol_name(key) + '__validate_missing'] = value._validate_missing 677 | namespace[field_symbol_name(key) + '__validate'] = value._validate 678 | 679 | if value.default is not missing: 680 | namespace[field_symbol_name(key) + '__default'] = value.default 681 | if value.missing is not missing: 682 | namespace[field_symbol_name(key) + '__missing'] = value.missing 683 | 684 | exec_(result, namespace) 685 | 686 | proxy = None # type: Optional[SerializeProxy] 687 | marshall_method = None # type: Union[SerializeProxy, Callable, None] 688 | if not context.is_serializing: 689 | # Deserialization always expects a dictionary. 690 | marshall_method = namespace[DictSerializer.__name__] 691 | elif jit_options.get('expected_marshal_type') in EXPECTED_TYPE_TO_CLASS: 692 | marshall_method = namespace[EXPECTED_TYPE_TO_CLASS[ 693 | jit_options['expected_marshal_type']].__name__] 694 | else: 695 | marshall_method = SerializeProxy( 696 | namespace[DictSerializer.__name__], 697 | namespace[HybridSerializer.__name__], 698 | namespace[InstanceSerializer.__name__], 699 | threshold=threshold) 700 | proxy = marshall_method 701 | 702 | def marshall(obj, many=False): 703 | if many: 704 | return [marshall_method(x) for x in obj] 705 | return marshall_method(obj) 706 | 707 | if proxy: 708 | # Used to allow tests to introspect the proxy. 709 | marshall.proxy = proxy # type: ignore 710 | marshall._source = result # type: ignore 711 | return marshall 712 | 713 | 714 | def generate_unmarshall_method(schema, context=missing): 715 | context = context or JitContext() 716 | context.is_serializing = False 717 | return generate_marshall_method(schema, context) 718 | -------------------------------------------------------------------------------- /toastedmarshmallow/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | 4 | if False: # pylint: disable=using-constant-test 5 | # pylint: disable=unused-import 6 | from typing import List, Union 7 | 8 | 9 | class IndentedString(object): 10 | """Utility class for printing indented strings via a context manager. 11 | 12 | """ 13 | def __init__(self, content='', indent=4): 14 | # type: (Union[str, IndentedString], int) -> None 15 | self.result = [] # type: List[str] 16 | self._indent = indent 17 | self.__indents = [''] 18 | if content: 19 | self.__iadd__(content) 20 | 21 | @contextmanager 22 | def indent(self): 23 | self.__indents.append(self.__indents[-1] + (self._indent * ' ')) 24 | try: 25 | yield 26 | finally: 27 | self.__indents.pop() 28 | 29 | def __iadd__(self, other): 30 | # type: (Union[str, IndentedString]) -> IndentedString 31 | if isinstance(other, IndentedString): 32 | for line in other.result: 33 | self.result.append(self.__indents[-1] + line) 34 | else: 35 | self.result.append(self.__indents[-1] + other) 36 | return self 37 | 38 | def __str__(self): 39 | # type: () -> str 40 | return os.linesep.join(self.result) 41 | --------------------------------------------------------------------------------