├── .coveragerc ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.rst ├── setup.py ├── straitlets ├── __init__.py ├── builtin_models.py ├── compat.py ├── dispatch.py ├── ext │ ├── __init__.py │ ├── click.py │ └── tests │ │ └── test_click.py ├── py3.py ├── serializable.py ├── test_utils.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── test_builtin_serializables.py │ ├── test_compat.py │ ├── test_cross_validation.py │ ├── test_examples.py │ ├── test_serializable.py │ ├── test_test_utils.py │ ├── test_to_primitive.py │ ├── test_traits.py │ └── test_utilities.py ├── to_primitive.py ├── traits.py └── utils.py └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = straitlets/py3.py 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # Emacs 60 | *~ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | python: 4 | - "2.7" 5 | - "3.5" 6 | - "3.6" 7 | 8 | # Python 3.7 requires OpenSSL 1.0.2+, which is only available on Travis 9 | # via xenial and sudo. Require them for only the build that needs them. 10 | matrix: 11 | include: 12 | - python: "3.7" 13 | dist: xenial 14 | sudo: true 15 | 16 | before_script: 17 | - pip install tox 18 | 19 | script: 20 | - if [[ $TRAVIS_PYTHON_VERSION = '2.7' ]]; then tox -e py27; fi 21 | - if [[ $TRAVIS_PYTHON_VERSION = '3.5' ]]; then tox -e py35; fi 22 | - if [[ $TRAVIS_PYTHON_VERSION = '3.6' ]]; then tox -e py36; fi 23 | - if [[ $TRAVIS_PYTHON_VERSION = '3.7' ]]; then tox -e py37; fi 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | serializable-traitlets 3 | ====================== 4 | Serializable IPython Traitlets 5 | 6 | ``serializable-traitlets`` (imported as ``straitlets``) is a Python 2/3 7 | compatible library providing a restricted subset of the classes from `IPython 8 | Traitlets`_. Within our restricted subset, we inherit all the benefits of 9 | using regular ``traitlets``, including static type declarations, `dynamic 10 | default generators`_, and `attribute observers/validators`_. 11 | 12 | By supporting only a limited (though still expressive) subset of Python 13 | objects, however, we gain the ability to serialize and deserialize instances of 14 | ``Serializable`` to and from various formats, including: 15 | 16 | #. JSON 17 | #. YAML 18 | #. base64-encoded strings 19 | 20 | These properties make ``Serializables`` well-suited for configuration in 21 | environments where objects need to be transferred between processes. 22 | 23 | ``straitlets`` also provides users the ability to specify ``example`` values 24 | for traits. If all traits of a ``Serializable`` class have examples (or 25 | default values) provided, then we can auto-generate an example for the parent 26 | class, and we can resursively generate examples for nested classes. 27 | 28 | Usage 29 | ----- 30 | 31 | **Basic Usage:** 32 | 33 | .. code-block:: python 34 | 35 | In [1]: from straitlets import Serializable, Integer, Dict, List 36 | In [2]: class Foo(Serializable): 37 | ...: my_int = Integer() 38 | ...: my_dict = Dict() 39 | ...: my_list = List() 40 | 41 | In [3]: instance = Foo(my_int=3, my_dict={'a': [1, 2], 'b': (3, 4)}, my_list=[5, None]) 42 | 43 | In [4]: print(instance.to_json()) 44 | {"my_int": 3, "my_dict": {"a": [1, 2], "b": [3, 4]}, "my_list": [5, null]} 45 | 46 | In [5]: print(instance.to_yaml()) 47 | my_dict: 48 | a: 49 | - 1 50 | - 2 51 | b: 52 | - 3 53 | - 4 54 | my_int: 3 55 | my_list: 56 | - 5 57 | - null 58 | 59 | **Autogenerating Example Values:** 60 | 61 | .. code-block:: python 62 | 63 | from straitlets import Serializable, Integer, Instance 64 | 65 | class Point(Serializable): 66 | x = Integer().example(0) 67 | y = Integer().example(0) 68 | 69 | 70 | class Vector(Serializable): 71 | # We can automatically generate example values for attributes 72 | # declared as Instances of Serializable. 73 | head = Instance(Point) 74 | 75 | # Per-attribute overrides are still supported. 76 | tail = Instance(Point).example(Point(x=1, y=3)) 77 | 78 | print(Vector.example_instance().to_yaml()) 79 | # head: 80 | # x: 0 81 | # y: 0 82 | # tail: 83 | # x: 1 84 | # y: 3 85 | 86 | .. _`IPython Traitlets` : http://traitlets.readthedocs.org 87 | .. _`dynamic default generators` : http://traitlets.readthedocs.org/en/stable/using_traitlets.html#dynamic-default-values 88 | .. _`attribute observers/validators` : http://traitlets.readthedocs.org/en/stable/using_traitlets.html#callbacks-when-trait-attributes-change 89 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from sys import version_info 3 | 4 | 5 | def install_requires(): 6 | requires = [ 7 | 'traitlets>=4.1', 8 | 'six>=1.9.0', 9 | 'pyyaml>=3.11', 10 | ] 11 | if (version_info.major, version_info.minor) < (3, 4): 12 | requires.append('singledispatch>=3.4.0') 13 | return requires 14 | 15 | 16 | def extras_require(): 17 | return { 18 | 'test': [ 19 | 'tox', 20 | 'pytest>=2.8.5', 21 | 'pytest-cov>=1.8.1', 22 | 'pytest-pep8>=1.0.6', 23 | 'click>=6.0', 24 | ], 25 | } 26 | 27 | 28 | def main(): 29 | setup( 30 | name='straitlets', 31 | # remember to update straitlets/__init__.py! 32 | version='0.3.3', 33 | description="Serializable IPython Traitlets", 34 | author="Quantopian Team", 35 | author_email="opensource@quantopian.com", 36 | packages=find_packages(include='straitlets.*'), 37 | include_package_data=True, 38 | zip_safe=True, 39 | url="https://github.com/quantopian/serializable-traitlets", 40 | classifiers=[ 41 | 'Development Status :: 3 - Alpha', 42 | 'Framework :: IPython', 43 | 'Natural Language :: English', 44 | 'Operating System :: OS Independent', 45 | 'Programming Language :: Python :: 2.7', 46 | 'Programming Language :: Python :: 3.4', 47 | 'Programming Language :: Python', 48 | ], 49 | install_requires=install_requires(), 50 | extras_require=extras_require() 51 | ) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /straitlets/__init__.py: -------------------------------------------------------------------------------- 1 | from .serializable import Serializable, StrictSerializable, MultipleTraitErrors 2 | from .traits import ( 3 | Bool, 4 | Dict, 5 | Enum, 6 | Float, 7 | Instance, 8 | Integer, 9 | LengthBoundedUnicode, 10 | List, 11 | Set, 12 | Tuple, 13 | Unicode, 14 | ) 15 | 16 | # remember to update setup.py! 17 | __version__ = '0.3.3' 18 | 19 | __all__ = ( 20 | 'Bool', 21 | 'Dict', 22 | 'Enum', 23 | 'Float', 24 | 'Instance', 25 | 'Integer', 26 | 'LengthBoundedUnicode', 27 | 'List', 28 | 'MultipleTraitErrors', 29 | 'Set', 30 | 'Tuple', 31 | 'Unicode', 32 | 'Serializable', 33 | 'StrictSerializable', 34 | ) 35 | -------------------------------------------------------------------------------- /straitlets/builtin_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Built-In Serializables 3 | """ 4 | from six.moves.urllib.parse import urlencode, urlparse 5 | from traitlets import TraitError, validate 6 | 7 | from .serializable import StrictSerializable 8 | from .traits import Bool, Integer, List, Unicode, Dict 9 | 10 | 11 | def join_filter_empty(sep, *elems): 12 | """ 13 | Join a sequence of elements by ``sep``, filtering out empty elements. 14 | 15 | Example 16 | ------- 17 | >>> join_filter_empty(':', 'a', None, 'c') 18 | 'a:c' 19 | >>> join_filter_empty(':', 'a', None) 20 | 'a' 21 | """ 22 | return sep.join(map(str, filter(None, elems))) 23 | 24 | 25 | class PostgresConfig(StrictSerializable): 26 | """ 27 | Configuration for a PostgreSQL connection. 28 | """ 29 | username = Unicode(help="Username for postgres login") 30 | password = Unicode( 31 | allow_none=True, 32 | default_value=None, 33 | help="Password for postgres login", 34 | ) 35 | hostname = Unicode( 36 | allow_none=True, 37 | default_value=None, 38 | help="Postgres server hostname", 39 | ) 40 | port = Integer( 41 | allow_none=True, 42 | default_value=None, 43 | help="Postgres server port", 44 | ) 45 | 46 | @validate('port') 47 | def _port_requires_hostname(self, proposal): 48 | value = proposal['value'] 49 | if value is not None and self.hostname is None: 50 | raise TraitError("Received port %s but no hostname." % value) 51 | return value 52 | 53 | database = Unicode(help="Database name") 54 | 55 | @property 56 | def netloc(self): 57 | user_pass = join_filter_empty(':', self.username, self.password) 58 | host_port = join_filter_empty(':', self.hostname, self.port) 59 | return '@'.join([user_pass, host_port]) 60 | 61 | query_params = Dict( 62 | default_value={}, 63 | help="Connection parameters", 64 | ) 65 | 66 | @property 67 | def url(self): 68 | return join_filter_empty( 69 | '?', 70 | "postgresql://{netloc}/{db}".format( 71 | netloc=self.netloc, 72 | db=self.database, 73 | ), 74 | urlencode(self.query_params), 75 | ) 76 | 77 | @classmethod 78 | def from_url(cls, url): 79 | """ 80 | Construct a PostgresConfig from a URL. 81 | """ 82 | parsed = urlparse(url) 83 | return cls( 84 | username=parsed.username, 85 | password=parsed.password, 86 | hostname=parsed.hostname, 87 | port=parsed.port, 88 | database=parsed.path.lstrip('/'), 89 | # Like parse_qs, but produces a scalar per key, instead of a list: 90 | query_params=dict(param.split('=') 91 | for param in parsed.query.split('&')) 92 | if parsed.query else {}, 93 | ) 94 | 95 | 96 | class MongoConfig(StrictSerializable): 97 | """ 98 | Configuration for a MongoDB connection. 99 | """ 100 | username = Unicode( 101 | allow_none=True, 102 | default_value=None, 103 | help="Username for Database Authentication", 104 | ) 105 | 106 | password = Unicode( 107 | allow_none=True, 108 | default_value=None, 109 | help="Password for Database Authentication", 110 | ) 111 | 112 | @validate('username') 113 | def _user_requires_password(self, proposal): 114 | new = proposal['value'] 115 | # Must supply both or neither. 116 | if new and not self.password: 117 | raise TraitError("Username '%s' supplied without password." % new) 118 | return new 119 | 120 | @validate('password') 121 | def _password_requires_user(self, proposal): 122 | # Must supply both or neither. 123 | new = proposal['value'] 124 | if new and not self.username: 125 | # Intentionally not printing a password here. 126 | raise TraitError("Password supplied without username.") 127 | return new 128 | 129 | hosts = List( 130 | trait=Unicode(), 131 | minlen=1, 132 | help=( 133 | "List of hosts in the replicaset. " 134 | "To specify a port, postfix with :{portnum}." 135 | ) 136 | ) 137 | database = Unicode(help="Database Name") 138 | replicaset = Unicode( 139 | default_value=None, 140 | help="Replicaset Name", 141 | allow_none=True, 142 | ) 143 | slave_ok = Bool( 144 | default_value=True, 145 | help="Okay to connect to non-primary?", 146 | ) 147 | prefer_secondary = Bool( 148 | default_value=True, 149 | help="Prefer to connect to non-primary?", 150 | ) 151 | ssl = Bool(default_value=False, help="Connect via SSL?") 152 | ssl_ca_certs = Unicode( 153 | allow_none=True, 154 | default_value=None, 155 | help="Path to concatenated CA certificates.", 156 | ) 157 | -------------------------------------------------------------------------------- /straitlets/compat.py: -------------------------------------------------------------------------------- 1 | from six import PY3 2 | 3 | if PY3: # pragma: no cover 4 | from inspect import getfullargspec as argspec # noqa 5 | long = int 6 | unicode = str 7 | else: # pragma: no cover 8 | from inspect import getargspec as argspec # noqa 9 | long = long 10 | unicode = unicode 11 | 12 | 13 | def ensure_bytes(s, encoding='utf-8'): 14 | if isinstance(s, bytes): 15 | return s 16 | elif isinstance(s, unicode): 17 | return s.encode(encoding=encoding) 18 | raise TypeError("Expected bytes or unicode, got %s." % type(s)) 19 | 20 | 21 | def ensure_unicode(s, encoding='utf-8'): 22 | if isinstance(s, unicode): 23 | return s 24 | elif isinstance(s, bytes): 25 | return s.decode(encoding=encoding) 26 | raise TypeError("Expected bytes or unicode, got %s." % type(s)) 27 | 28 | 29 | __all__ = [ 30 | 'argspec', 31 | 'ensure_bytes', 32 | 'ensure_unicode', 33 | 'long', 34 | ] 35 | -------------------------------------------------------------------------------- /straitlets/dispatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python <= 3.4 compat for singledispatch. 3 | """ 4 | from sys import version_info 5 | if (version_info.major, version_info.minor) < (3, 4): # pragma: no cover 6 | from singledispatch import singledispatch 7 | else: # pragma: no cover 8 | from functools import singledispatch 9 | 10 | __all__ = ['singledispatch'] 11 | -------------------------------------------------------------------------------- /straitlets/ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantopian/serializable-traitlets/f7de75507978e08446a15894a8417997940ea7a6/straitlets/ext/__init__.py -------------------------------------------------------------------------------- /straitlets/ext/click.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from operator import itemgetter 4 | 5 | import click 6 | 7 | from straitlets import MultipleTraitErrors 8 | from traitlets import TraitError 9 | 10 | 11 | def _indent(msg): 12 | return '\n'.join(' ' + line for line in msg.splitlines()) 13 | 14 | 15 | class _ConfigFile(click.File): 16 | def __init__(self, config_type, encoding=None): 17 | super(_ConfigFile, self).__init__( 18 | mode='r', 19 | encoding=None, 20 | errors='strict', 21 | lazy=False, 22 | atomic=False, 23 | ) 24 | self.config = config_type 25 | 26 | def read(self, f): # pragma: no cover 27 | raise NotImplementedError('read') 28 | 29 | @property 30 | def name(self): # pragma: no cover 31 | raise NotImplementedError('name') 32 | 33 | def convert(self, value, param, ctx): 34 | f = super(_ConfigFile, self).convert(value, param, ctx) 35 | try: 36 | return self.read(f) 37 | except MultipleTraitErrors as e: 38 | self.fail( 39 | 'Failed to validate the schema:\n\n' + '\n'.join( 40 | '%s:\n%s' % (key, _indent(str(err))) 41 | for key, err in sorted(e.errors.items(), key=itemgetter(0)) 42 | ), 43 | param, 44 | ctx, 45 | ) 46 | except TraitError as e: 47 | self.fail( 48 | 'Failed to validate the schema:\n%s' % _indent(str(e)), 49 | param, 50 | ctx, 51 | ) 52 | 53 | 54 | class JsonConfigFile(_ConfigFile): 55 | """A click parameter type for reading a :class:`~straitlets.Serializable` 56 | object out of json file. 57 | 58 | Parameters 59 | ---------- 60 | config_type : type[Serializable] 61 | A subclass of :class:`~straitlets.Serializable`. 62 | 63 | Notes 64 | ----- 65 | Normal :class:`~straitlets.Serializable` are not eagerly validated, if you 66 | want to check the schema of the file at read time you should use a 67 | :class:`~straitlets.StrictSerializable`. 68 | """ 69 | name = 'JSON-FILE' 70 | 71 | def read(self, f): 72 | return self.config.from_json(f.read()) 73 | 74 | 75 | class YamlConfigFile(_ConfigFile): 76 | """A click parameter type for reading a :class:`~straitlets.Serializable` 77 | object out of yaml file. 78 | 79 | Parameters 80 | ---------- 81 | config_type : type[Serializable] 82 | A subclass of :class:`~straitlets.Serializable`. 83 | 84 | Notes 85 | ----- 86 | Normal :class:`~straitlets.Serializable` are not eagerly validated, if you 87 | want to check the schema of the file at read time you should use a 88 | :class:`~straitlets.StrictSerializable`. 89 | """ 90 | name = 'YAML-FILE' 91 | 92 | def read(self, f): 93 | return self.config.from_yaml(f) 94 | -------------------------------------------------------------------------------- /straitlets/ext/tests/test_click.py: -------------------------------------------------------------------------------- 1 | import re 2 | from textwrap import dedent 3 | 4 | import click 5 | from click.testing import CliRunner 6 | import pytest 7 | 8 | from straitlets import ( 9 | Serializable, 10 | StrictSerializable, 11 | Bool, 12 | Unicode, 13 | Integer, 14 | ) 15 | from straitlets.ext.click import ( 16 | JsonConfigFile, 17 | YamlConfigFile, 18 | ) 19 | from straitlets.test_utils import assert_serializables_equal 20 | 21 | 22 | @pytest.fixture 23 | def runner(): 24 | return CliRunner() 25 | 26 | 27 | class Config(Serializable): 28 | bool = Bool() 29 | unicode = Unicode() 30 | int = Integer() 31 | 32 | 33 | class MissingAttr(Serializable): 34 | bool = Bool() 35 | unicode = Unicode() 36 | 37 | 38 | class StrictConfig(Config, StrictSerializable): 39 | pass 40 | 41 | 42 | @pytest.fixture 43 | def expected_instance(): 44 | return Config( 45 | bool=True, 46 | unicode='ayy', 47 | int=1, 48 | ) 49 | 50 | 51 | @pytest.fixture 52 | def missing_attr_instance(): 53 | return MissingAttr( 54 | bool=True, 55 | unicode='ayy', 56 | ) 57 | 58 | 59 | multi_error_output = re.compile( 60 | dedent( 61 | """\ 62 | Failed to validate the schema: 63 | 64 | bool: 65 | No default value found for bool trait of <.+?> 66 | int: 67 | No default value found for int trait of <.+?> 68 | unicode: 69 | No default value found for unicode trait of <.+?> 70 | """, 71 | ), 72 | ) 73 | 74 | single_error_output = re.compile( 75 | dedent( 76 | """\ 77 | Failed to validate the schema: 78 | No default value found for int trait of <.+?> 79 | """, 80 | ), 81 | ) 82 | 83 | 84 | def test_json_file(runner, expected_instance): 85 | instance = [None] # nonlocal 86 | 87 | @click.command() 88 | @click.option('--config', type=JsonConfigFile(Config)) 89 | def main(config): 90 | instance[0] = config 91 | 92 | with runner.isolated_filesystem(): 93 | with open('f.json', 'w') as f: 94 | f.write(expected_instance.to_json()) 95 | 96 | result = runner.invoke( 97 | main, 98 | ['--config', 'f.json'], 99 | input='not-json', 100 | catch_exceptions=False, 101 | ) 102 | assert result.output == '' 103 | assert result.exit_code == 0 104 | assert_serializables_equal( 105 | instance[0], 106 | expected_instance, 107 | ) 108 | 109 | 110 | def test_json_multiple_errors(runner): 111 | @click.command() 112 | @click.option('--config', type=JsonConfigFile(StrictConfig)) 113 | def main(config): # pragma: no cover 114 | pass 115 | 116 | with runner.isolated_filesystem(): 117 | with open('f.json', 'w') as f: 118 | f.write('{}') 119 | 120 | result = runner.invoke( 121 | main, 122 | ['--config', 'f.json'], 123 | input='not-json', 124 | catch_exceptions=False, 125 | ) 126 | assert result.exit_code 127 | assert multi_error_output.search(result.output) 128 | 129 | 130 | def test_json_single_error(runner, missing_attr_instance): 131 | @click.command() 132 | @click.option('--config', type=JsonConfigFile(StrictConfig)) 133 | def main(config): # pragma: no cover 134 | pass 135 | 136 | with runner.isolated_filesystem(): 137 | with open('f.json', 'w') as f: 138 | f.write(missing_attr_instance.to_json()) 139 | 140 | result = runner.invoke( 141 | main, 142 | ['--config', 'f.json'], 143 | input='not-json', 144 | catch_exceptions=False, 145 | ) 146 | assert result.exit_code 147 | assert single_error_output.search(result.output) 148 | 149 | 150 | def test_yaml_file(runner, expected_instance): 151 | instance = [None] # nonlocal 152 | 153 | @click.command() 154 | @click.option('--config', type=YamlConfigFile(Config)) 155 | def main(config): 156 | instance[0] = config 157 | 158 | with runner.isolated_filesystem(): 159 | with open('f.yml', 'w') as f: 160 | f.write(expected_instance.to_yaml()) 161 | 162 | result = runner.invoke( 163 | main, 164 | ['--config', 'f.yml'], 165 | input='not-yaml', 166 | catch_exceptions=False, 167 | ) 168 | assert result.output == '' 169 | assert result.exit_code == 0 170 | assert_serializables_equal( 171 | instance[0], 172 | expected_instance, 173 | ) 174 | 175 | 176 | def test_yaml_multiple_errors(runner): 177 | @click.command() 178 | @click.option('--config', type=YamlConfigFile(StrictConfig)) 179 | def main(config): # pragma: no cover 180 | pass 181 | 182 | with runner.isolated_filesystem(): 183 | with open('f.yml', 'w') as f: 184 | f.write('{}') 185 | 186 | result = runner.invoke( 187 | main, 188 | ['--config', 'f.yml'], 189 | input='not-yaml', 190 | catch_exceptions=False, 191 | ) 192 | assert result.exit_code 193 | assert multi_error_output.search(result.output) 194 | 195 | 196 | def test_yaml_single_error(runner, missing_attr_instance): 197 | @click.command() 198 | @click.option('--config', type=YamlConfigFile(StrictConfig)) 199 | def main(config): # pragma: no cover 200 | pass 201 | 202 | with runner.isolated_filesystem(): 203 | with open('f.yml', 'w') as f: 204 | f.write(missing_attr_instance.to_yaml()) 205 | 206 | result = runner.invoke( 207 | main, 208 | ['--config', 'f.yml'], 209 | input='not-yaml', 210 | catch_exceptions=False, 211 | ) 212 | assert result.exit_code 213 | assert single_error_output.search(result.output) 214 | -------------------------------------------------------------------------------- /straitlets/py3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from .to_primitive import to_primitive 4 | from .traits import SerializableTrait 5 | 6 | if sys.version_info.major < 3: 7 | # raise a more explicit error message if this is imported in Python 2 8 | raise ImportError('%s is only available in Python 3' % __name__) 9 | 10 | # noqa on the import because it is not at the top of the module. We cannot 11 | # import this module until we know that we are in Python 3. 12 | import pathlib # noqa 13 | 14 | 15 | @to_primitive.register(pathlib.Path) 16 | def _path_to_primitive(path): 17 | return str(path) 18 | 19 | 20 | class Path(SerializableTrait): 21 | def validate(self, obj, value): 22 | # ``pathlib.Path`` is a nop when called on a ``pathlib.Path`` object 23 | return pathlib.Path(value) 24 | -------------------------------------------------------------------------------- /straitlets/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a Serializable subclass for extended traitlets. 3 | """ 4 | import base64 5 | import json 6 | from operator import itemgetter 7 | from textwrap import dedent 8 | import yaml 9 | 10 | from traitlets import ( 11 | HasTraits, 12 | MetaHasTraits, 13 | TraitError, 14 | TraitType, 15 | Undefined, 16 | ) 17 | from six import with_metaclass, iteritems, viewkeys 18 | 19 | from .compat import ensure_bytes, ensure_unicode 20 | from .traits import SerializableTrait 21 | from .to_primitive import to_primitive 22 | 23 | 24 | class SerializableMeta(MetaHasTraits): 25 | 26 | def __new__(mcls, name, bases, classdict): 27 | # Check that all TraitType instances are all. 28 | for maybe_trait_name, maybe_trait_instance in iteritems(classdict): 29 | if isinstance(maybe_trait_instance, TraitType): 30 | if not isinstance(maybe_trait_instance, SerializableTrait): 31 | raise TypeError( 32 | "Got non-serializable trait {name}={type}".format( 33 | name=maybe_trait_name, 34 | type=type(maybe_trait_instance).__name__, 35 | ) 36 | ) 37 | 38 | return super(SerializableMeta, mcls).__new__( 39 | mcls, name, bases, classdict 40 | ) 41 | 42 | 43 | _DID_YOU_MEAN_INSTANCE_TEMPLATE = dedent( 44 | """ 45 | {type}.__init__() got unexpected keyword argument {name!r}. 46 | {type} (or a parent) has a class attribute with the same name. 47 | Did you mean to write `{name} = Instance({instance_type})`? 48 | """ 49 | ) 50 | 51 | 52 | class MultipleTraitErrors(TraitError): 53 | def __new__(cls, errors): 54 | if len(errors) == 1: 55 | # If only one error is passed, pass it through unmodified. 56 | return list(errors.items())[0][1] 57 | return super(MultipleTraitErrors, cls).__new__(cls, errors) 58 | 59 | def __init__(self, errors): 60 | self.errors = errors 61 | 62 | def __str__(self): 63 | return '\n' + ('\n%s\n' % ('-' * 20)).join( 64 | ': '.join((name, str(e))) 65 | for name, e in sorted(self.errors.items(), key=itemgetter(0)) 66 | ) 67 | 68 | 69 | class Serializable(with_metaclass(SerializableMeta, HasTraits)): 70 | """ 71 | Base class for HasTraits instances that can be serialized into Python 72 | primitives. 73 | 74 | The traitlets set on Serializables must be instances of 75 | straitlets.traits.SerializableTrait. 76 | """ 77 | 78 | def __init__(self, **metadata): 79 | unexpected = viewkeys(metadata) - self.trait_names() 80 | if unexpected: 81 | raise TypeError(self._unexpected_kwarg_msg(unexpected)) 82 | super(Serializable, self).__init__(**metadata) 83 | 84 | def validate_all_attributes(self): 85 | """ 86 | Force validation of all traits. 87 | 88 | Useful for circumstances where an attribute won't be accessed until 89 | well after construction, but we want to fail eagerly if that attribute 90 | is passed incorrectly. 91 | 92 | Consider using ``StrictSerializable`` for classes where you always want 93 | this called on construction. 94 | 95 | See Also 96 | -------- 97 | StrictSerializable 98 | """ 99 | errors = {} 100 | for name in self.trait_names(): 101 | try: 102 | getattr(self, name) 103 | except TraitError as e: 104 | errors[name] = e 105 | if errors: 106 | raise MultipleTraitErrors(errors) 107 | 108 | @classmethod 109 | def _unexpected_kwarg_msg(cls, unexpected): 110 | # Provide a more useful error is the user did: 111 | # 112 | # class SomeSerializable(Serializable): 113 | # sub_serial = SomeOtherSerializable() 114 | # 115 | # When what they actually meant was: 116 | # class SomeSerializable(Serializable): 117 | # sub_serial = Instance(SomeOtherSerializable) 118 | for name in unexpected: 119 | not_there = object() 120 | maybe_attr = getattr(cls, name, not_there) 121 | if maybe_attr is not not_there: 122 | return _DID_YOU_MEAN_INSTANCE_TEMPLATE.format( 123 | type=cls.__name__, 124 | name=name, 125 | instance_type=type(maybe_attr).__name__, 126 | ) 127 | return ( 128 | "{type}.__init__() got unexpected" 129 | " keyword arguments {unexpected}.".format( 130 | type=cls.__name__, 131 | unexpected=tuple(unexpected), 132 | ) 133 | ) 134 | 135 | @classmethod 136 | def example_instance(cls, skip=()): 137 | """ 138 | Generate an example instance of a Serializable subclass. 139 | 140 | If traits have been tagged with an `example` value, then we use that 141 | value. Otherwise we fall back the default_value for the instance. 142 | 143 | Traits with names in ``skip`` will not have example values set. 144 | """ 145 | kwargs = {} 146 | for name, trait in iteritems(cls.class_traits()): 147 | if name in skip: 148 | continue 149 | value = trait.example_value 150 | if value is Undefined: 151 | continue 152 | kwargs[name] = value 153 | 154 | return cls(**kwargs) 155 | 156 | @classmethod 157 | def example_yaml(cls, skip=()): 158 | """ 159 | Generate an example yaml string for a Serializable subclass. 160 | 161 | If traits have been tagged with an `example` value, then we use that 162 | value. Otherwise we fall back the default_value for the instance. 163 | """ 164 | return cls.example_instance(skip=skip).to_yaml(skip=skip) 165 | 166 | @classmethod 167 | def write_example_yaml(cls, dest, skip=()): 168 | """ 169 | Write a file containing an example yaml string for a Serializable 170 | subclass. 171 | """ 172 | # Make sure we can make an instance before we open a file. 173 | inst = cls.example_instance(skip=skip) 174 | with open(dest, 'w') as f: 175 | inst.to_yaml(stream=f, skip=skip) 176 | 177 | def to_dict(self, skip=()): 178 | out_dict = {} 179 | for key in self.trait_names(): 180 | if key in skip: 181 | continue 182 | out_dict[key] = to_primitive(getattr(self, key)) 183 | return out_dict 184 | 185 | @classmethod 186 | def from_dict(cls, dict_): 187 | return cls(**dict_) 188 | 189 | def to_json(self, skip=()): 190 | return json.dumps(self.to_dict(skip=skip)) 191 | 192 | @classmethod 193 | def from_json(cls, s): 194 | return cls.from_dict(json.loads(s)) 195 | 196 | def to_yaml(self, stream=None, skip=()): 197 | return yaml.safe_dump( 198 | self.to_dict(skip=skip), 199 | stream=stream, 200 | default_flow_style=False, 201 | ) 202 | 203 | @classmethod 204 | def from_yaml(cls, stream): 205 | return cls.from_dict(yaml.safe_load(stream)) 206 | 207 | @classmethod 208 | def from_yaml_file(cls, path): 209 | with open(path, 'r') as f: 210 | return cls.from_yaml(f) 211 | 212 | @classmethod 213 | def from_base64(cls, s): 214 | """ 215 | Construct from base64-encoded JSON. 216 | """ 217 | return cls.from_json(ensure_unicode(base64.b64decode(s))) 218 | 219 | def to_base64(self, skip=()): 220 | """ 221 | Construct from base64-encoded JSON. 222 | """ 223 | return base64.b64encode( 224 | ensure_bytes( 225 | self.to_json(skip=skip), 226 | encoding='utf-8', 227 | ) 228 | ) 229 | 230 | @classmethod 231 | def from_environ(cls, environ): 232 | """ 233 | Deserialize an instance that was written to the environment via 234 | ``to_environ``. 235 | 236 | Parameters 237 | ---------- 238 | environ : dict-like 239 | Dict-like object (e.g. os.environ) from which to read ``self``. 240 | """ 241 | return cls.from_base64(environ[cls.__name__]) 242 | 243 | def to_environ(self, environ, skip=()): 244 | """ 245 | Serialize and write self to environ[self._envvar]. 246 | 247 | Parameters 248 | ---------- 249 | environ : dict-like 250 | Dict-like object (e.g. os.environ) into which to write ``self``. 251 | """ 252 | environ[ensure_unicode(type(self).__name__)] = ( 253 | ensure_unicode(self.to_base64(skip=skip)) 254 | ) 255 | 256 | 257 | @to_primitive.register(Serializable) 258 | def _serializable_to_primitive(s): 259 | return s.to_dict() 260 | 261 | 262 | class StrictSerializable(Serializable): 263 | """ 264 | Serializable subclass that eagerly evaluates traited attributes after 265 | construction. 266 | 267 | Useful in circumstances where you want to fail as early as possible when an 268 | object is malformed. 269 | """ 270 | 271 | def __init__(self, **metadata): 272 | super(StrictSerializable, self).__init__(**metadata) 273 | self.validate_all_attributes() 274 | -------------------------------------------------------------------------------- /straitlets/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Testing utilities. 3 | """ 4 | from contextlib import contextmanager 5 | 6 | import pytest 7 | from six import iteritems, string_types 8 | from six.moves.urllib.parse import parse_qs 9 | 10 | from .serializable import Serializable 11 | 12 | 13 | def multifixture(g): 14 | """ 15 | Decorator for turning a generator into a "parameterized" fixture that emits 16 | the generated values. 17 | """ 18 | fixture_values = list(g()) 19 | 20 | @pytest.fixture(params=fixture_values) 21 | def _fixture(request): 22 | return request.param 23 | 24 | return _fixture 25 | 26 | 27 | def check_attributes(obj, attrs): 28 | for key, value in iteritems(attrs): 29 | assert getattr(obj, key) == value 30 | 31 | 32 | def assert_serializables_equal(left, right, skip=()): 33 | assert type(left) == type(right) 34 | assert set(left.trait_names()) == set(right.trait_names()) 35 | for name in left.trait_names(): 36 | if name in skip: 37 | continue 38 | left_attr = getattr(left, name) 39 | right_attr = getattr(right, name) 40 | assert type(left_attr) == type(right_attr) 41 | if isinstance(left_attr, Serializable): 42 | assert_serializables_equal(left_attr, right_attr) 43 | else: 44 | assert left_attr == right_attr 45 | 46 | 47 | def assert_urls_equal(left, right): 48 | assert isinstance(left, string_types) 49 | assert isinstance(right, string_types) 50 | 51 | left_parts = left.split('?', 1) 52 | right_parts = right.split('?', 1) 53 | 54 | left_url = left_parts[0] 55 | right_url = right_parts[0] 56 | assert left_url == right_url 57 | 58 | left_params = left_parts[1:] 59 | right_params = right_parts[1:] 60 | assert len(left_params) == len(right_params) 61 | if left_params: 62 | assert parse_qs(left_params[0]) == parse_qs(right_params[0]) 63 | 64 | 65 | @contextmanager 66 | def removed_keys(dict_, keys): 67 | popped = {} 68 | for key in keys: 69 | popped[key] = dict_.pop(key) 70 | try: 71 | yield 72 | finally: 73 | dict_.update(popped) 74 | -------------------------------------------------------------------------------- /straitlets/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quantopian/serializable-traitlets/f7de75507978e08446a15894a8417997940ea7a6/straitlets/tests/__init__.py -------------------------------------------------------------------------------- /straitlets/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ..test_utils import multifixture 3 | 4 | 5 | def _roundtrip_to_dict(traited, skip=()): 6 | return type(traited).from_dict(traited.to_dict(skip=skip)) 7 | 8 | 9 | def _roundtrip_to_json(traited, skip=()): 10 | return type(traited).from_json(traited.to_json(skip=skip)) 11 | 12 | 13 | def _roundtrip_to_yaml(traited, skip=()): 14 | return type(traited).from_yaml(traited.to_yaml(skip=skip)) 15 | 16 | 17 | def _roundtrip_to_base64(traited, skip=()): 18 | return type(traited).from_base64(traited.to_base64(skip=skip)) 19 | 20 | 21 | def _roundtrip_to_environ_dict(traited, skip=()): 22 | environ = {} 23 | traited.to_environ(environ, skip=skip) 24 | return type(traited).from_environ(environ) 25 | 26 | 27 | def _roundtrip_to_os_environ(traited, skip=()): 28 | environ = os.environ 29 | orig = dict(environ) 30 | 31 | traited.to_environ(environ, skip=skip) 32 | try: 33 | return type(traited).from_environ(environ) 34 | finally: 35 | environ.clear() 36 | environ.update(orig) 37 | 38 | 39 | @multifixture 40 | def roundtrip_func(): 41 | yield _roundtrip_to_dict 42 | yield _roundtrip_to_json 43 | yield _roundtrip_to_yaml 44 | yield _roundtrip_to_base64 45 | yield _roundtrip_to_environ_dict 46 | yield _roundtrip_to_os_environ 47 | -------------------------------------------------------------------------------- /straitlets/tests/test_builtin_serializables.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | import pytest 4 | 5 | from traitlets import TraitError 6 | 7 | from straitlets.utils import merge 8 | from straitlets.test_utils import ( 9 | assert_serializables_equal, 10 | check_attributes, 11 | multifixture, 12 | removed_keys, 13 | assert_urls_equal, 14 | ) 15 | 16 | from ..builtin_models import MongoConfig, PostgresConfig 17 | 18 | 19 | @pytest.fixture 20 | def pg_required_kwargs(): 21 | return { 22 | 'username': 'user', 23 | 'database': 'db', 24 | } 25 | 26 | 27 | @pytest.fixture 28 | def pg_optional_kwargs(): 29 | return { 30 | 'hostname': 'localhost', 31 | 'port': 5432, 32 | 'password': 'password', 33 | 'query_params': {'connect_timeout': '10', 'sslmode': 'require'}, 34 | } 35 | 36 | 37 | @multifixture 38 | def mongo_hosts_lists(): 39 | yield ['web'] 40 | yield ['web', 'scale'] 41 | yield ['web:10421', 'scale:10474'] 42 | 43 | 44 | @pytest.fixture 45 | def mongo_optional_kwargs(): 46 | return { 47 | 'replicaset': "secret_ingredient_in_the_webscale_sauce", 48 | 'slave_ok': False, 49 | 'prefer_secondary': False, 50 | 'ssl': False, 51 | 'ssl_ca_certs': '/path/to/ca_cert.pem' 52 | } 53 | 54 | 55 | @pytest.fixture 56 | def mongo_required_kwargs(mongo_hosts_lists): 57 | return { 58 | 'username': 'user', 59 | 'password': 'pass', 60 | 'hosts': mongo_hosts_lists, 61 | 'database': "webscale", 62 | } 63 | 64 | 65 | def test_postgres_config_required(pg_required_kwargs, roundtrip_func): 66 | cfg = PostgresConfig(**pg_required_kwargs) 67 | check_attributes( 68 | cfg, 69 | merge(pg_required_kwargs, {'port': None, 'password': None}), 70 | ) 71 | assert_urls_equal(cfg.url, "postgresql://user@/db") 72 | rounded = roundtrip_func(cfg) 73 | assert_serializables_equal(cfg, rounded, skip=['url']) 74 | assert_urls_equal(rounded.url, cfg.url) 75 | 76 | from_url = PostgresConfig.from_url(cfg.url) 77 | assert_serializables_equal(cfg, from_url, skip=['url']) 78 | assert_urls_equal(from_url.url, cfg.url) 79 | 80 | 81 | def test_postgres_config_optional(pg_required_kwargs, 82 | pg_optional_kwargs, 83 | roundtrip_func): 84 | kwargs = merge(pg_required_kwargs, pg_optional_kwargs) 85 | cfg = PostgresConfig(**kwargs) 86 | check_attributes(cfg, kwargs) 87 | 88 | assert_urls_equal( 89 | cfg.url, 90 | "postgresql://user:password@localhost:5432/db?" 91 | "connect_timeout=10&sslmode=require") 92 | 93 | rounded = roundtrip_func(cfg) 94 | assert_serializables_equal(cfg, rounded) 95 | assert_urls_equal(rounded.url, cfg.url) 96 | 97 | from_url = PostgresConfig.from_url(cfg.url) 98 | assert_serializables_equal(cfg, from_url, skip=['url']) 99 | assert_urls_equal(from_url.url, cfg.url) 100 | 101 | 102 | def test_all_pg_kwargs_required(pg_required_kwargs): 103 | 104 | kwargs = pg_required_kwargs.copy() 105 | for key in kwargs: 106 | with removed_keys(kwargs, [key]), pytest.raises(TraitError) as e: 107 | PostgresConfig(**kwargs) 108 | assert str(e.value).startswith('No default value found for %s' % key) 109 | 110 | 111 | def test_pg_port_requires_hostname(pg_required_kwargs): 112 | 113 | # Hostname without port is ok. 114 | cfg = PostgresConfig(hostname='localhost', **pg_required_kwargs) 115 | check_attributes( 116 | cfg, 117 | merge(pg_required_kwargs, {'hostname': 'localhost'}) 118 | ) 119 | assert cfg.url == "postgresql://user@localhost/db" 120 | 121 | # Port without hostname is an error. 122 | with pytest.raises(TraitError) as e: 123 | PostgresConfig(port=5432, **pg_required_kwargs) 124 | assert str(e.value) == "Received port 5432 but no hostname." 125 | 126 | 127 | def test_mongo_config(mongo_required_kwargs, 128 | mongo_optional_kwargs, 129 | roundtrip_func): 130 | 131 | with pytest.raises(TraitError): 132 | MongoConfig(**mongo_optional_kwargs) 133 | 134 | optional_kwarg_defaults = { 135 | 'replicaset': None, 136 | 'slave_ok': True, 137 | 'prefer_secondary': True, 138 | 'ssl': False, 139 | 'ssl_ca_certs': None 140 | } 141 | 142 | without_optionals = MongoConfig(**mongo_required_kwargs) 143 | check_attributes( 144 | without_optionals, 145 | merge(mongo_required_kwargs, optional_kwarg_defaults), 146 | ) 147 | assert_serializables_equal( 148 | without_optionals, 149 | roundtrip_func(without_optionals) 150 | ) 151 | 152 | full_kwargs = merge(mongo_required_kwargs, mongo_optional_kwargs) 153 | with_optionals = MongoConfig(**full_kwargs) 154 | check_attributes(with_optionals, full_kwargs) 155 | assert_serializables_equal(with_optionals, roundtrip_func(with_optionals)) 156 | 157 | 158 | def test_mongo_config_username_password_both_or_neither(mongo_required_kwargs): 159 | 160 | kwargs = mongo_required_kwargs.copy() 161 | 162 | with removed_keys(kwargs, ['username']), pytest.raises(TraitError) as e: 163 | MongoConfig(**kwargs) 164 | assert str(e.value) == "Password supplied without username." 165 | 166 | with removed_keys(kwargs, ['password']), pytest.raises(TraitError) as e: 167 | MongoConfig(**kwargs) 168 | assert str(e.value) == "Username 'user' supplied without password." 169 | 170 | with removed_keys(kwargs, ['username', 'password']): 171 | cfg = MongoConfig(**kwargs) 172 | 173 | check_attributes( 174 | cfg, 175 | merge(kwargs, {'username': None, 'password': None}), 176 | ) 177 | -------------------------------------------------------------------------------- /straitlets/tests/test_compat.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import sys 3 | 4 | import pytest 5 | 6 | from ..compat import ensure_bytes, ensure_unicode 7 | 8 | 9 | def test_ensure_bytes(): 10 | 11 | b = b"asdfas" 12 | u = u'unicodé' 13 | 14 | assert ensure_bytes(b) is b 15 | assert ensure_bytes(u) == b"unicod\xc3\xa9" 16 | 17 | with pytest.raises(TypeError): 18 | ensure_bytes(1) 19 | 20 | 21 | def test_ensure_unicode(): 22 | 23 | b = b"asdfas" 24 | u = u'unicodé' 25 | 26 | assert ensure_unicode(b) == u"asdfas" 27 | assert ensure_unicode(u) is u 28 | 29 | with pytest.raises(TypeError): 30 | ensure_unicode(1) 31 | 32 | 33 | @pytest.mark.skipif( 34 | sys.version_info.major > 2, 35 | reason='we can import straitlets.py3 in Python 3', 36 | ) 37 | def test_py3_import_error(): # pragma: no cover 38 | with pytest.raises(ImportError) as e: 39 | import straitlets.py3 # noqa 40 | 41 | assert str(e.value) == 'straitlets.py3 is only available in Python 3' 42 | -------------------------------------------------------------------------------- /straitlets/tests/test_cross_validation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from straitlets import Serializable, Unicode 4 | from traitlets import validate, TraitError 5 | 6 | 7 | class SomeConfig(Serializable): 8 | username = Unicode(allow_none=True).example('username') 9 | password = Unicode(allow_none=True).example('password') 10 | 11 | @validate('username', 'password') 12 | def _both_or_neither(self, proposal): 13 | new_value = proposal['value'] 14 | name = proposal['trait'].name 15 | other_name = { 16 | 'username': 'password', 17 | 'password': 'username', 18 | }[name] 19 | other_value = getattr(self, other_name) 20 | if new_value is not None and other_value is None: 21 | raise TraitError("%s supplied without %s" % (name, other_name)) 22 | if new_value is None and other_value is not None: 23 | raise TraitError("%s supplied without %s" % (other_name, name)) 24 | return new_value 25 | 26 | 27 | def test_example_instance(): 28 | example = SomeConfig.example_instance() 29 | assert example.username == 'username' 30 | assert example.password == 'password' 31 | 32 | 33 | def test_both_or_neither(): 34 | cfg = SomeConfig(username='u', password='p') 35 | assert cfg.username == 'u' 36 | assert cfg.password == 'p' 37 | 38 | cfg = SomeConfig(username=None, password=None) 39 | assert cfg.username is None 40 | assert cfg.password is None 41 | 42 | with pytest.raises(TraitError): 43 | SomeConfig(username='username', password=None) 44 | 45 | with pytest.raises(TraitError): 46 | SomeConfig(username=None, password='password') 47 | -------------------------------------------------------------------------------- /straitlets/tests/test_examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for example generation. 3 | """ 4 | from __future__ import unicode_literals 5 | from textwrap import dedent 6 | 7 | import pytest 8 | from yaml import safe_load 9 | 10 | from traitlets import TraitError 11 | 12 | from straitlets.serializable import Serializable 13 | from straitlets.test_utils import assert_serializables_equal, multifixture 14 | from straitlets.traits import ( 15 | Bool, 16 | Dict, 17 | Instance, 18 | Integer, 19 | Unicode, 20 | ) 21 | 22 | 23 | class Point(Serializable): 24 | x = Integer().example(100) 25 | y = Integer().example(101) 26 | 27 | 28 | class ExampleClass(Serializable): 29 | 30 | bool_tag = Bool().example(True) 31 | bool_default = Bool(default_value=False) 32 | bool_both = Bool(default_value=False).example(True) 33 | 34 | int_tag = Integer().example(1) 35 | int_default = Integer(default_value=2) 36 | int_both = Integer(default_value=100).example(3) 37 | 38 | dict_tag = Dict().example({'a': 'b'}) 39 | dict_default = Dict(default_value={'c': 'd'}) 40 | dict_both = Dict(default_value={'z': 'z'}).example({'e': 'f', 'g': 'h'}) 41 | 42 | instance_tag = Instance(Point).example(Point(x=0, y=1)) 43 | instance_default = Instance(Point, default_value=Point(x=2, y=3)) 44 | instance_both = Instance(Point, default_value=Point(x=4, y=5)).example( 45 | Point(x=6, y=7), 46 | ) 47 | # Instance provides a default example if the value is a Serializable 48 | # that can provide an example. 49 | instance_neither = Instance(Point) 50 | 51 | 52 | @pytest.fixture 53 | def expected_instance(): 54 | return ExampleClass( 55 | bool_tag=True, 56 | bool_default=False, 57 | bool_both=True, 58 | int_tag=1, 59 | int_default=2, 60 | int_both=3, 61 | dict_tag={'a': 'b'}, 62 | dict_default={'c': 'd'}, 63 | dict_both={'e': 'f', 'g': 'h'}, 64 | instance_tag=Point(x=0, y=1), 65 | instance_default=Point(x=2, y=3), 66 | instance_both=Point(x=6, y=7), 67 | instance_neither=Point(x=100, y=101), 68 | ) 69 | 70 | 71 | @pytest.fixture 72 | def expected_yaml(): 73 | return dedent( 74 | """ 75 | bool_tag: true 76 | bool_default: false 77 | bool_both: true 78 | int_tag: 1 79 | int_default: 2 80 | int_both: 3 81 | dict_tag: 82 | a: b 83 | dict_default: 84 | c: d 85 | dict_both: 86 | e: f 87 | g: h 88 | instance_tag: 89 | x: 0 90 | y: 1 91 | instance_default: 92 | x: 2 93 | y: 3 94 | instance_both: 95 | x: 6 96 | y: 7 97 | instance_neither: 98 | x: 100 99 | y: 101 100 | """ 101 | ) 102 | 103 | 104 | @multifixture 105 | def skip_names(): 106 | yield () 107 | yield ('bool_tag',) 108 | yield ('bool_tag', 'int_tag') 109 | yield ('dict_tag', 'instance_tag', 'instance_neither') 110 | yield [ 111 | name for name in ExampleClass.class_trait_names() 112 | if name.endswith('tag') 113 | ] 114 | 115 | 116 | def test_example_instance(expected_yaml, expected_instance): 117 | instance = ExampleClass.example_instance() 118 | 119 | assert_serializables_equal(instance, expected_instance) 120 | assert_serializables_equal(instance, ExampleClass.from_yaml(expected_yaml)) 121 | 122 | assert ExampleClass.example_yaml() == expected_instance.to_yaml() 123 | 124 | 125 | def test_example_skip_names(expected_instance, skip_names): 126 | instance = ExampleClass.example_instance(skip=skip_names) 127 | assert_serializables_equal(instance, expected_instance, skip=skip_names) 128 | 129 | for name in skip_names: 130 | with pytest.raises(TraitError): 131 | getattr(instance, name) 132 | 133 | 134 | def test_write_example_yaml(tmpdir, expected_instance, skip_names): 135 | 136 | path = tmpdir.join("test.yaml").strpath 137 | ExampleClass.write_example_yaml(path, skip=skip_names) 138 | 139 | from_file = ExampleClass.from_yaml_file(path) 140 | assert_serializables_equal( 141 | from_file, 142 | expected_instance, 143 | skip=skip_names, 144 | ) 145 | 146 | for name in skip_names: 147 | with pytest.raises(TraitError): 148 | getattr(from_file, name) 149 | 150 | 151 | def test_nested_example(): 152 | 153 | class C(Serializable): 154 | point = Instance(Point) 155 | unicode_ = Unicode().tag(example='foo') 156 | 157 | class B(Serializable): 158 | value = Integer().tag(example=ord('b')) 159 | next_ = Instance(C) 160 | 161 | class A(Serializable): 162 | value = Integer().tag(example=ord('a')) 163 | next_ = Instance(B) 164 | 165 | expected = A( 166 | value=ord('a'), 167 | next_=B( 168 | value=ord('b'), 169 | next_=C( 170 | point=Point.example_instance(), 171 | unicode_='foo', 172 | ), 173 | ), 174 | ) 175 | 176 | assert_serializables_equal(expected, A.example_instance()) 177 | 178 | 179 | def test_readme_example(): 180 | 181 | class Point(Serializable): 182 | x = Integer().example(0) 183 | y = Integer().example(0) 184 | 185 | class Vector(Serializable): 186 | head = Instance(Point) 187 | tail = Instance(Point).example(Point(x=1, y=3)) 188 | 189 | example_yaml = Vector.example_instance().to_yaml() 190 | expected = { 191 | "head": {"x": 0, "y": 0}, 192 | "tail": {"x": 1, "y": 3}, 193 | } 194 | assert safe_load(example_yaml) == expected 195 | -------------------------------------------------------------------------------- /straitlets/tests/test_serializable.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | Tests for serializable.py. 4 | """ 5 | from __future__ import unicode_literals 6 | 7 | import re 8 | from textwrap import dedent 9 | 10 | import pytest 11 | from traitlets import TraitError, Unicode as base_Unicode 12 | 13 | from straitlets.compat import unicode 14 | from straitlets.test_utils import ( 15 | assert_serializables_equal, 16 | check_attributes, 17 | multifixture, 18 | ) 19 | from ..serializable import ( 20 | MultipleTraitErrors, 21 | Serializable, 22 | StrictSerializable, 23 | ) 24 | from ..traits import ( 25 | Bool, 26 | Dict, 27 | Enum, 28 | Float, 29 | Instance, 30 | Integer, 31 | List, 32 | Set, 33 | Unicode, 34 | Tuple, 35 | ) 36 | 37 | 38 | not_ascii = 'unicodé' 39 | 40 | 41 | class Foo(Serializable): 42 | 43 | bool_ = Bool() 44 | float_ = Float() 45 | int_ = Integer() 46 | unicode_ = Unicode() 47 | enum = Enum(values=(1, 2, not_ascii)) 48 | 49 | dict_ = Dict() 50 | list_ = List() 51 | set_ = Set() 52 | tuple_ = Tuple() 53 | 54 | 55 | @multifixture 56 | def foo_kwargs(): 57 | for seed in (1, 2): 58 | list_elems = list(range(seed)) + list(map(unicode, range(seed))) 59 | yield { 60 | 'bool_': bool(seed), 61 | 'float_': float(seed), 62 | 'int_': int(seed), 63 | 'unicode_': unicode(seed), 64 | 'enum': seed, 65 | 66 | 'dict_': { 67 | 'a': list_elems, 68 | not_ascii: list(reversed(list_elems)) 69 | }, 70 | 'list_': list_elems, 71 | 'set_': set(list_elems), 72 | 'tuple_': tuple(list_elems), 73 | } 74 | 75 | 76 | @multifixture 77 | def skip_names(): 78 | yield () 79 | yield ('enum', 'unicode_', 'int_', 'float_', 'bool_') 80 | yield ('dict_', 'list_', 'set_', 'tuple_') 81 | yield Foo.class_trait_names() 82 | 83 | 84 | def test_construct_from_kwargs(foo_kwargs): 85 | instance = Foo(**foo_kwargs) 86 | check_attributes(instance, foo_kwargs) 87 | 88 | 89 | def test_roundtrip(foo_kwargs, roundtrip_func, skip_names): 90 | foo = Foo(**foo_kwargs) 91 | roundtripped = roundtrip_func(foo, skip=skip_names) 92 | assert isinstance(roundtripped, Foo) 93 | assert foo is not roundtripped 94 | assert_serializables_equal(roundtripped, foo, skip=skip_names) 95 | 96 | for name in skip_names: 97 | with pytest.raises(TraitError): 98 | getattr(roundtripped, name) 99 | 100 | 101 | class DynamicDefaults(Serializable): 102 | 103 | non_dynamic = Integer() 104 | 105 | d = Dict() 106 | # This is non-idiomatic usage, but it makes testing simpler. 107 | DEFAULT_D = {not_ascii: 1} 108 | 109 | def _d_default(self): 110 | return self.DEFAULT_D 111 | 112 | l = List() # noqa 113 | DEFAULT_L = [1, 2, not_ascii, 3] 114 | 115 | def _l_default(self): 116 | return self.DEFAULT_L 117 | 118 | 119 | @multifixture 120 | def non_dynamic_val(): 121 | yield 1 122 | yield 2 123 | 124 | 125 | @multifixture 126 | def d_val(): 127 | yield {'x': 1} 128 | yield None 129 | 130 | 131 | @multifixture 132 | def l_val(): 133 | yield [1, 2, not_ascii] 134 | yield None 135 | 136 | 137 | def test_dynamic_defaults(non_dynamic_val, d_val, l_val, roundtrip_func): 138 | expected = { 139 | 'non_dynamic': non_dynamic_val, 140 | 'd': d_val if d_val is not None else DynamicDefaults.DEFAULT_D, 141 | 'l': l_val if l_val is not None else DynamicDefaults.DEFAULT_L, 142 | } 143 | kwargs = {'non_dynamic': non_dynamic_val} 144 | if d_val is not None: 145 | kwargs['d'] = d_val 146 | if l_val is not None: 147 | kwargs['l'] = l_val 148 | 149 | instance = DynamicDefaults(**kwargs) 150 | check_attributes(instance, expected) 151 | check_attributes(roundtrip_func(instance), expected) 152 | 153 | # Do a check without forcing all the attributes via check_attributes. 154 | check_attributes(roundtrip_func(DynamicDefaults(**kwargs)), expected) 155 | 156 | 157 | @pytest.fixture 158 | def foo_instance(): 159 | return Foo( 160 | bool_=True, 161 | float_=5.0, 162 | int_=2, 163 | enum=1, 164 | unicode_="foo", 165 | dict_={"foo": "foo"}, 166 | list_=["foo"], 167 | set_={"foo"}, 168 | tuple_=("foo",), 169 | ) 170 | 171 | 172 | @pytest.fixture 173 | def different_foo_instance(): 174 | return Foo( 175 | bool_=False, 176 | float_=4.0, 177 | int_=3, 178 | enum=2, 179 | unicode_=not_ascii, 180 | dict_={not_ascii: not_ascii}, 181 | list_=["not_foo", not_ascii, 3], 182 | set_={"not_foo", not_ascii}, 183 | tuple_=(), 184 | ) 185 | 186 | 187 | class Nested(Serializable): 188 | unicode_ = Unicode() 189 | dict_ = Dict() 190 | 191 | foo1 = Instance(Foo) 192 | foo2 = Instance(Foo) 193 | 194 | 195 | @multifixture 196 | def unicode_val(): 197 | yield "" 198 | yield "ascii" 199 | yield not_ascii 200 | 201 | 202 | @multifixture 203 | def dict_val(): 204 | yield {} 205 | yield {'foo': {'buzz': ['bar']}} 206 | 207 | 208 | def test_nested(unicode_val, 209 | dict_val, 210 | foo_instance, 211 | different_foo_instance, 212 | roundtrip_func): 213 | 214 | instance = Nested( 215 | unicode_=unicode_val, 216 | dict_=dict_val, 217 | foo1=foo_instance, 218 | foo2=different_foo_instance, 219 | ) 220 | 221 | check_attributes( 222 | instance, 223 | { 224 | "unicode_": unicode_val, 225 | "dict_": dict_val, 226 | } 227 | ) 228 | assert_serializables_equal(instance.foo1, foo_instance) 229 | assert_serializables_equal(instance.foo2, different_foo_instance) 230 | 231 | roundtripped = roundtrip_func(instance) 232 | assert_serializables_equal(instance, roundtripped) 233 | 234 | 235 | def test_double_nested(roundtrip_func): 236 | 237 | class Bottom(Serializable): 238 | x = List() 239 | y = Unicode() 240 | 241 | class Middle(Serializable): 242 | x = Integer() 243 | bottom = Instance(Bottom) 244 | 245 | class Top(Serializable): 246 | x = Unicode() 247 | y = Tuple() 248 | middle = Instance(Middle) 249 | 250 | top = Top( 251 | x="asdf", 252 | y=(1, 2), 253 | middle=Middle( 254 | x=3, 255 | bottom=Bottom( 256 | x=[1, 2], 257 | y="foo", 258 | ) 259 | ) 260 | ) 261 | 262 | assert_serializables_equal(roundtrip_func(top), top) 263 | 264 | 265 | def test_inheritance(roundtrip_func, foo_instance): 266 | 267 | class Parent(Serializable): 268 | a = Integer() 269 | 270 | def _a_default(self): 271 | return 3 272 | 273 | b = Unicode() 274 | 275 | check_attributes(Parent(b="b"), {"a": 3, "b": "b"}) 276 | 277 | class Child(Parent): 278 | x = Instance(Foo) 279 | y = Dict() 280 | 281 | def _a_default(self): 282 | return 4 283 | 284 | child = Child(b="b", x=foo_instance, y={}) 285 | check_attributes(child, {'a': 4, 'b': 'b', 'y': {}}) 286 | assert child.x is foo_instance 287 | 288 | assert_serializables_equal(roundtrip_func(child), child) 289 | 290 | 291 | def test_barf_on_unexpected_input(): 292 | 293 | class MyClass(Serializable): 294 | x = Integer() 295 | # This should be Instance(Foo). 296 | foo = Foo() 297 | 298 | with pytest.raises(TypeError) as e: 299 | MyClass(x=1, y=5) 300 | assert str(e.value) == ( 301 | "MyClass.__init__() got unexpected keyword arguments ('y',)." 302 | ) 303 | 304 | with pytest.raises(TypeError) as e: 305 | MyClass(x=1, foo=Foo()) 306 | assert str(e.value) == ( 307 | dedent( 308 | """ 309 | MyClass.__init__() got unexpected keyword argument 'foo'. 310 | MyClass (or a parent) has a class attribute with the same name. 311 | Did you mean to write `foo = Instance(Foo)`? 312 | """ 313 | ) 314 | ) 315 | 316 | 317 | @pytest.fixture 318 | def foo_yaml(): 319 | return dedent( 320 | """ 321 | bool_: true 322 | float_: 1.0 323 | int_: 2 324 | unicode_: {not_ascii} 325 | enum: {not_ascii} 326 | dict_: 327 | a: 3 328 | b: 4 329 | c: 330 | - 5 331 | - 6 332 | list_: 333 | - 7 334 | - 8 335 | set_: 336 | - 9 337 | - 10 338 | tuple_: 339 | - 11 340 | - 12 341 | """ 342 | ).format(not_ascii=not_ascii) 343 | 344 | 345 | @pytest.fixture 346 | def foo_yaml_expected_result(): 347 | return Foo( 348 | bool_=True, 349 | float_=1.0, 350 | int_=2, 351 | unicode_=not_ascii, 352 | enum=not_ascii, 353 | dict_=dict( 354 | a=3, 355 | b=4, 356 | c=[5, 6], 357 | ), 358 | list_=[7, 8], 359 | set_={9, 10}, 360 | tuple_=(11, 12), 361 | ) 362 | 363 | 364 | def test_from_yaml(foo_yaml, foo_yaml_expected_result): 365 | assert_serializables_equal( 366 | Foo.from_yaml(foo_yaml), 367 | foo_yaml_expected_result 368 | ) 369 | 370 | 371 | def test_from_yaml_file(tmpdir, foo_yaml, foo_yaml_expected_result): 372 | fileobj = tmpdir.join("test.yaml") 373 | fileobj.write_text(foo_yaml, encoding='utf-8') 374 | 375 | assert_serializables_equal( 376 | Foo.from_yaml_file(fileobj.strpath), 377 | foo_yaml_expected_result, 378 | ) 379 | 380 | 381 | def test_reject_non_serializable_traitlet(): 382 | 383 | with pytest.raises(TypeError) as e: 384 | class F(Serializable): 385 | u = base_Unicode() 386 | assert str(e.value) == "Got non-serializable trait u=Unicode" 387 | 388 | 389 | def test_reject_unknown_object(): 390 | 391 | class SomeRandomClass(object): 392 | pass 393 | 394 | with pytest.raises(TypeError) as e: 395 | class F(Serializable): 396 | i = Instance(SomeRandomClass) 397 | 398 | assert str(e.value) == ( 399 | "Can't convert instances of SomeRandomClass to primitives." 400 | ) 401 | 402 | 403 | def test_allow_none(foo_instance, different_foo_instance, roundtrip_func): 404 | 405 | class MyClass(Serializable): 406 | required = Instance(Foo) 407 | explicit_optional = Instance(Foo, allow_none=True) 408 | implicit_optional = Instance(Foo, allow_none=True, default_value=None) 409 | 410 | with pytest.raises(TraitError): 411 | MyClass().required 412 | 413 | with pytest.raises(TraitError): 414 | # This should still raise because there's no default value. 415 | MyClass().explicit_optional 416 | 417 | assert MyClass().implicit_optional is None 418 | 419 | without_optional = MyClass(required=foo_instance, explicit_optional=None) 420 | assert without_optional.required is foo_instance 421 | assert without_optional.explicit_optional is None 422 | assert_serializables_equal( 423 | without_optional, 424 | roundtrip_func(without_optional), 425 | ) 426 | 427 | with_optional = MyClass( 428 | required=foo_instance, 429 | explicit_optional=different_foo_instance, 430 | ) 431 | assert with_optional.required is foo_instance 432 | assert with_optional.explicit_optional is different_foo_instance 433 | assert_serializables_equal( 434 | with_optional, 435 | roundtrip_func(with_optional), 436 | ) 437 | 438 | 439 | def test_lazy_attribute_access(): 440 | 441 | class MyClass(Serializable): 442 | x = Integer() 443 | 444 | m = MyClass() 445 | with pytest.raises(TraitError): 446 | m.x 447 | 448 | assert MyClass(x=1).x == 1 449 | 450 | 451 | def test_validate_all_attributes(): 452 | 453 | class MyClass(Serializable): 454 | x = Integer() 455 | 456 | m = MyClass() 457 | with pytest.raises(TraitError) as validate_err: 458 | m.validate_all_attributes() 459 | 460 | with pytest.raises(TraitError) as touch_err: 461 | m.x 462 | 463 | assert str(validate_err.value) == str(touch_err.value) 464 | 465 | 466 | def test_strict_serializable(): 467 | 468 | class Strict(StrictSerializable): 469 | x = Integer() 470 | 471 | with pytest.raises(TraitError): 472 | Strict() 473 | 474 | with pytest.raises(TraitError): 475 | Strict.from_dict({}) 476 | 477 | with pytest.raises(TypeError): 478 | Strict(x=1, y=2) 479 | 480 | assert Strict(x=1).x == 1 481 | assert Strict(x=1).to_dict() == {"x": 1} 482 | 483 | 484 | class MultipleErrorsStrict(StrictSerializable): 485 | """This is not defined in the body of a function because it breaks my regex! 486 | """ 487 | x = Integer() 488 | y = Integer() 489 | 490 | 491 | def test_multiple_trait_errors(): 492 | 493 | with pytest.raises(MultipleTraitErrors) as e: 494 | MultipleErrorsStrict() 495 | 496 | assert re.match( 497 | dedent( 498 | """\ 499 | ^ 500 | x: No default value found for x trait of <.+?> 501 | -------------------- 502 | y: No default value found for y trait of <.+?>$""" 503 | ), 504 | str(e.value) 505 | ) 506 | -------------------------------------------------------------------------------- /straitlets/tests/test_test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the test utils. 3 | """ 4 | import pytest 5 | 6 | from straitlets import Serializable, Integer 7 | from straitlets.test_utils import assert_serializables_equal 8 | 9 | 10 | def test_assert_serializables_equal(): 11 | 12 | class Foo(Serializable): 13 | x = Integer() 14 | y = Integer() 15 | 16 | class Bar(Serializable): 17 | x = Integer() 18 | y = Integer() 19 | 20 | assert_serializables_equal(Foo(x=1, y=1), Foo(x=1, y=1)) 21 | 22 | with pytest.raises(AssertionError): 23 | assert_serializables_equal(Foo(x=1, y=1), Bar(x=1, y=1)) 24 | 25 | with pytest.raises(AssertionError): 26 | assert_serializables_equal(Foo(x=1, y=1), Foo(x=1, y=2)) 27 | with pytest.raises(AssertionError): 28 | assert_serializables_equal( 29 | Foo(x=1, y=1), 30 | Foo(x=1, y=2), 31 | skip=('x',), 32 | ) 33 | 34 | assert_serializables_equal(Foo(x=1), Foo(x=1), skip=('y',)) 35 | assert_serializables_equal(Foo(y=1), Foo(y=1), skip=('x',)) 36 | -------------------------------------------------------------------------------- /straitlets/tests/test_to_primitive.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ..to_primitive import to_primitive 3 | 4 | 5 | def test_base_case(): 6 | class SomeRandomClass(object): 7 | pass 8 | 9 | with pytest.raises(TypeError) as e: 10 | to_primitive(SomeRandomClass()) 11 | 12 | assert str(e.value) == ( 13 | "Don't know how to convert instances of SomeRandomClass to primitives." 14 | ) 15 | -------------------------------------------------------------------------------- /straitlets/tests/test_traits.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | import traitlets as tr 5 | 6 | from ..serializable import Serializable 7 | from ..test_utils import assert_serializables_equal 8 | from ..traits import Enum, LengthBoundedUnicode 9 | 10 | 11 | def test_reject_unknown_enum_value(): 12 | 13 | class SomeRandomClass(object): 14 | 15 | def __init__(self, x): 16 | self.x = x 17 | 18 | def __str__(self): 19 | return "SomeRandomClass(x=%s)" % self.x 20 | 21 | with pytest.raises(TypeError) as e: 22 | Enum(values=[SomeRandomClass(3)]) 23 | 24 | assert str(e.value) == ( 25 | "Can't convert Enum value SomeRandomClass(x=3) to a primitive." 26 | ) 27 | 28 | 29 | def test_length_bounded_unicode(): 30 | 31 | class F(Serializable): 32 | u = LengthBoundedUnicode(minlen=5, maxlen=10) 33 | 34 | for i in range(5, 11): 35 | F(u=u'a' * i) 36 | 37 | with pytest.raises(tr.TraitError): 38 | F(u=u'a' * 4) 39 | 40 | with pytest.raises(tr.TraitError): 41 | F(u=u'a' * 11) 42 | 43 | 44 | @pytest.mark.skipif( 45 | sys.version_info.major < 3, 46 | reason='Path requires Python 3', 47 | ) 48 | def test_path(roundtrip_func): # pragma: no cover 49 | # defer imports because these only work in Python 3 50 | import pathlib 51 | from ..py3 import Path 52 | 53 | class S(Serializable): 54 | p = Path() 55 | 56 | s = S(p='/etc') 57 | 58 | # ensure that the object stored at ``p`` is a ``pathlib.Path`` 59 | assert isinstance(s.p, pathlib.Path) 60 | 61 | assert_serializables_equal(s, roundtrip_func(s)) 62 | -------------------------------------------------------------------------------- /straitlets/tests/test_utilities.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from straitlets.utils import merge 4 | 5 | 6 | def test_merge(): 7 | with pytest.raises(ValueError): 8 | merge() 9 | 10 | d1 = {'a': 'b'} 11 | assert merge(d1) == d1 12 | 13 | d2 = {'c': 'd'} 14 | assert merge(d1, d2) == {'a': 'b', 'c': 'd'} 15 | 16 | d3 = {'a': 'z', 'e': 'f'} 17 | assert merge(d1, d2, d3) == {'a': 'z', 'c': 'd', 'e': 'f'} 18 | -------------------------------------------------------------------------------- /straitlets/to_primitive.py: -------------------------------------------------------------------------------- 1 | from six import iterkeys, itervalues 2 | from six.moves import zip, map 3 | 4 | from straitlets.compat import long, unicode 5 | from straitlets.dispatch import singledispatch 6 | 7 | 8 | @singledispatch 9 | def to_primitive(obj): 10 | raise TypeError( 11 | "Don't know how to convert instances of %s to primitives." % ( 12 | type(obj).__name__ 13 | ) 14 | ) 15 | 16 | 17 | _base_handler = to_primitive.dispatch(object) 18 | 19 | 20 | def can_convert_to_primitive(type_): 21 | """ 22 | Check whether or not we have a to_primitive handler type_. 23 | """ 24 | return to_primitive.dispatch(type_) is not _base_handler 25 | 26 | 27 | @to_primitive.register(int) 28 | @to_primitive.register(long) # Redundant in PY3, but that's fine. 29 | @to_primitive.register(float) 30 | @to_primitive.register(unicode) 31 | @to_primitive.register(bytes) 32 | @to_primitive.register(type(None)) 33 | def _atom_to_primitive(a): 34 | return a 35 | 36 | 37 | @to_primitive.register(set) 38 | @to_primitive.register(list) 39 | @to_primitive.register(tuple) 40 | @to_primitive.register(frozenset) 41 | def _sequence_to_primitive(s): 42 | return list(map(to_primitive, s)) 43 | 44 | 45 | @to_primitive.register(dict) 46 | def _dict_to_primitive(d): 47 | return dict( 48 | zip( 49 | map(to_primitive, iterkeys(d)), 50 | map(to_primitive, itervalues(d)), 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /straitlets/traits.py: -------------------------------------------------------------------------------- 1 | """ 2 | Enhanced versions of IPython's traitlets. 3 | 4 | Adds the following additional behavior: 5 | 6 | - Strict construction/validation of config attributes. 7 | - Serialization to/from dictionaries containing only primitives. 8 | - More strict handling of default values than traitlets' built-in behavior. 9 | """ 10 | from contextlib import contextmanager 11 | 12 | import traitlets as tr 13 | 14 | from . import compat 15 | from .to_primitive import to_primitive, can_convert_to_primitive 16 | 17 | 18 | @contextmanager 19 | def cross_validation_lock(obj): 20 | """ 21 | A contextmanager for holding Traited object's cross-validators. 22 | 23 | This should be used in circumstances where you want to call _validate, but 24 | don't want to fire cross-validators. 25 | """ 26 | # TODO: Replace this with usage of public API when 27 | # https://github.com/ipython/traitlets/pull/166 lands upstream. 28 | orig = getattr(obj, '_cross_validation_lock', False) 29 | try: 30 | obj._cross_validation_lock = True 31 | yield 32 | finally: 33 | obj._cross_validation_lock = orig 34 | 35 | 36 | class SerializableTrait(tr.TraitType): 37 | 38 | # Override IPython's default values with Undefined so that default values 39 | # must be passed explicitly to trait instances. 40 | default_value = tr.Undefined 41 | 42 | def example(self, value): 43 | return self.tag(example=value) 44 | 45 | def instance_init(self, obj): 46 | super(SerializableTrait, self).instance_init(obj) 47 | # If we were tagged with an example, make sure it's actually a valid 48 | # example. 49 | example = self._static_example_value() 50 | if example is not tr.Undefined: 51 | with cross_validation_lock(obj): 52 | self._validate(obj, example) 53 | 54 | def _static_example_value(self): 55 | return self.metadata.get('example', self.default_value) 56 | 57 | example_value = property(_static_example_value) 58 | 59 | 60 | class Integer(SerializableTrait, tr.Integer): 61 | pass 62 | 63 | 64 | class Float(SerializableTrait, tr.Float): 65 | pass 66 | 67 | 68 | class Unicode(SerializableTrait, tr.Unicode): 69 | pass 70 | 71 | 72 | class LengthBoundedUnicode(Unicode): 73 | 74 | def __init__(self, minlen, maxlen, *args, **kwargs): 75 | self.minlen = minlen 76 | self.maxlen = maxlen 77 | super(LengthBoundedUnicode, self).__init__(*args, **kwargs) 78 | 79 | def validate(self, obj, value): 80 | super_retval = super(LengthBoundedUnicode, self).validate(obj, value) 81 | length = len(value) 82 | if length < self.minlen: 83 | raise tr.TraitError("len(%r) < minlen=%d" % (value, self.minlen)) 84 | elif length > self.maxlen: 85 | raise tr.TraitError("len(%r) > maxlen=%d" % (value, self.maxlen)) 86 | return super_retval 87 | 88 | 89 | class Bool(SerializableTrait, tr.Bool): 90 | pass 91 | 92 | 93 | # Different traitlets container types use different values for `default_value`. 94 | # Figure out what to use by inspecting the signatures of __init__. 95 | def _get_default_value_sentinel(t): 96 | # traitlets Tuple does a kwargs.pop rather than specifying the value in its 97 | # signature. 98 | if t is tr.Tuple: 99 | return tr.Undefined 100 | argspec = compat.argspec(t.__init__) 101 | for name, value in zip(reversed(argspec.args), reversed(argspec.defaults)): 102 | if name == 'default_value': 103 | return value 104 | 105 | raise TypeError( # pragma: nocover 106 | "Can't find default value sentinel for type %s" % t 107 | ) 108 | 109 | 110 | _NOTPASSED = object() 111 | _TRAITLETS_CONTAINER_TYPES = frozenset([tr.List, tr.Set, tr.Dict, tr.Tuple]) 112 | _DEFAULT_VALUE_SENTINELS = { 113 | t: _get_default_value_sentinel(t) for t in _TRAITLETS_CONTAINER_TYPES 114 | } 115 | 116 | 117 | class _ContainerMixin(object): 118 | 119 | def __init__(self, default_value=_NOTPASSED, **kwargs): 120 | # traitlets' Container base class converts default_value into args and 121 | # kwargs to pass to a factory type and sets those values to (), {} when 122 | # default is None or Undefined. They do this so that not every List 123 | # trait shares the same list object as a default value, but each 124 | # subclass mucks with the default value in slightly different ways, and 125 | # all of them interpret 'default_value not passed' as 'construct an 126 | # empty instance', which we don't think is a sane choice of default. 127 | # 128 | # Rather than trying to intercept all the different ways that traitlets 129 | # overrides default values, we just mark whether we've seen an explicit 130 | # default value in our constructor, and our make_dynamic_default 131 | # function yields Undefined if this wasn't specified. 132 | self._have_explicit_default_value = (default_value is not _NOTPASSED) 133 | if not self._have_explicit_default_value: 134 | # Different traitlets use different values in their __init__ 135 | # signatures to signify 'not passed'. Find the correct value to 136 | # forward by inspecting our method resolution order. 137 | for type_ in type(self).mro(): 138 | if type_ in _TRAITLETS_CONTAINER_TYPES: 139 | default_value = _DEFAULT_VALUE_SENTINELS[type_] 140 | break 141 | else: # pragma: nocover 142 | raise tr.TraitError( 143 | "_ContainerMixin applied to unknown type %s" % type(self) 144 | ) 145 | 146 | super(_ContainerMixin, self).__init__( 147 | default_value=default_value, 148 | **kwargs 149 | ) 150 | 151 | def validate(self, obj, value): 152 | # Ensure that the value is coercible to a primitive. 153 | to_primitive(value) 154 | return super(_ContainerMixin, self).validate(obj, value) 155 | 156 | def make_dynamic_default(self): 157 | if not self._have_explicit_default_value: 158 | return None 159 | return super(_ContainerMixin, self).make_dynamic_default() 160 | 161 | 162 | class Set(SerializableTrait, _ContainerMixin, tr.Set): 163 | pass 164 | 165 | 166 | class List(SerializableTrait, _ContainerMixin, tr.List): 167 | pass 168 | 169 | 170 | class Dict(SerializableTrait, _ContainerMixin, tr.Dict): 171 | pass 172 | 173 | 174 | class Tuple(SerializableTrait, _ContainerMixin, tr.Tuple): 175 | pass 176 | 177 | 178 | class Enum(SerializableTrait, tr.Enum): 179 | 180 | def __init__(self, *args, **kwargs): 181 | super(Enum, self).__init__(*args, **kwargs) 182 | for value in self.values: 183 | if not can_convert_to_primitive(type(value)): 184 | raise TypeError( 185 | "Can't convert Enum value %s to a primitive." % value 186 | ) 187 | 188 | 189 | class Instance(SerializableTrait, tr.Instance): 190 | 191 | def __init__(self, *args, **kwargs): 192 | super(Instance, self).__init__(*args, **kwargs) 193 | self._resolve_classes() 194 | if not can_convert_to_primitive(self.klass): 195 | raise TypeError( 196 | "Can't convert instances of %s to primitives." % ( 197 | self.klass.__name__, 198 | ) 199 | ) 200 | 201 | def validate(self, obj, value): 202 | from .serializable import Serializable 203 | if issubclass(self.klass, Serializable) and isinstance(value, dict): 204 | value = self.klass.from_dict(value) 205 | return super(Instance, self).validate(obj, value) 206 | 207 | @property 208 | def example_value(self): 209 | """ 210 | If we're an instance of a Serializable, fall back to its 211 | `example_instance()` method. 212 | """ 213 | from .serializable import Serializable 214 | inst = self._static_example_value() 215 | if inst is tr.Undefined and issubclass(self.klass, Serializable): 216 | return self.klass.example_instance() 217 | return inst 218 | 219 | # Override the base class. 220 | make_dynamic_default = None 221 | -------------------------------------------------------------------------------- /straitlets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def merge(*ds): 3 | """ 4 | Merge together a sequence if dictionaries. 5 | 6 | Later entries overwrite values from earlier entries. 7 | 8 | >>> merge({'a': 'b', 'c': 'd'}, {'a': 'z', 'e': 'f'}) 9 | {'a': 'z', 'c': 'd', 'e': 'f'} 10 | """ 11 | if not ds: 12 | raise ValueError("Must provide at least one dict to merge().") 13 | out = {} 14 | for d in ds: 15 | out.update(d) 16 | return out 17 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist=py{27,34,35,36,37} 3 | skip_missing_interpreters=True 4 | 5 | [testenv] 6 | commands= 7 | pip install -e .[test] 8 | py.test --cov-fail-under 100 9 | 10 | [pytest] 11 | addopts = --pep8 --cov straitlets --cov-report term-missing --cov-report html 12 | testpaths = straitlets 13 | filterwarnings = 14 | # PyYAML==3.13 15 | ignore:Using or importing the ABCs:DeprecationWarning:yaml.constructor:126 16 | --------------------------------------------------------------------------------