├── .buildkite └── sdist.yml ├── .github └── workflows │ ├── cibuildwheel.yml │ ├── publish_pypi.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── bin └── push-tag.sh ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py └── srsly ├── __init__.py ├── _json_api.py ├── _msgpack_api.py ├── _pickle_api.py ├── _yaml_api.py ├── about.py ├── cloudpickle ├── __init__.py ├── cloudpickle.py ├── cloudpickle_fast.py └── compat.py ├── msgpack ├── __init__.py ├── _epoch.pyx ├── _msgpack_numpy.py ├── _packer.pyx ├── _unpacker.pyx ├── _version.py ├── exceptions.py ├── ext.py ├── fallback.py ├── pack.h ├── pack_template.h ├── sysdep.h ├── unpack.h ├── unpack_container_header.h ├── unpack_define.h ├── unpack_template.h └── util.py ├── ruamel_yaml ├── LICENSE ├── __init__.py ├── anchor.py ├── comments.py ├── compat.py ├── composer.py ├── configobjwalker.py ├── constructor.py ├── cyaml.py ├── dumper.py ├── emitter.py ├── error.py ├── events.py ├── loader.py ├── main.py ├── nodes.py ├── parser.py ├── py.typed ├── reader.py ├── representer.py ├── resolver.py ├── scalarbool.py ├── scalarfloat.py ├── scalarint.py ├── scalarstring.py ├── scanner.py ├── serializer.py ├── timestamp.py ├── tokens.py └── util.py ├── tests ├── __init__.py ├── cloudpickle │ ├── __init__.py │ ├── cloudpickle_file_test.py │ ├── cloudpickle_test.py │ ├── mock_local_folder │ │ ├── mod.py │ │ └── subfolder │ │ │ └── submod.py │ └── testutils.py ├── msgpack │ ├── __init__.py │ ├── test_buffer.py │ ├── test_case.py │ ├── test_except.py │ ├── test_extension.py │ ├── test_format.py │ ├── test_limits.py │ ├── test_memoryview.py │ ├── test_newspec.py │ ├── test_numpy.py │ ├── test_pack.py │ ├── test_read_size.py │ ├── test_seq.py │ ├── test_sequnpack.py │ ├── test_stricttype.py │ ├── test_subtype.py │ └── test_unpack.py ├── ruamel_yaml │ ├── __init__.py │ ├── roundtrip.py │ ├── test_a_dedent.py │ ├── test_add_xxx.py │ ├── test_anchor.py │ ├── test_api_change.py │ ├── test_appliance.py │ ├── test_class_register.py │ ├── test_collections.py │ ├── test_comment_manipulation.py │ ├── test_comments.py │ ├── test_contextmanager.py │ ├── test_copy.py │ ├── test_datetime.py │ ├── test_deprecation.py │ ├── test_documents.py │ ├── test_fail.py │ ├── test_float.py │ ├── test_flowsequencekey.py │ ├── test_indentation.py │ ├── test_int.py │ ├── test_issues.py │ ├── test_json_numbers.py │ ├── test_line_col.py │ ├── test_literal.py │ ├── test_none.py │ ├── test_numpy.py │ ├── test_program_config.py │ ├── test_spec_examples.py │ ├── test_string.py │ ├── test_tag.py │ ├── test_version.py │ ├── test_yamlfile.py │ ├── test_yamlobject.py │ ├── test_z_check_debug_leftovers.py │ └── test_z_data.py ├── test_json_api.py ├── test_msgpack_api.py ├── test_pickle_api.py ├── test_yaml_api.py ├── ujson │ ├── 334-reproducer.json │ ├── __init__.py │ └── test_ujson.py └── util.py ├── ujson ├── JSONtoObj.c ├── __init__.py ├── lib │ ├── dconv_wrapper.cc │ ├── ultrajson.h │ ├── ultrajsondec.c │ └── ultrajsonenc.c ├── objToJSON.c ├── py_defines.h ├── ujson.c └── version.h └── util.py /.buildkite/sdist.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | - 3 | command: "./bin/build-sdist.sh" 4 | label: ":dizzy: :python:" 5 | artifact_paths: "dist/*.tar.gz" 6 | -------------------------------------------------------------------------------- /.github/workflows/cibuildwheel.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | tags: 6 | # ytf did they invent their own syntax that's almost regex? 7 | # ** matches 'zero or more of any character' 8 | - 'release-v[0-9]+.[0-9]+.[0-9]+**' 9 | - 'prerelease-v[0-9]+.[0-9]+.[0-9]+**' 10 | jobs: 11 | build_wheels: 12 | name: Build wheels on ${{ matrix.os }} 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | # macos-13 is an intel runner, macos-14 is apple silicon 17 | os: [ubuntu-latest, windows-latest, macos-13, macos-14, ubuntu-24.04-arm] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Build wheels 22 | uses: pypa/cibuildwheel@v2.21.3 23 | env: 24 | CIBW_SOME_OPTION: value 25 | with: 26 | package-dir: . 27 | output-dir: wheelhouse 28 | config-file: "{package}/pyproject.toml" 29 | - uses: actions/upload-artifact@v4 30 | with: 31 | name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} 32 | path: ./wheelhouse/*.whl 33 | 34 | build_sdist: 35 | name: Build source distribution 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/checkout@v4 39 | 40 | - name: Build sdist 41 | run: pipx run build --sdist 42 | - uses: actions/upload-artifact@v4 43 | with: 44 | name: cibw-sdist 45 | path: dist/*.tar.gz 46 | create_release: 47 | needs: [build_wheels, build_sdist] 48 | runs-on: ubuntu-latest 49 | permissions: 50 | contents: write 51 | checks: write 52 | actions: read 53 | issues: read 54 | packages: write 55 | pull-requests: read 56 | repository-projects: read 57 | statuses: read 58 | steps: 59 | - name: Get the tag name and determine if it's a prerelease 60 | id: get_tag_info 61 | run: | 62 | FULL_TAG=${GITHUB_REF#refs/tags/} 63 | if [[ $FULL_TAG == release-* ]]; then 64 | TAG_NAME=${FULL_TAG#release-} 65 | IS_PRERELEASE=false 66 | elif [[ $FULL_TAG == prerelease-* ]]; then 67 | TAG_NAME=${FULL_TAG#prerelease-} 68 | IS_PRERELEASE=true 69 | else 70 | echo "Tag does not match expected patterns" >&2 71 | exit 1 72 | fi 73 | echo "FULL_TAG=$TAG_NAME" >> $GITHUB_ENV 74 | echo "TAG_NAME=$TAG_NAME" >> $GITHUB_ENV 75 | echo "IS_PRERELEASE=$IS_PRERELEASE" >> $GITHUB_ENV 76 | - uses: actions/download-artifact@v4 77 | with: 78 | # unpacks all CIBW artifacts into dist/ 79 | pattern: cibw-* 80 | path: dist 81 | merge-multiple: true 82 | - name: Create Draft Release 83 | id: create_release 84 | uses: softprops/action-gh-release@v2 85 | if: startsWith(github.ref, 'refs/tags/') 86 | env: 87 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 88 | with: 89 | name: ${{ env.TAG_NAME }} 90 | draft: true 91 | prerelease: ${{ env.IS_PRERELEASE }} 92 | files: "./dist/*" 93 | -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | # The cibuildwheel action triggers on creation of a release, this 2 | # triggers on publication. 3 | # The expected workflow is to create a draft release and let the wheels 4 | # upload, and then hit 'publish', which uploads to PyPi. 5 | 6 | on: 7 | release: 8 | types: 9 | - published 10 | 11 | jobs: 12 | upload_pypi: 13 | runs-on: ubuntu-latest 14 | environment: 15 | name: pypi 16 | url: https://pypi.org/p/srsly 17 | permissions: 18 | id-token: write 19 | contents: read 20 | if: github.event_name == 'release' && github.event.action == 'published' 21 | # or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this) 22 | # if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 23 | steps: 24 | - uses: robinraju/release-downloader@v1 25 | with: 26 | tag: ${{ github.event.release.tag_name }} 27 | fileName: '*' 28 | out-file-path: 'dist' 29 | - uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | tags-ignore: 6 | - '**' 7 | paths-ignore: 8 | - "*.md" 9 | - ".github/cibuildwheel.yml" 10 | - ".github/publish_pypi.yml" 11 | pull_request: 12 | types: [opened, synchronize, reopened, edited] 13 | paths-ignore: 14 | - "*.md" 15 | - ".github/cibuildwheel.yml" 16 | - ".github/publish_pypi.yml" 17 | env: 18 | MODULE_NAME: 'srsly' 19 | RUN_MYPY: 'false' 20 | 21 | jobs: 22 | tests: 23 | name: Test 24 | if: github.repository_owner == 'explosion' 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | os: [ubuntu-latest, windows-latest] 29 | python_version: ["3.9", "3.10", "3.11", "3.12"] 30 | runs-on: ${{ matrix.os }} 31 | 32 | steps: 33 | - name: Check out repo 34 | uses: actions/checkout@v3 35 | - name: Configure Python version 36 | uses: actions/setup-python@v4 37 | with: 38 | python-version: ${{ matrix.python_version }} 39 | architecture: x64 40 | 41 | - name: Build sdist 42 | run: | 43 | python -m pip install -U build pip setuptools 44 | python -m pip install -U -r requirements.txt 45 | python -m build --sdist 46 | 47 | - name: Run mypy 48 | shell: bash 49 | if: ${{ env.RUN_MYPY == 'true' }} 50 | run: | 51 | python -m mypy $MODULE_NAME 52 | 53 | - name: Delete source directory 54 | shell: bash 55 | run: | 56 | rm -rf $MODULE_NAME 57 | 58 | - name: Uninstall all packages 59 | run: | 60 | python -m pip freeze > installed.txt 61 | python -m pip uninstall -y -r installed.txt 62 | 63 | - name: Install from sdist 64 | shell: bash 65 | run: | 66 | SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1) 67 | python -m pip install dist/$SDIST 68 | 69 | - name: Test import 70 | shell: bash 71 | run: | 72 | python -c "import $MODULE_NAME" -Werror 73 | 74 | - name: Install test requirements 75 | run: | 76 | python -m pip install -U -r requirements.txt 77 | 78 | - name: Run tests 79 | shell: bash 80 | run: | 81 | python -m pytest --pyargs $MODULE_NAME -Werror 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env/ 2 | .env* 3 | .vscode/ 4 | cythonize.json 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # Cython intermediate files 112 | *.cpp 113 | 114 | # Vim files 115 | *.sw* 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (C) 2018 ExplosionAI UG (haftungsbeschränkt) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include srsly *.h *.pyx *.pxd *.cc *.c *.cpp *.json 2 | include LICENSE 3 | include README.md 4 | -------------------------------------------------------------------------------- /bin/push-tag.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Insist repository is clean 6 | git diff-index --quiet HEAD 7 | 8 | git checkout $1 9 | git pull origin $1 10 | 11 | version=$(grep "__version__ = " srsly/about.py) 12 | version=${version/__version__ = } 13 | version=${version/\'/} 14 | version=${version/\'/} 15 | version=${version/\"/} 16 | version=${version/\"/} 17 | git tag "v$version" 18 | git push origin "v$version" 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "cython>=0.25", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.cibuildwheel] 9 | build = "*" 10 | skip = "pp* cp36* cp37* cp38*" 11 | test-skip = "" 12 | free-threaded-support = false 13 | 14 | archs = ["native"] 15 | 16 | build-frontend = "default" 17 | config-settings = {} 18 | dependency-versions = "pinned" 19 | environment = {} 20 | environment-pass = [] 21 | build-verbosity = 0 22 | 23 | before-all = "" 24 | before-build = "" 25 | repair-wheel-command = "" 26 | 27 | test-command = "" 28 | before-test = "" 29 | test-requires = [] 30 | test-extras = [] 31 | 32 | container-engine = "docker" 33 | 34 | manylinux-x86_64-image = "manylinux2014" 35 | manylinux-i686-image = "manylinux2014" 36 | manylinux-aarch64-image = "manylinux2014" 37 | manylinux-ppc64le-image = "manylinux2014" 38 | manylinux-s390x-image = "manylinux2014" 39 | manylinux-pypy_x86_64-image = "manylinux2014" 40 | manylinux-pypy_i686-image = "manylinux2014" 41 | manylinux-pypy_aarch64-image = "manylinux2014" 42 | 43 | musllinux-x86_64-image = "musllinux_1_2" 44 | musllinux-i686-image = "musllinux_1_2" 45 | musllinux-aarch64-image = "musllinux_1_2" 46 | musllinux-ppc64le-image = "musllinux_1_2" 47 | musllinux-s390x-image = "musllinux_1_2" 48 | 49 | 50 | [tool.cibuildwheel.linux] 51 | repair-wheel-command = "auditwheel repair -w {dest_dir} {wheel}" 52 | 53 | [tool.cibuildwheel.macos] 54 | repair-wheel-command = "delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel}" 55 | 56 | [tool.cibuildwheel.windows] 57 | 58 | [tool.cibuildwheel.pyodide] 59 | 60 | 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | catalogue>=2.0.3,<2.1.0 2 | # Development requirements 3 | cython>=0.29.1 4 | pytest>=4.6.5 5 | pytest-timeout>=1.3.3 6 | mock>=2.0.0,<3.0.0 7 | numpy>=1.15.0 8 | psutil 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description = Modern high-performance serialization utilities for Python 3 | url = https://github.com/explosion/srsly 4 | author = Explosion 5 | author_email = contact@explosion.ai 6 | license = MIT 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | classifiers = 10 | Development Status :: 5 - Production/Stable 11 | Environment :: Console 12 | Intended Audience :: Developers 13 | Intended Audience :: Science/Research 14 | License :: OSI Approved :: MIT License 15 | Operating System :: POSIX :: Linux 16 | Operating System :: MacOS :: MacOS X 17 | Operating System :: Microsoft :: Windows 18 | Programming Language :: Cython 19 | Programming Language :: Python :: 3 20 | Programming Language :: Python :: 3.9 21 | Programming Language :: Python :: 3.10 22 | Programming Language :: Python :: 3.11 23 | Programming Language :: Python :: 3.12 24 | Programming Language :: Python :: 3.13 25 | Topic :: Scientific/Engineering 26 | 27 | [options] 28 | zip_safe = true 29 | include_package_data = true 30 | python_requires = >=3.9,<3.14 31 | setup_requires = 32 | cython>=0.29.1 33 | install_requires = 34 | catalogue>=2.0.3,<2.1.0 35 | 36 | [options.entry_points] 37 | # If spaCy is installed in the same environment as srsly, it will automatically 38 | # have these readers available 39 | spacy_readers = 40 | srsly.read_json.v1 = srsly:read_json 41 | srsly.read_jsonl.v1 = srsly:read_jsonl 42 | srsly.read_yaml.v1 = srsly:read_yaml 43 | srsly.read_msgpack.v1 = srsly:read_msgpack 44 | 45 | [bdist_wheel] 46 | universal = false 47 | 48 | [sdist] 49 | formats = gztar 50 | 51 | [flake8] 52 | ignore = E203, E266, E501, E731, W503, E741 53 | max-line-length = 80 54 | select = B,C,E,F,W,T4,B9 55 | exclude = 56 | srsly/__init__.py 57 | srsly/msgpack/__init__.py 58 | srsly/cloudpickle/__init__.py 59 | 60 | [mypy] 61 | ignore_missing_imports = True 62 | 63 | [mypy-srsly.cloudpickle.*] 64 | ignore_errors=True 65 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | from setuptools.command.build_ext import build_ext 4 | from sysconfig import get_path 5 | from setuptools import Extension, setup, find_packages 6 | from pathlib import Path 7 | from Cython.Build import cythonize 8 | from Cython.Compiler import Options 9 | import contextlib 10 | import os 11 | 12 | 13 | # Preserve `__doc__` on functions and classes 14 | # http://docs.cython.org/en/latest/src/userguide/source_files_and_compilation.html#compiler-options 15 | Options.docstrings = True 16 | 17 | 18 | PACKAGE_DATA = {"": ["*.pyx", "*.pxd", "*.c", "*.h", "*.cpp"]} 19 | PACKAGES = find_packages() 20 | # msgpack has this whacky build where it only builds _cmsgpack which textually includes 21 | # _packer and _unpacker. I refactored this. 22 | MOD_NAMES = ["srsly.msgpack._epoch", "srsly.msgpack._packer", "srsly.msgpack._unpacker"] 23 | COMPILE_OPTIONS = { 24 | "msvc": ["/Ox", "/EHsc"], 25 | "mingw32": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"], 26 | "other": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"], 27 | } 28 | COMPILER_DIRECTIVES = { 29 | "language_level": -3, 30 | "embedsignature": True, 31 | "annotation_typing": False, 32 | } 33 | LINK_OPTIONS = {"msvc": [], "mingw32": [], "other": ["-lstdc++", "-lm"]} 34 | 35 | if sys.byteorder == "big": 36 | macros = [("__BIG_ENDIAN__", "1")] 37 | else: 38 | macros = [("__LITTLE_ENDIAN__", "1")] 39 | 40 | 41 | # By subclassing build_extensions we have the actual compiler that will be used 42 | # which is really known only after finalize_options 43 | # http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used 44 | class build_ext_options: 45 | def build_options(self): 46 | if hasattr(self.compiler, "initialize"): 47 | self.compiler.initialize() 48 | self.compiler.platform = sys.platform[:6] 49 | for e in self.extensions: 50 | e.extra_compile_args += COMPILE_OPTIONS.get( 51 | self.compiler.compiler_type, COMPILE_OPTIONS["other"] 52 | ) 53 | e.extra_link_args += LINK_OPTIONS.get( 54 | self.compiler.compiler_type, LINK_OPTIONS["other"] 55 | ) 56 | 57 | 58 | class build_ext_subclass(build_ext, build_ext_options): 59 | def build_extensions(self): 60 | build_ext_options.build_options(self) 61 | build_ext.build_extensions(self) 62 | 63 | 64 | def clean(path): 65 | n_cleaned = 0 66 | for name in MOD_NAMES: 67 | name = name.replace(".", "/") 68 | for ext in ["so", "html", "cpp", "c"]: 69 | file_path = path / f"{name}.{ext}" 70 | if file_path.exists(): 71 | file_path.unlink() 72 | n_cleaned += 1 73 | print(f"Cleaned {n_cleaned} files") 74 | 75 | 76 | @contextlib.contextmanager 77 | def chdir(new_dir): 78 | old_dir = os.getcwd() 79 | try: 80 | os.chdir(new_dir) 81 | sys.path.insert(0, new_dir) 82 | yield 83 | finally: 84 | del sys.path[0] 85 | os.chdir(old_dir) 86 | 87 | 88 | def setup_package(): 89 | root = Path(__file__).parent 90 | 91 | if len(sys.argv) > 1 and sys.argv[1] == "clean": 92 | return clean(root) 93 | 94 | with (root / "srsly" / "about.py").open("r") as f: 95 | about = {} 96 | exec(f.read(), about) 97 | 98 | with chdir(str(root)): 99 | include_dirs = [get_path("include"), ".", "srsly"] 100 | ext_modules = [] 101 | for name in MOD_NAMES: 102 | mod_path = name.replace(".", "/") + ".pyx" 103 | ext_modules.append( 104 | Extension( 105 | name, 106 | [mod_path], 107 | language="c++", 108 | include_dirs=include_dirs, 109 | define_macros=macros, 110 | ) 111 | ) 112 | ext_modules.append( 113 | Extension( 114 | "srsly.ujson.ujson", 115 | sources=[ 116 | "./srsly/ujson/ujson.c", 117 | "./srsly/ujson/objToJSON.c", 118 | "./srsly/ujson/JSONtoObj.c", 119 | "./srsly/ujson/lib/ultrajsonenc.c", 120 | "./srsly/ujson/lib/ultrajsondec.c", 121 | ], 122 | include_dirs=["./srsly/ujson", "./srsly/ujson/lib"], 123 | extra_compile_args=["-D_GNU_SOURCE"], 124 | ) 125 | ) 126 | print("Cythonizing sources") 127 | ext_modules = cythonize( 128 | ext_modules, compiler_directives=COMPILER_DIRECTIVES, language_level=2 129 | ) 130 | 131 | setup( 132 | name="srsly", 133 | packages=PACKAGES, 134 | version=about["__version__"], 135 | ext_modules=ext_modules, 136 | cmdclass={"build_ext": build_ext_subclass}, 137 | package_data=PACKAGE_DATA, 138 | ) 139 | 140 | 141 | if __name__ == "__main__": 142 | setup_package() 143 | -------------------------------------------------------------------------------- /srsly/__init__.py: -------------------------------------------------------------------------------- 1 | from ._json_api import read_json, read_gzip_json, write_json, write_gzip_json 2 | from ._json_api import read_gzip_jsonl, write_gzip_jsonl 3 | from ._json_api import read_jsonl, write_jsonl 4 | from ._json_api import json_dumps, json_loads, is_json_serializable 5 | from ._msgpack_api import read_msgpack, write_msgpack, msgpack_dumps, msgpack_loads 6 | from ._msgpack_api import msgpack_encoders, msgpack_decoders 7 | from ._pickle_api import pickle_dumps, pickle_loads 8 | from ._yaml_api import read_yaml, write_yaml, yaml_dumps, yaml_loads 9 | from ._yaml_api import is_yaml_serializable 10 | from .about import __version__ 11 | -------------------------------------------------------------------------------- /srsly/_msgpack_api.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | from . import msgpack 4 | from .msgpack import msgpack_encoders, msgpack_decoders # noqa: F401 5 | from .util import force_path, FilePath, JSONInputBin, JSONOutputBin 6 | 7 | 8 | def msgpack_dumps(data: JSONInputBin) -> bytes: 9 | """Serialize an object to a msgpack byte string. 10 | 11 | data: The data to serialize. 12 | RETURNS (bytes): The serialized bytes. 13 | """ 14 | return msgpack.dumps(data, use_bin_type=True) 15 | 16 | 17 | def msgpack_loads(data: bytes, use_list: bool = True) -> JSONOutputBin: 18 | """Deserialize msgpack bytes to a Python object. 19 | 20 | data (bytes): The data to deserialize. 21 | use_list (bool): Don't use tuples instead of lists. Can make 22 | deserialization slower. 23 | RETURNS: The deserialized Python object. 24 | """ 25 | # msgpack-python docs suggest disabling gc before unpacking large messages 26 | gc.disable() 27 | msg = msgpack.loads(data, raw=False, use_list=use_list) 28 | gc.enable() 29 | return msg 30 | 31 | 32 | def write_msgpack(path: FilePath, data: JSONInputBin) -> None: 33 | """Create a msgpack file and dump contents. 34 | 35 | location (FilePath): The file path. 36 | data (JSONInputBin): The data to serialize. 37 | """ 38 | file_path = force_path(path, require_exists=False) 39 | with file_path.open("wb") as f: 40 | msgpack.dump(data, f, use_bin_type=True) 41 | 42 | 43 | def read_msgpack(path: FilePath, use_list: bool = True) -> JSONOutputBin: 44 | """Load a msgpack file. 45 | 46 | location (FilePath): The file path. 47 | use_list (bool): Don't use tuples instead of lists. Can make 48 | deserialization slower. 49 | RETURNS (JSONOutputBin): The loaded and deserialized content. 50 | """ 51 | file_path = force_path(path) 52 | with file_path.open("rb") as f: 53 | # msgpack-python docs suggest disabling gc before unpacking large messages 54 | gc.disable() 55 | msg = msgpack.load(f, raw=False, use_list=use_list) 56 | gc.enable() 57 | return msg 58 | -------------------------------------------------------------------------------- /srsly/_pickle_api.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from . import cloudpickle 4 | from .util import JSONInput, JSONOutput 5 | 6 | 7 | def pickle_dumps(data: JSONInput, protocol: Optional[int] = None) -> bytes: 8 | """Serialize a Python object with pickle. 9 | 10 | data: The object to serialize. 11 | protocol (int): Protocol to use. -1 for highest. 12 | RETURNS (bytes): The serialized object. 13 | """ 14 | return cloudpickle.dumps(data, protocol=protocol) 15 | 16 | 17 | def pickle_loads(data: bytes) -> JSONOutput: 18 | """Deserialize bytes with pickle. 19 | 20 | data (bytes): The data to deserialize. 21 | RETURNS: The deserialized Python object. 22 | """ 23 | return cloudpickle.loads(data) 24 | -------------------------------------------------------------------------------- /srsly/_yaml_api.py: -------------------------------------------------------------------------------- 1 | from typing import Union, IO, Any 2 | from io import StringIO 3 | import sys 4 | 5 | from .ruamel_yaml import YAML 6 | from .ruamel_yaml.representer import RepresenterError 7 | from .util import force_path, FilePath, YAMLInput, YAMLOutput 8 | 9 | 10 | class CustomYaml(YAML): 11 | def __init__(self, typ="safe", pure=True): 12 | YAML.__init__(self, typ=typ, pure=pure) 13 | self.default_flow_style = False 14 | self.allow_unicode = True 15 | self.encoding = "utf-8" 16 | 17 | # https://yaml.readthedocs.io/en/latest/example.html#output-of-dump-as-a-string 18 | def dump(self, data, stream=None, **kw): 19 | inefficient = False 20 | if stream is None: 21 | inefficient = True 22 | stream = StringIO() 23 | YAML.dump(self, data, stream, **kw) 24 | if inefficient: 25 | return stream.getvalue() 26 | 27 | 28 | def yaml_dumps( 29 | data: YAMLInput, 30 | indent_mapping: int = 2, 31 | indent_sequence: int = 4, 32 | indent_offset: int = 2, 33 | sort_keys: bool = False, 34 | ) -> str: 35 | """Serialize an object to a YAML string. See the ruamel.yaml docs on 36 | indentation for more details on the expected format. 37 | https://yaml.readthedocs.io/en/latest/detail.html?highlight=indentation#indentation-of-block-sequences 38 | 39 | data: The YAML-serializable data. 40 | indent_mapping (int): Mapping indentation. 41 | indent_sequence (int): Sequence indentation. 42 | indent_offset (int): Indentation offset. 43 | sort_keys (bool): Sort dictionary keys. 44 | RETURNS (str): The serialized string. 45 | """ 46 | yaml = CustomYaml() 47 | yaml.sort_base_mapping_type_on_output = sort_keys 48 | yaml.indent(mapping=indent_mapping, sequence=indent_sequence, offset=indent_offset) 49 | return yaml.dump(data) 50 | 51 | 52 | def yaml_loads(data: Union[str, IO]) -> YAMLOutput: 53 | """Deserialize unicode or a file object a Python object. 54 | 55 | data (str / file): The data to deserialize. 56 | RETURNS: The deserialized Python object. 57 | """ 58 | yaml = CustomYaml() 59 | try: 60 | return yaml.load(data) 61 | except Exception as e: 62 | raise ValueError(f"Invalid YAML: {e}") 63 | 64 | 65 | def read_yaml(path: FilePath) -> YAMLOutput: 66 | """Load YAML from file or standard input. 67 | 68 | location (FilePath): The file path. "-" for reading from stdin. 69 | RETURNS (YAMLOutput): The loaded content. 70 | """ 71 | if path == "-": # reading from sys.stdin 72 | data = sys.stdin.read() 73 | return yaml_loads(data) 74 | file_path = force_path(path) 75 | with file_path.open("r", encoding="utf8") as f: 76 | return yaml_loads(f) 77 | 78 | 79 | def write_yaml( 80 | path: FilePath, 81 | data: YAMLInput, 82 | indent_mapping: int = 2, 83 | indent_sequence: int = 4, 84 | indent_offset: int = 2, 85 | sort_keys: bool = False, 86 | ) -> None: 87 | """Create a .json file and dump contents or write to standard 88 | output. 89 | 90 | location (FilePath): The file path. "-" for writing to stdout. 91 | data (YAMLInput): The JSON-serializable data to output. 92 | indent_mapping (int): Mapping indentation. 93 | indent_sequence (int): Sequence indentation. 94 | indent_offset (int): Indentation offset. 95 | sort_keys (bool): Sort dictionary keys. 96 | """ 97 | yaml_data = yaml_dumps( 98 | data, 99 | indent_mapping=indent_mapping, 100 | indent_sequence=indent_sequence, 101 | indent_offset=indent_offset, 102 | sort_keys=sort_keys, 103 | ) 104 | if path == "-": # writing to stdout 105 | print(yaml_data) 106 | else: 107 | file_path = force_path(path, require_exists=False) 108 | with file_path.open("w", encoding="utf8") as f: 109 | f.write(yaml_data) 110 | 111 | 112 | def is_yaml_serializable(obj: Any) -> bool: 113 | """Check if a Python object is YAML-serializable (strict). 114 | 115 | obj: The object to check. 116 | RETURNS (bool): Whether the object is YAML-serializable. 117 | """ 118 | try: 119 | yaml_dumps(obj) 120 | return True 121 | except RepresenterError: 122 | return False 123 | -------------------------------------------------------------------------------- /srsly/about.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.5.1" 2 | -------------------------------------------------------------------------------- /srsly/cloudpickle/__init__.py: -------------------------------------------------------------------------------- 1 | from .cloudpickle import * # noqa 2 | from .cloudpickle_fast import CloudPickler, dumps, dump # noqa 3 | 4 | # Conform to the convention used by python serialization libraries, which 5 | # expose their Pickler subclass at top-level under the "Pickler" name. 6 | Pickler = CloudPickler 7 | 8 | __version__ = '2.2.0' 9 | -------------------------------------------------------------------------------- /srsly/cloudpickle/compat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | if sys.version_info < (3, 8): 5 | try: 6 | import pickle5 as pickle # noqa: F401 7 | from pickle5 import Pickler # noqa: F401 8 | except ImportError: 9 | import pickle # noqa: F401 10 | 11 | # Use the Python pickler for old CPython versions 12 | from pickle import _Pickler as Pickler # noqa: F401 13 | else: 14 | import pickle # noqa: F401 15 | 16 | # Pickler will the C implementation in CPython and the Python 17 | # implementation in PyPy 18 | from pickle import Pickler # noqa: F401 19 | -------------------------------------------------------------------------------- /srsly/msgpack/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import functools 4 | import catalogue 5 | 6 | # These need to be imported before packer and unpacker 7 | from ._epoch import utc, epoch # noqa 8 | 9 | from ._version import version 10 | from .exceptions import * 11 | 12 | # In msgpack-python these are put under a _cmsgpack module that textually includes 13 | # them. I dislike this so I refactored it. 14 | from ._packer import Packer as _Packer 15 | from ._unpacker import unpackb as _unpackb 16 | from ._unpacker import Unpacker as _Unpacker 17 | from .ext import ExtType 18 | from ._msgpack_numpy import encode_numpy as _encode_numpy 19 | from ._msgpack_numpy import decode_numpy as _decode_numpy 20 | 21 | 22 | msgpack_encoders = catalogue.create("srsly", "msgpack_encoders", entry_points=True) 23 | msgpack_decoders = catalogue.create("srsly", "msgpack_decoders", entry_points=True) 24 | 25 | msgpack_encoders.register("numpy", func=_encode_numpy) 26 | msgpack_decoders.register("numpy", func=_decode_numpy) 27 | 28 | 29 | # msgpack_numpy extensions 30 | class Packer(_Packer): 31 | def __init__(self, *args, **kwargs): 32 | default = kwargs.get("default") 33 | for encoder in msgpack_encoders.get_all().values(): 34 | default = functools.partial(encoder, chain=default) 35 | kwargs["default"] = default 36 | super(Packer, self).__init__(*args, **kwargs) 37 | 38 | 39 | class Unpacker(_Unpacker): 40 | def __init__(self, *args, **kwargs): 41 | object_hook = kwargs.get("object_hook") 42 | for decoder in msgpack_decoders.get_all().values(): 43 | object_hook = functools.partial(decoder, chain=object_hook) 44 | kwargs["object_hook"] = object_hook 45 | super(Unpacker, self).__init__(*args, **kwargs) 46 | 47 | 48 | def pack(o, stream, **kwargs): 49 | """ 50 | Pack an object and write it to a stream. 51 | """ 52 | packer = Packer(**kwargs) 53 | stream.write(packer.pack(o)) 54 | 55 | 56 | def packb(o, **kwargs): 57 | """ 58 | Pack an object and return the packed bytes. 59 | """ 60 | return Packer(**kwargs).pack(o) 61 | 62 | 63 | def unpack(stream, **kwargs): 64 | """ 65 | Unpack a packed object from a stream. 66 | """ 67 | if "object_pairs_hook" not in kwargs: 68 | object_hook = kwargs.get("object_hook") 69 | for decoder in msgpack_decoders.get_all().values(): 70 | object_hook = functools.partial(decoder, chain=object_hook) 71 | kwargs["object_hook"] = object_hook 72 | data = stream.read() 73 | return _unpackb(data, **kwargs) 74 | 75 | 76 | def unpackb(packed, **kwargs): 77 | """ 78 | Unpack a packed object. 79 | """ 80 | if "object_pairs_hook" not in kwargs: 81 | object_hook = kwargs.get("object_hook") 82 | for decoder in msgpack_decoders.get_all().values(): 83 | object_hook = functools.partial(decoder, chain=object_hook) 84 | kwargs["object_hook"] = object_hook 85 | return _unpackb(packed, **kwargs) 86 | 87 | 88 | # alias for compatibility to simplejson/marshal/pickle. 89 | load = unpack 90 | loads = unpackb 91 | 92 | dump = pack 93 | dumps = packb 94 | -------------------------------------------------------------------------------- /srsly/msgpack/_epoch.pyx: -------------------------------------------------------------------------------- 1 | from cpython.datetime cimport import_datetime, datetime_new 2 | 3 | import_datetime() 4 | import datetime 5 | 6 | utc = datetime.timezone.utc 7 | epoch = datetime_new(1970, 1, 1, 0, 0, 0, 0, tz=utc) 8 | -------------------------------------------------------------------------------- /srsly/msgpack/_msgpack_numpy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Support for serialization of numpy data types with msgpack. 5 | """ 6 | 7 | # Copyright (c) 2013-2018, Lev E. Givon 8 | # All rights reserved. 9 | # Distributed under the terms of the BSD license: 10 | # http://www.opensource.org/licenses/bsd-license 11 | try: 12 | import numpy as np 13 | 14 | has_numpy = True 15 | except ImportError: 16 | has_numpy = False 17 | 18 | try: 19 | import cupy 20 | 21 | has_cupy = True 22 | except ImportError: 23 | has_cupy = False 24 | 25 | 26 | def encode_numpy(obj, chain=None): 27 | """ 28 | Data encoder for serializing numpy data types. 29 | """ 30 | if not has_numpy: 31 | return obj if chain is None else chain(obj) 32 | if has_cupy and isinstance(obj, cupy.ndarray): 33 | obj = obj.get() 34 | if isinstance(obj, np.ndarray): 35 | # If the dtype is structured, store the interface description; 36 | # otherwise, store the corresponding array protocol type string: 37 | if obj.dtype.kind == "V": 38 | kind = b"V" 39 | descr = obj.dtype.descr 40 | else: 41 | kind = b"" 42 | descr = obj.dtype.str 43 | return { 44 | b"nd": True, 45 | b"type": descr, 46 | b"kind": kind, 47 | b"shape": obj.shape, 48 | b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(), 49 | } 50 | elif isinstance(obj, (np.bool_, np.number)): 51 | return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data} 52 | elif isinstance(obj, complex): 53 | return {b"complex": True, b"data": obj.__repr__()} 54 | else: 55 | return obj if chain is None else chain(obj) 56 | 57 | 58 | def tostr(x): 59 | if isinstance(x, bytes): 60 | return x.decode() 61 | else: 62 | return str(x) 63 | 64 | 65 | def decode_numpy(obj, chain=None): 66 | """ 67 | Decoder for deserializing numpy data types. 68 | """ 69 | 70 | try: 71 | if b"nd" in obj: 72 | if obj[b"nd"] is True: 73 | 74 | # Check if b'kind' is in obj to enable decoding of data 75 | # serialized with older versions (#20): 76 | if b"kind" in obj and obj[b"kind"] == b"V": 77 | descr = [ 78 | tuple(tostr(t) if type(t) is bytes else t for t in d) 79 | for d in obj[b"type"] 80 | ] 81 | else: 82 | descr = obj[b"type"] 83 | return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape( 84 | obj[b"shape"] 85 | ) 86 | else: 87 | descr = obj[b"type"] 88 | return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0] 89 | elif b"complex" in obj: 90 | return complex(tostr(obj[b"data"])) 91 | else: 92 | return obj if chain is None else chain(obj) 93 | except KeyError: 94 | return obj if chain is None else chain(obj) 95 | -------------------------------------------------------------------------------- /srsly/msgpack/_version.py: -------------------------------------------------------------------------------- 1 | version = (1, 1, 0) 2 | -------------------------------------------------------------------------------- /srsly/msgpack/exceptions.py: -------------------------------------------------------------------------------- 1 | class UnpackException(Exception): 2 | """Base class for some exceptions raised while unpacking. 3 | 4 | NOTE: unpack may raise exception other than subclass of 5 | UnpackException. If you want to catch all error, catch 6 | Exception instead. 7 | """ 8 | 9 | 10 | class BufferFull(UnpackException): 11 | pass 12 | 13 | 14 | class OutOfData(UnpackException): 15 | pass 16 | 17 | 18 | class FormatError(ValueError, UnpackException): 19 | """Invalid msgpack format""" 20 | 21 | 22 | class StackError(ValueError, UnpackException): 23 | """Too nested""" 24 | 25 | 26 | # Deprecated. Use ValueError instead 27 | UnpackValueError = ValueError 28 | 29 | 30 | class ExtraData(UnpackValueError): 31 | """ExtraData is raised when there is trailing data. 32 | 33 | This exception is raised while only one-shot (not streaming) 34 | unpack. 35 | """ 36 | 37 | def __init__(self, unpacked, extra): 38 | self.unpacked = unpacked 39 | self.extra = extra 40 | 41 | def __str__(self): 42 | return "unpack(b) received extra data." 43 | 44 | 45 | # Deprecated. Use Exception instead to catch all exception during packing. 46 | PackException = Exception 47 | PackValueError = ValueError 48 | PackOverflowError = OverflowError 49 | -------------------------------------------------------------------------------- /srsly/msgpack/ext.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import struct 3 | from collections import namedtuple 4 | 5 | 6 | class ExtType(namedtuple("ExtType", "code data")): 7 | """ExtType represents ext type in msgpack.""" 8 | 9 | def __new__(cls, code, data): 10 | if not isinstance(code, int): 11 | raise TypeError("code must be int") 12 | if not isinstance(data, bytes): 13 | raise TypeError("data must be bytes") 14 | if not 0 <= code <= 127: 15 | raise ValueError("code must be 0~127") 16 | return super().__new__(cls, code, data) 17 | 18 | 19 | class Timestamp: 20 | """Timestamp represents the Timestamp extension type in msgpack. 21 | 22 | When built with Cython, msgpack uses C methods to pack and unpack `Timestamp`. 23 | When using pure-Python msgpack, :func:`to_bytes` and :func:`from_bytes` are used to pack and 24 | unpack `Timestamp`. 25 | 26 | This class is immutable: Do not override seconds and nanoseconds. 27 | """ 28 | 29 | __slots__ = ["seconds", "nanoseconds"] 30 | 31 | def __init__(self, seconds, nanoseconds=0): 32 | """Initialize a Timestamp object. 33 | 34 | :param int seconds: 35 | Number of seconds since the UNIX epoch (00:00:00 UTC Jan 1 1970, minus leap seconds). 36 | May be negative. 37 | 38 | :param int nanoseconds: 39 | Number of nanoseconds to add to `seconds` to get fractional time. 40 | Maximum is 999_999_999. Default is 0. 41 | 42 | Note: Negative times (before the UNIX epoch) are represented as neg. seconds + pos. ns. 43 | """ 44 | if not isinstance(seconds, int): 45 | raise TypeError("seconds must be an integer") 46 | if not isinstance(nanoseconds, int): 47 | raise TypeError("nanoseconds must be an integer") 48 | if not (0 <= nanoseconds < 10**9): 49 | raise ValueError("nanoseconds must be a non-negative integer less than 999999999.") 50 | self.seconds = seconds 51 | self.nanoseconds = nanoseconds 52 | 53 | def __repr__(self): 54 | """String representation of Timestamp.""" 55 | return f"Timestamp(seconds={self.seconds}, nanoseconds={self.nanoseconds})" 56 | 57 | def __eq__(self, other): 58 | """Check for equality with another Timestamp object""" 59 | if type(other) is self.__class__: 60 | return self.seconds == other.seconds and self.nanoseconds == other.nanoseconds 61 | return False 62 | 63 | def __ne__(self, other): 64 | """not-equals method (see :func:`__eq__()`)""" 65 | return not self.__eq__(other) 66 | 67 | def __hash__(self): 68 | return hash((self.seconds, self.nanoseconds)) 69 | 70 | @staticmethod 71 | def from_bytes(b): 72 | """Unpack bytes into a `Timestamp` object. 73 | 74 | Used for pure-Python msgpack unpacking. 75 | 76 | :param b: Payload from msgpack ext message with code -1 77 | :type b: bytes 78 | 79 | :returns: Timestamp object unpacked from msgpack ext payload 80 | :rtype: Timestamp 81 | """ 82 | if len(b) == 4: 83 | seconds = struct.unpack("!L", b)[0] 84 | nanoseconds = 0 85 | elif len(b) == 8: 86 | data64 = struct.unpack("!Q", b)[0] 87 | seconds = data64 & 0x00000003FFFFFFFF 88 | nanoseconds = data64 >> 34 89 | elif len(b) == 12: 90 | nanoseconds, seconds = struct.unpack("!Iq", b) 91 | else: 92 | raise ValueError( 93 | "Timestamp type can only be created from 32, 64, or 96-bit byte objects" 94 | ) 95 | return Timestamp(seconds, nanoseconds) 96 | 97 | def to_bytes(self): 98 | """Pack this Timestamp object into bytes. 99 | 100 | Used for pure-Python msgpack packing. 101 | 102 | :returns data: Payload for EXT message with code -1 (timestamp type) 103 | :rtype: bytes 104 | """ 105 | if (self.seconds >> 34) == 0: # seconds is non-negative and fits in 34 bits 106 | data64 = self.nanoseconds << 34 | self.seconds 107 | if data64 & 0xFFFFFFFF00000000 == 0: 108 | # nanoseconds is zero and seconds < 2**32, so timestamp 32 109 | data = struct.pack("!L", data64) 110 | else: 111 | # timestamp 64 112 | data = struct.pack("!Q", data64) 113 | else: 114 | # timestamp 96 115 | data = struct.pack("!Iq", self.nanoseconds, self.seconds) 116 | return data 117 | 118 | @staticmethod 119 | def from_unix(unix_sec): 120 | """Create a Timestamp from posix timestamp in seconds. 121 | 122 | :param unix_float: Posix timestamp in seconds. 123 | :type unix_float: int or float 124 | """ 125 | seconds = int(unix_sec // 1) 126 | nanoseconds = int((unix_sec % 1) * 10**9) 127 | return Timestamp(seconds, nanoseconds) 128 | 129 | def to_unix(self): 130 | """Get the timestamp as a floating-point value. 131 | 132 | :returns: posix timestamp 133 | :rtype: float 134 | """ 135 | return self.seconds + self.nanoseconds / 1e9 136 | 137 | @staticmethod 138 | def from_unix_nano(unix_ns): 139 | """Create a Timestamp from posix timestamp in nanoseconds. 140 | 141 | :param int unix_ns: Posix timestamp in nanoseconds. 142 | :rtype: Timestamp 143 | """ 144 | return Timestamp(*divmod(unix_ns, 10**9)) 145 | 146 | def to_unix_nano(self): 147 | """Get the timestamp as a unixtime in nanoseconds. 148 | 149 | :returns: posix timestamp in nanoseconds 150 | :rtype: int 151 | """ 152 | return self.seconds * 10**9 + self.nanoseconds 153 | 154 | def to_datetime(self): 155 | """Get the timestamp as a UTC datetime. 156 | 157 | :rtype: `datetime.datetime` 158 | """ 159 | utc = datetime.timezone.utc 160 | return datetime.datetime.fromtimestamp(0, utc) + datetime.timedelta( 161 | seconds=self.seconds, microseconds=self.nanoseconds // 1000 162 | ) 163 | 164 | @staticmethod 165 | def from_datetime(dt): 166 | """Create a Timestamp from datetime with tzinfo. 167 | 168 | :rtype: Timestamp 169 | """ 170 | return Timestamp(seconds=int(dt.timestamp()), nanoseconds=dt.microsecond * 1000) 171 | -------------------------------------------------------------------------------- /srsly/msgpack/pack.h: -------------------------------------------------------------------------------- 1 | /* 2 | * MessagePack for Python packing routine 3 | * 4 | * Copyright (C) 2009 Naoki INADA 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | #include 20 | #include 21 | #include "sysdep.h" 22 | #include 23 | #include 24 | #include 25 | 26 | #ifdef __cplusplus 27 | extern "C" { 28 | #endif 29 | 30 | typedef struct msgpack_packer { 31 | char *buf; 32 | size_t length; 33 | size_t buf_size; 34 | bool use_bin_type; 35 | } msgpack_packer; 36 | 37 | typedef struct Packer Packer; 38 | 39 | static inline int msgpack_pack_write(msgpack_packer* pk, const char *data, size_t l) 40 | { 41 | char* buf = pk->buf; 42 | size_t bs = pk->buf_size; 43 | size_t len = pk->length; 44 | 45 | if (len + l > bs) { 46 | bs = (len + l) * 2; 47 | buf = (char*)PyMem_Realloc(buf, bs); 48 | if (!buf) { 49 | PyErr_NoMemory(); 50 | return -1; 51 | } 52 | } 53 | memcpy(buf + len, data, l); 54 | len += l; 55 | 56 | pk->buf = buf; 57 | pk->buf_size = bs; 58 | pk->length = len; 59 | return 0; 60 | } 61 | 62 | #define msgpack_pack_append_buffer(user, buf, len) \ 63 | return msgpack_pack_write(user, (const char*)buf, len) 64 | 65 | #include "pack_template.h" 66 | 67 | #ifdef __cplusplus 68 | } 69 | #endif 70 | -------------------------------------------------------------------------------- /srsly/msgpack/unpack_container_header.h: -------------------------------------------------------------------------------- 1 | static inline int unpack_container_header(unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off) 2 | { 3 | assert(len >= *off); 4 | uint32_t size; 5 | const unsigned char *const p = (unsigned char*)data + *off; 6 | 7 | #define inc_offset(inc) \ 8 | if (len - *off < inc) \ 9 | return 0; \ 10 | *off += inc; 11 | 12 | switch (*p) { 13 | case var_offset: 14 | inc_offset(3); 15 | size = _msgpack_load16(uint16_t, p + 1); 16 | break; 17 | case var_offset + 1: 18 | inc_offset(5); 19 | size = _msgpack_load32(uint32_t, p + 1); 20 | break; 21 | #ifdef USE_CASE_RANGE 22 | case fixed_offset + 0x0 ... fixed_offset + 0xf: 23 | #else 24 | case fixed_offset + 0x0: 25 | case fixed_offset + 0x1: 26 | case fixed_offset + 0x2: 27 | case fixed_offset + 0x3: 28 | case fixed_offset + 0x4: 29 | case fixed_offset + 0x5: 30 | case fixed_offset + 0x6: 31 | case fixed_offset + 0x7: 32 | case fixed_offset + 0x8: 33 | case fixed_offset + 0x9: 34 | case fixed_offset + 0xa: 35 | case fixed_offset + 0xb: 36 | case fixed_offset + 0xc: 37 | case fixed_offset + 0xd: 38 | case fixed_offset + 0xe: 39 | case fixed_offset + 0xf: 40 | #endif 41 | ++*off; 42 | size = ((unsigned int)*p) & 0x0f; 43 | break; 44 | default: 45 | PyErr_SetString(PyExc_ValueError, "Unexpected type header on stream"); 46 | return -1; 47 | } 48 | unpack_callback_uint32(&ctx->user, size, &ctx->stack[0].obj); 49 | return 1; 50 | } 51 | 52 | -------------------------------------------------------------------------------- /srsly/msgpack/unpack_define.h: -------------------------------------------------------------------------------- 1 | /* 2 | * MessagePack unpacking routine template 3 | * 4 | * Copyright (C) 2008-2010 FURUHASHI Sadayuki 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #ifndef MSGPACK_UNPACK_DEFINE_H__ 19 | #define MSGPACK_UNPACK_DEFINE_H__ 20 | 21 | #include "msgpack/sysdep.h" 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #ifdef __cplusplus 28 | extern "C" { 29 | #endif 30 | 31 | 32 | #ifndef MSGPACK_EMBED_STACK_SIZE 33 | #define MSGPACK_EMBED_STACK_SIZE 32 34 | #endif 35 | 36 | 37 | // CS is first byte & 0x1f 38 | typedef enum { 39 | CS_HEADER = 0x00, // nil 40 | 41 | //CS_ = 0x01, 42 | //CS_ = 0x02, // false 43 | //CS_ = 0x03, // true 44 | 45 | CS_BIN_8 = 0x04, 46 | CS_BIN_16 = 0x05, 47 | CS_BIN_32 = 0x06, 48 | 49 | CS_EXT_8 = 0x07, 50 | CS_EXT_16 = 0x08, 51 | CS_EXT_32 = 0x09, 52 | 53 | CS_FLOAT = 0x0a, 54 | CS_DOUBLE = 0x0b, 55 | CS_UINT_8 = 0x0c, 56 | CS_UINT_16 = 0x0d, 57 | CS_UINT_32 = 0x0e, 58 | CS_UINT_64 = 0x0f, 59 | CS_INT_8 = 0x10, 60 | CS_INT_16 = 0x11, 61 | CS_INT_32 = 0x12, 62 | CS_INT_64 = 0x13, 63 | 64 | //CS_FIXEXT1 = 0x14, 65 | //CS_FIXEXT2 = 0x15, 66 | //CS_FIXEXT4 = 0x16, 67 | //CS_FIXEXT8 = 0x17, 68 | //CS_FIXEXT16 = 0x18, 69 | 70 | CS_RAW_8 = 0x19, 71 | CS_RAW_16 = 0x1a, 72 | CS_RAW_32 = 0x1b, 73 | CS_ARRAY_16 = 0x1c, 74 | CS_ARRAY_32 = 0x1d, 75 | CS_MAP_16 = 0x1e, 76 | CS_MAP_32 = 0x1f, 77 | 78 | ACS_RAW_VALUE, 79 | ACS_BIN_VALUE, 80 | ACS_EXT_VALUE, 81 | } msgpack_unpack_state; 82 | 83 | 84 | typedef enum { 85 | CT_ARRAY_ITEM, 86 | CT_MAP_KEY, 87 | CT_MAP_VALUE, 88 | } msgpack_container_type; 89 | 90 | 91 | #ifdef __cplusplus 92 | } 93 | #endif 94 | 95 | #endif /* msgpack/unpack_define.h */ 96 | -------------------------------------------------------------------------------- /srsly/msgpack/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | try: 4 | unicode 5 | except NameError: 6 | unicode = str 7 | 8 | 9 | def ensure_bytes(string): 10 | """Ensure a string is returned as a bytes object, encoded as utf8.""" 11 | if isinstance(string, unicode): 12 | return string.encode("utf8") 13 | else: 14 | return string 15 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014-2020 Anthon van der Neut, Ruamel bvba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/__init__.py: -------------------------------------------------------------------------------- 1 | __with_libyaml__ = False 2 | 3 | from .main import * # NOQA 4 | 5 | version_info = (0, 16, 7) 6 | __version__ = "0.16.7" 7 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/anchor.py: -------------------------------------------------------------------------------- 1 | 2 | if False: # MYPY 3 | from typing import Any, Dict, Optional, List, Union, Optional, Iterator # NOQA 4 | 5 | anchor_attrib = '_yaml_anchor' 6 | 7 | 8 | class Anchor(object): 9 | __slots__ = 'value', 'always_dump' 10 | attrib = anchor_attrib 11 | 12 | def __init__(self): 13 | # type: () -> None 14 | self.value = None 15 | self.always_dump = False 16 | 17 | def __repr__(self): 18 | # type: () -> Any 19 | ad = ', (always dump)' if self.always_dump else "" 20 | return 'Anchor({!r}{})'.format(self.value, ad) 21 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/configobjwalker.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import warnings 4 | 5 | from .util import configobj_walker as new_configobj_walker 6 | 7 | if False: # MYPY 8 | from typing import Any # NOQA 9 | 10 | 11 | def configobj_walker(cfg): 12 | # type: (Any) -> Any 13 | warnings.warn( 14 | "configobj_walker has moved to srsly.ruamel_yaml.util, please update your code" 15 | ) 16 | return new_configobj_walker(cfg) 17 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/events.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # Abstract classes. 4 | 5 | if False: # MYPY 6 | from typing import Any, Dict, Optional, List # NOQA 7 | 8 | 9 | def CommentCheck(): 10 | # type: () -> None 11 | pass 12 | 13 | 14 | class Event(object): 15 | __slots__ = 'start_mark', 'end_mark', 'comment' 16 | 17 | def __init__(self, start_mark=None, end_mark=None, comment=CommentCheck): 18 | # type: (Any, Any, Any) -> None 19 | self.start_mark = start_mark 20 | self.end_mark = end_mark 21 | # assert comment is not CommentCheck 22 | if comment is CommentCheck: 23 | comment = None 24 | self.comment = comment 25 | 26 | def __repr__(self): 27 | # type: () -> Any 28 | attributes = [ 29 | key 30 | for key in ['anchor', 'tag', 'implicit', 'value', 'flow_style', 'style'] 31 | if hasattr(self, key) 32 | ] 33 | arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) for key in attributes]) 34 | if self.comment not in [None, CommentCheck]: 35 | arguments += ', comment={!r}'.format(self.comment) 36 | return '%s(%s)' % (self.__class__.__name__, arguments) 37 | 38 | 39 | class NodeEvent(Event): 40 | __slots__ = ('anchor',) 41 | 42 | def __init__(self, anchor, start_mark=None, end_mark=None, comment=None): 43 | # type: (Any, Any, Any, Any) -> None 44 | Event.__init__(self, start_mark, end_mark, comment) 45 | self.anchor = anchor 46 | 47 | 48 | class CollectionStartEvent(NodeEvent): 49 | __slots__ = 'tag', 'implicit', 'flow_style', 'nr_items' 50 | 51 | def __init__( 52 | self, 53 | anchor, 54 | tag, 55 | implicit, 56 | start_mark=None, 57 | end_mark=None, 58 | flow_style=None, 59 | comment=None, 60 | nr_items=None, 61 | ): 62 | # type: (Any, Any, Any, Any, Any, Any, Any, Optional[int]) -> None 63 | NodeEvent.__init__(self, anchor, start_mark, end_mark, comment) 64 | self.tag = tag 65 | self.implicit = implicit 66 | self.flow_style = flow_style 67 | self.nr_items = nr_items 68 | 69 | 70 | class CollectionEndEvent(Event): 71 | __slots__ = () 72 | 73 | 74 | # Implementations. 75 | 76 | 77 | class StreamStartEvent(Event): 78 | __slots__ = ('encoding',) 79 | 80 | def __init__(self, start_mark=None, end_mark=None, encoding=None, comment=None): 81 | # type: (Any, Any, Any, Any) -> None 82 | Event.__init__(self, start_mark, end_mark, comment) 83 | self.encoding = encoding 84 | 85 | 86 | class StreamEndEvent(Event): 87 | __slots__ = () 88 | 89 | 90 | class DocumentStartEvent(Event): 91 | __slots__ = 'explicit', 'version', 'tags' 92 | 93 | def __init__( 94 | self, 95 | start_mark=None, 96 | end_mark=None, 97 | explicit=None, 98 | version=None, 99 | tags=None, 100 | comment=None, 101 | ): 102 | # type: (Any, Any, Any, Any, Any, Any) -> None 103 | Event.__init__(self, start_mark, end_mark, comment) 104 | self.explicit = explicit 105 | self.version = version 106 | self.tags = tags 107 | 108 | 109 | class DocumentEndEvent(Event): 110 | __slots__ = ('explicit',) 111 | 112 | def __init__(self, start_mark=None, end_mark=None, explicit=None, comment=None): 113 | # type: (Any, Any, Any, Any) -> None 114 | Event.__init__(self, start_mark, end_mark, comment) 115 | self.explicit = explicit 116 | 117 | 118 | class AliasEvent(NodeEvent): 119 | __slots__ = () 120 | 121 | 122 | class ScalarEvent(NodeEvent): 123 | __slots__ = 'tag', 'implicit', 'value', 'style' 124 | 125 | def __init__( 126 | self, 127 | anchor, 128 | tag, 129 | implicit, 130 | value, 131 | start_mark=None, 132 | end_mark=None, 133 | style=None, 134 | comment=None, 135 | ): 136 | # type: (Any, Any, Any, Any, Any, Any, Any, Any) -> None 137 | NodeEvent.__init__(self, anchor, start_mark, end_mark, comment) 138 | self.tag = tag 139 | self.implicit = implicit 140 | self.value = value 141 | self.style = style 142 | 143 | 144 | class SequenceStartEvent(CollectionStartEvent): 145 | __slots__ = () 146 | 147 | 148 | class SequenceEndEvent(CollectionEndEvent): 149 | __slots__ = () 150 | 151 | 152 | class MappingStartEvent(CollectionStartEvent): 153 | __slots__ = () 154 | 155 | 156 | class MappingEndEvent(CollectionEndEvent): 157 | __slots__ = () 158 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/loader.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import absolute_import 4 | 5 | 6 | from .reader import Reader 7 | from .scanner import Scanner, RoundTripScanner 8 | from .parser import Parser, RoundTripParser 9 | from .composer import Composer 10 | from .constructor import ( 11 | BaseConstructor, 12 | SafeConstructor, 13 | Constructor, 14 | RoundTripConstructor, 15 | ) 16 | from .resolver import VersionedResolver 17 | 18 | if False: # MYPY 19 | from typing import Any, Dict, List, Union, Optional # NOQA 20 | from .compat import StreamTextType, VersionType # NOQA 21 | 22 | __all__ = ["BaseLoader", "SafeLoader", "Loader", "RoundTripLoader"] 23 | 24 | 25 | class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, VersionedResolver): 26 | def __init__(self, stream, version=None, preserve_quotes=None): 27 | # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None 28 | Reader.__init__(self, stream, loader=self) 29 | Scanner.__init__(self, loader=self) 30 | Parser.__init__(self, loader=self) 31 | Composer.__init__(self, loader=self) 32 | BaseConstructor.__init__(self, loader=self) 33 | VersionedResolver.__init__(self, version, loader=self) 34 | 35 | 36 | class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, VersionedResolver): 37 | def __init__(self, stream, version=None, preserve_quotes=None): 38 | # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None 39 | Reader.__init__(self, stream, loader=self) 40 | Scanner.__init__(self, loader=self) 41 | Parser.__init__(self, loader=self) 42 | Composer.__init__(self, loader=self) 43 | SafeConstructor.__init__(self, loader=self) 44 | VersionedResolver.__init__(self, version, loader=self) 45 | 46 | 47 | class Loader(Reader, Scanner, Parser, Composer, Constructor, VersionedResolver): 48 | def __init__(self, stream, version=None, preserve_quotes=None): 49 | raise ValueError("Unsafe loader not implemented in this library.") 50 | 51 | 52 | class RoundTripLoader( 53 | Reader, 54 | RoundTripScanner, 55 | RoundTripParser, 56 | Composer, 57 | RoundTripConstructor, 58 | VersionedResolver, 59 | ): 60 | def __init__(self, stream, version=None, preserve_quotes=None): 61 | # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None 62 | # self.reader = Reader.__init__(self, stream) 63 | Reader.__init__(self, stream, loader=self) 64 | RoundTripScanner.__init__(self, loader=self) 65 | RoundTripParser.__init__(self, loader=self) 66 | Composer.__init__(self, loader=self) 67 | RoundTripConstructor.__init__( 68 | self, preserve_quotes=preserve_quotes, loader=self 69 | ) 70 | VersionedResolver.__init__(self, version, loader=self) 71 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/nodes.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import sys 6 | from .compat import string_types 7 | 8 | if False: # MYPY 9 | from typing import Dict, Any, Text # NOQA 10 | 11 | 12 | class Node(object): 13 | __slots__ = 'tag', 'value', 'start_mark', 'end_mark', 'comment', 'anchor' 14 | 15 | def __init__(self, tag, value, start_mark, end_mark, comment=None, anchor=None): 16 | # type: (Any, Any, Any, Any, Any, Any) -> None 17 | self.tag = tag 18 | self.value = value 19 | self.start_mark = start_mark 20 | self.end_mark = end_mark 21 | self.comment = comment 22 | self.anchor = anchor 23 | 24 | def __repr__(self): 25 | # type: () -> str 26 | value = self.value 27 | # if isinstance(value, list): 28 | # if len(value) == 0: 29 | # value = '' 30 | # elif len(value) == 1: 31 | # value = '<1 item>' 32 | # else: 33 | # value = '<%d items>' % len(value) 34 | # else: 35 | # if len(value) > 75: 36 | # value = repr(value[:70]+u' ... ') 37 | # else: 38 | # value = repr(value) 39 | value = repr(value) 40 | return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value) 41 | 42 | def dump(self, indent=0): 43 | # type: (int) -> None 44 | if isinstance(self.value, string_types): 45 | sys.stdout.write( 46 | '{}{}(tag={!r}, value={!r})\n'.format( 47 | ' ' * indent, self.__class__.__name__, self.tag, self.value 48 | ) 49 | ) 50 | if self.comment: 51 | sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment)) 52 | return 53 | sys.stdout.write( 54 | '{}{}(tag={!r})\n'.format(' ' * indent, self.__class__.__name__, self.tag) 55 | ) 56 | if self.comment: 57 | sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment)) 58 | for v in self.value: 59 | if isinstance(v, tuple): 60 | for v1 in v: 61 | v1.dump(indent + 1) 62 | elif isinstance(v, Node): 63 | v.dump(indent + 1) 64 | else: 65 | sys.stdout.write('Node value type? {}\n'.format(type(v))) 66 | 67 | 68 | class ScalarNode(Node): 69 | """ 70 | styles: 71 | ? -> set() ? key, no value 72 | " -> double quoted 73 | ' -> single quoted 74 | | -> literal style 75 | > -> folding style 76 | """ 77 | 78 | __slots__ = ('style',) 79 | id = 'scalar' 80 | 81 | def __init__( 82 | self, tag, value, start_mark=None, end_mark=None, style=None, comment=None, anchor=None 83 | ): 84 | # type: (Any, Any, Any, Any, Any, Any, Any) -> None 85 | Node.__init__(self, tag, value, start_mark, end_mark, comment=comment, anchor=anchor) 86 | self.style = style 87 | 88 | 89 | class CollectionNode(Node): 90 | __slots__ = ('flow_style',) 91 | 92 | def __init__( 93 | self, 94 | tag, 95 | value, 96 | start_mark=None, 97 | end_mark=None, 98 | flow_style=None, 99 | comment=None, 100 | anchor=None, 101 | ): 102 | # type: (Any, Any, Any, Any, Any, Any, Any) -> None 103 | Node.__init__(self, tag, value, start_mark, end_mark, comment=comment) 104 | self.flow_style = flow_style 105 | self.anchor = anchor 106 | 107 | 108 | class SequenceNode(CollectionNode): 109 | __slots__ = () 110 | id = 'sequence' 111 | 112 | 113 | class MappingNode(CollectionNode): 114 | __slots__ = ('merge',) 115 | id = 'mapping' 116 | 117 | def __init__( 118 | self, 119 | tag, 120 | value, 121 | start_mark=None, 122 | end_mark=None, 123 | flow_style=None, 124 | comment=None, 125 | anchor=None, 126 | ): 127 | # type: (Any, Any, Any, Any, Any, Any, Any) -> None 128 | CollectionNode.__init__( 129 | self, tag, value, start_mark, end_mark, flow_style, comment, anchor 130 | ) 131 | self.merge = None 132 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/srsly/6c0fe81e2cd30093aa3bb12b9f4df67e4da1ff59/srsly/ruamel_yaml/py.typed -------------------------------------------------------------------------------- /srsly/ruamel_yaml/scalarbool.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | """ 6 | You cannot subclass bool, and this is necessary for round-tripping anchored 7 | bool values (and also if you want to preserve the original way of writing) 8 | 9 | bool.__bases__ is type 'int', so that is what is used as the basis for ScalarBoolean as well. 10 | 11 | You can use these in an if statement, but not when testing equivalence 12 | """ 13 | 14 | from .anchor import Anchor 15 | 16 | if False: # MYPY 17 | from typing import Text, Any, Dict, List # NOQA 18 | 19 | __all__ = ["ScalarBoolean"] 20 | 21 | # no need for no_limit_int -> int 22 | 23 | 24 | class ScalarBoolean(int): 25 | def __new__(cls, *args, **kw): 26 | # type: (Any, Any, Any) -> Any 27 | anchor = kw.pop("anchor", None) # type: ignore 28 | b = int.__new__(cls, *args, **kw) # type: ignore 29 | if anchor is not None: 30 | b.yaml_set_anchor(anchor, always_dump=True) 31 | return b 32 | 33 | @property 34 | def anchor(self): 35 | # type: () -> Any 36 | if not hasattr(self, Anchor.attrib): 37 | setattr(self, Anchor.attrib, Anchor()) 38 | return getattr(self, Anchor.attrib) 39 | 40 | def yaml_anchor(self, any=False): 41 | # type: (bool) -> Any 42 | if not hasattr(self, Anchor.attrib): 43 | return None 44 | if any or self.anchor.always_dump: 45 | return self.anchor 46 | return None 47 | 48 | def yaml_set_anchor(self, value, always_dump=False): 49 | # type: (Any, bool) -> None 50 | self.anchor.value = value 51 | self.anchor.always_dump = always_dump 52 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/scalarfloat.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | import sys 6 | from .compat import no_limit_int # NOQA 7 | from .anchor import Anchor 8 | 9 | if False: # MYPY 10 | from typing import Text, Any, Dict, List # NOQA 11 | 12 | __all__ = ["ScalarFloat", "ExponentialFloat", "ExponentialCapsFloat"] 13 | 14 | 15 | class ScalarFloat(float): 16 | def __new__(cls, *args, **kw): 17 | # type: (Any, Any, Any) -> Any 18 | width = kw.pop("width", None) # type: ignore 19 | prec = kw.pop("prec", None) # type: ignore 20 | m_sign = kw.pop("m_sign", None) # type: ignore 21 | m_lead0 = kw.pop("m_lead0", 0) # type: ignore 22 | exp = kw.pop("exp", None) # type: ignore 23 | e_width = kw.pop("e_width", None) # type: ignore 24 | e_sign = kw.pop("e_sign", None) # type: ignore 25 | underscore = kw.pop("underscore", None) # type: ignore 26 | anchor = kw.pop("anchor", None) # type: ignore 27 | v = float.__new__(cls, *args, **kw) # type: ignore 28 | v._width = width 29 | v._prec = prec 30 | v._m_sign = m_sign 31 | v._m_lead0 = m_lead0 32 | v._exp = exp 33 | v._e_width = e_width 34 | v._e_sign = e_sign 35 | v._underscore = underscore 36 | if anchor is not None: 37 | v.yaml_set_anchor(anchor, always_dump=True) 38 | return v 39 | 40 | def __iadd__(self, a): # type: ignore 41 | # type: (Any) -> Any 42 | return float(self) + a 43 | x = type(self)(self + a) 44 | x._width = self._width 45 | x._underscore = ( 46 | self._underscore[:] if self._underscore is not None else None 47 | ) # NOQA 48 | return x 49 | 50 | def __ifloordiv__(self, a): # type: ignore 51 | # type: (Any) -> Any 52 | return float(self) // a 53 | x = type(self)(self // a) 54 | x._width = self._width 55 | x._underscore = ( 56 | self._underscore[:] if self._underscore is not None else None 57 | ) # NOQA 58 | return x 59 | 60 | def __imul__(self, a): # type: ignore 61 | # type: (Any) -> Any 62 | return float(self) * a 63 | x = type(self)(self * a) 64 | x._width = self._width 65 | x._underscore = ( 66 | self._underscore[:] if self._underscore is not None else None 67 | ) # NOQA 68 | x._prec = self._prec # check for others 69 | return x 70 | 71 | def __ipow__(self, a): # type: ignore 72 | # type: (Any) -> Any 73 | return float(self) ** a 74 | x = type(self)(self ** a) 75 | x._width = self._width 76 | x._underscore = ( 77 | self._underscore[:] if self._underscore is not None else None 78 | ) # NOQA 79 | return x 80 | 81 | def __isub__(self, a): # type: ignore 82 | # type: (Any) -> Any 83 | return float(self) - a 84 | x = type(self)(self - a) 85 | x._width = self._width 86 | x._underscore = ( 87 | self._underscore[:] if self._underscore is not None else None 88 | ) # NOQA 89 | return x 90 | 91 | @property 92 | def anchor(self): 93 | # type: () -> Any 94 | if not hasattr(self, Anchor.attrib): 95 | setattr(self, Anchor.attrib, Anchor()) 96 | return getattr(self, Anchor.attrib) 97 | 98 | def yaml_anchor(self, any=False): 99 | # type: (bool) -> Any 100 | if not hasattr(self, Anchor.attrib): 101 | return None 102 | if any or self.anchor.always_dump: 103 | return self.anchor 104 | return None 105 | 106 | def yaml_set_anchor(self, value, always_dump=False): 107 | # type: (Any, bool) -> None 108 | self.anchor.value = value 109 | self.anchor.always_dump = always_dump 110 | 111 | def dump(self, out=sys.stdout): 112 | # type: (Any) -> Any 113 | out.write( 114 | "ScalarFloat({}| w:{}, p:{}, s:{}, lz:{}, _:{}|{}, w:{}, s:{})\n".format( 115 | self, 116 | self._width, # type: ignore 117 | self._prec, # type: ignore 118 | self._m_sign, # type: ignore 119 | self._m_lead0, # type: ignore 120 | self._underscore, # type: ignore 121 | self._exp, # type: ignore 122 | self._e_width, # type: ignore 123 | self._e_sign, # type: ignore 124 | ) 125 | ) 126 | 127 | 128 | class ExponentialFloat(ScalarFloat): 129 | def __new__(cls, value, width=None, underscore=None): 130 | # type: (Any, Any, Any) -> Any 131 | return ScalarFloat.__new__(cls, value, width=width, underscore=underscore) 132 | 133 | 134 | class ExponentialCapsFloat(ScalarFloat): 135 | def __new__(cls, value, width=None, underscore=None): 136 | # type: (Any, Any, Any) -> Any 137 | return ScalarFloat.__new__(cls, value, width=width, underscore=underscore) 138 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/scalarint.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | from .compat import no_limit_int # NOQA 6 | from .anchor import Anchor 7 | 8 | if False: # MYPY 9 | from typing import Text, Any, Dict, List # NOQA 10 | 11 | __all__ = ["ScalarInt", "BinaryInt", "OctalInt", "HexInt", "HexCapsInt", "DecimalInt"] 12 | 13 | 14 | class ScalarInt(no_limit_int): 15 | def __new__(cls, *args, **kw): 16 | # type: (Any, Any, Any) -> Any 17 | width = kw.pop("width", None) # type: ignore 18 | underscore = kw.pop("underscore", None) # type: ignore 19 | anchor = kw.pop("anchor", None) # type: ignore 20 | v = no_limit_int.__new__(cls, *args, **kw) # type: ignore 21 | v._width = width 22 | v._underscore = underscore 23 | if anchor is not None: 24 | v.yaml_set_anchor(anchor, always_dump=True) 25 | return v 26 | 27 | def __iadd__(self, a): # type: ignore 28 | # type: (Any) -> Any 29 | x = type(self)(self + a) 30 | x._width = self._width # type: ignore 31 | x._underscore = ( # type: ignore 32 | self._underscore[:] 33 | if self._underscore is not None 34 | else None # type: ignore 35 | ) # NOQA 36 | return x 37 | 38 | def __ifloordiv__(self, a): # type: ignore 39 | # type: (Any) -> Any 40 | x = type(self)(self // a) 41 | x._width = self._width # type: ignore 42 | x._underscore = ( # type: ignore 43 | self._underscore[:] 44 | if self._underscore is not None 45 | else None # type: ignore 46 | ) # NOQA 47 | return x 48 | 49 | def __imul__(self, a): # type: ignore 50 | # type: (Any) -> Any 51 | x = type(self)(self * a) 52 | x._width = self._width # type: ignore 53 | x._underscore = ( # type: ignore 54 | self._underscore[:] 55 | if self._underscore is not None 56 | else None # type: ignore 57 | ) # NOQA 58 | return x 59 | 60 | def __ipow__(self, a): # type: ignore 61 | # type: (Any) -> Any 62 | x = type(self)(self ** a) 63 | x._width = self._width # type: ignore 64 | x._underscore = ( # type: ignore 65 | self._underscore[:] 66 | if self._underscore is not None 67 | else None # type: ignore 68 | ) # NOQA 69 | return x 70 | 71 | def __isub__(self, a): # type: ignore 72 | # type: (Any) -> Any 73 | x = type(self)(self - a) 74 | x._width = self._width # type: ignore 75 | x._underscore = ( # type: ignore 76 | self._underscore[:] 77 | if self._underscore is not None 78 | else None # type: ignore 79 | ) # NOQA 80 | return x 81 | 82 | @property 83 | def anchor(self): 84 | # type: () -> Any 85 | if not hasattr(self, Anchor.attrib): 86 | setattr(self, Anchor.attrib, Anchor()) 87 | return getattr(self, Anchor.attrib) 88 | 89 | def yaml_anchor(self, any=False): 90 | # type: (bool) -> Any 91 | if not hasattr(self, Anchor.attrib): 92 | return None 93 | if any or self.anchor.always_dump: 94 | return self.anchor 95 | return None 96 | 97 | def yaml_set_anchor(self, value, always_dump=False): 98 | # type: (Any, bool) -> None 99 | self.anchor.value = value 100 | self.anchor.always_dump = always_dump 101 | 102 | 103 | class BinaryInt(ScalarInt): 104 | def __new__(cls, value, width=None, underscore=None, anchor=None): 105 | # type: (Any, Any, Any, Any) -> Any 106 | return ScalarInt.__new__( 107 | cls, value, width=width, underscore=underscore, anchor=anchor 108 | ) 109 | 110 | 111 | class OctalInt(ScalarInt): 112 | def __new__(cls, value, width=None, underscore=None, anchor=None): 113 | # type: (Any, Any, Any, Any) -> Any 114 | return ScalarInt.__new__( 115 | cls, value, width=width, underscore=underscore, anchor=anchor 116 | ) 117 | 118 | 119 | # mixed casing of A-F is not supported, when loading the first non digit 120 | # determines the case 121 | 122 | 123 | class HexInt(ScalarInt): 124 | """uses lower case (a-f)""" 125 | 126 | def __new__(cls, value, width=None, underscore=None, anchor=None): 127 | # type: (Any, Any, Any, Any) -> Any 128 | return ScalarInt.__new__( 129 | cls, value, width=width, underscore=underscore, anchor=anchor 130 | ) 131 | 132 | 133 | class HexCapsInt(ScalarInt): 134 | """uses upper case (A-F)""" 135 | 136 | def __new__(cls, value, width=None, underscore=None, anchor=None): 137 | # type: (Any, Any, Any, Any) -> Any 138 | return ScalarInt.__new__( 139 | cls, value, width=width, underscore=underscore, anchor=anchor 140 | ) 141 | 142 | 143 | class DecimalInt(ScalarInt): 144 | """needed if anchor""" 145 | 146 | def __new__(cls, value, width=None, underscore=None, anchor=None): 147 | # type: (Any, Any, Any, Any) -> Any 148 | return ScalarInt.__new__( 149 | cls, value, width=width, underscore=underscore, anchor=anchor 150 | ) 151 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/scalarstring.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | from .compat import text_type 6 | from .anchor import Anchor 7 | 8 | if False: # MYPY 9 | from typing import Text, Any, Dict, List # NOQA 10 | 11 | __all__ = [ 12 | "ScalarString", 13 | "LiteralScalarString", 14 | "FoldedScalarString", 15 | "SingleQuotedScalarString", 16 | "DoubleQuotedScalarString", 17 | "PlainScalarString", 18 | # PreservedScalarString is the old name, as it was the first to be preserved on rt, 19 | # use LiteralScalarString instead 20 | "PreservedScalarString", 21 | ] 22 | 23 | 24 | class ScalarString(text_type): 25 | __slots__ = Anchor.attrib 26 | 27 | def __new__(cls, *args, **kw): 28 | # type: (Any, Any) -> Any 29 | anchor = kw.pop("anchor", None) # type: ignore 30 | ret_val = text_type.__new__(cls, *args, **kw) # type: ignore 31 | if anchor is not None: 32 | ret_val.yaml_set_anchor(anchor, always_dump=True) 33 | return ret_val 34 | 35 | def replace(self, old, new, maxreplace=-1): 36 | # type: (Any, Any, int) -> Any 37 | return type(self)((text_type.replace(self, old, new, maxreplace))) 38 | 39 | @property 40 | def anchor(self): 41 | # type: () -> Any 42 | if not hasattr(self, Anchor.attrib): 43 | setattr(self, Anchor.attrib, Anchor()) 44 | return getattr(self, Anchor.attrib) 45 | 46 | def yaml_anchor(self, any=False): 47 | # type: (bool) -> Any 48 | if not hasattr(self, Anchor.attrib): 49 | return None 50 | if any or self.anchor.always_dump: 51 | return self.anchor 52 | return None 53 | 54 | def yaml_set_anchor(self, value, always_dump=False): 55 | # type: (Any, bool) -> None 56 | self.anchor.value = value 57 | self.anchor.always_dump = always_dump 58 | 59 | 60 | class LiteralScalarString(ScalarString): 61 | __slots__ = "comment" # the comment after the | on the first line 62 | 63 | style = "|" 64 | 65 | def __new__(cls, value, anchor=None): 66 | # type: (Text, Any) -> Any 67 | return ScalarString.__new__(cls, value, anchor=anchor) 68 | 69 | 70 | PreservedScalarString = LiteralScalarString 71 | 72 | 73 | class FoldedScalarString(ScalarString): 74 | __slots__ = ("fold_pos", "comment") # the comment after the > on the first line 75 | 76 | style = ">" 77 | 78 | def __new__(cls, value, anchor=None): 79 | # type: (Text, Any) -> Any 80 | return ScalarString.__new__(cls, value, anchor=anchor) 81 | 82 | 83 | class SingleQuotedScalarString(ScalarString): 84 | __slots__ = () 85 | 86 | style = "'" 87 | 88 | def __new__(cls, value, anchor=None): 89 | # type: (Text, Any) -> Any 90 | return ScalarString.__new__(cls, value, anchor=anchor) 91 | 92 | 93 | class DoubleQuotedScalarString(ScalarString): 94 | __slots__ = () 95 | 96 | style = '"' 97 | 98 | def __new__(cls, value, anchor=None): 99 | # type: (Text, Any) -> Any 100 | return ScalarString.__new__(cls, value, anchor=anchor) 101 | 102 | 103 | class PlainScalarString(ScalarString): 104 | __slots__ = () 105 | 106 | style = "" 107 | 108 | def __new__(cls, value, anchor=None): 109 | # type: (Text, Any) -> Any 110 | return ScalarString.__new__(cls, value, anchor=anchor) 111 | 112 | 113 | def preserve_literal(s): 114 | # type: (Text) -> Text 115 | return LiteralScalarString(s.replace("\r\n", "\n").replace("\r", "\n")) 116 | 117 | 118 | def walk_tree(base, map=None): 119 | # type: (Any, Any) -> None 120 | """ 121 | the routine here walks over a simple yaml tree (recursing in 122 | dict values and list items) and converts strings that 123 | have multiple lines to literal scalars 124 | 125 | You can also provide an explicit (ordered) mapping for multiple transforms 126 | (first of which is executed): 127 | map = .compat.ordereddict 128 | map['\n'] = preserve_literal 129 | map[':'] = SingleQuotedScalarString 130 | walk_tree(data, map=map) 131 | """ 132 | from .compat import string_types 133 | from .compat import MutableMapping, MutableSequence # type: ignore 134 | 135 | if map is None: 136 | map = {"\n": preserve_literal} 137 | 138 | if isinstance(base, MutableMapping): 139 | for k in base: 140 | v = base[k] # type: Text 141 | if isinstance(v, string_types): 142 | for ch in map: 143 | if ch in v: 144 | base[k] = map[ch](v) 145 | break 146 | else: 147 | walk_tree(v) 148 | elif isinstance(base, MutableSequence): 149 | for idx, elem in enumerate(base): 150 | if isinstance(elem, string_types): 151 | for ch in map: 152 | if ch in elem: # type: ignore 153 | base[idx] = map[ch](elem) 154 | break 155 | else: 156 | walk_tree(elem) 157 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/timestamp.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | import datetime 6 | import copy 7 | 8 | # ToDo: at least on PY3 you could probably attach the tzinfo correctly to the object 9 | # a more complete datetime might be used by safe loading as well 10 | 11 | if False: # MYPY 12 | from typing import Any, Dict, Optional, List # NOQA 13 | 14 | 15 | class TimeStamp(datetime.datetime): 16 | def __init__(self, *args, **kw): 17 | # type: (Any, Any) -> None 18 | self._yaml = dict(t=False, tz=None, delta=0) # type: Dict[Any, Any] 19 | 20 | def __new__(cls, *args, **kw): # datetime is immutable 21 | # type: (Any, Any) -> Any 22 | return datetime.datetime.__new__(cls, *args, **kw) # type: ignore 23 | 24 | def __deepcopy__(self, memo): 25 | # type: (Any) -> Any 26 | ts = TimeStamp(self.year, self.month, self.day, self.hour, self.minute, self.second) 27 | ts._yaml = copy.deepcopy(self._yaml) 28 | return ts 29 | -------------------------------------------------------------------------------- /srsly/ruamel_yaml/util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | some helper functions that might be generally useful 5 | """ 6 | 7 | from __future__ import absolute_import, print_function 8 | 9 | from functools import partial 10 | import re 11 | 12 | from .compat import text_type, binary_type 13 | 14 | if False: # MYPY 15 | from typing import Any, Dict, Optional, List, Text # NOQA 16 | from .compat import StreamTextType # NOQA 17 | 18 | 19 | class LazyEval(object): 20 | """ 21 | Lightweight wrapper around lazily evaluated func(*args, **kwargs). 22 | 23 | func is only evaluated when any attribute of its return value is accessed. 24 | Every attribute access is passed through to the wrapped value. 25 | (This only excludes special cases like method-wrappers, e.g., __hash__.) 26 | The sole additional attribute is the lazy_self function which holds the 27 | return value (or, prior to evaluation, func and arguments), in its closure. 28 | """ 29 | 30 | def __init__(self, func, *args, **kwargs): 31 | # type: (Any, Any, Any) -> None 32 | def lazy_self(): 33 | # type: () -> Any 34 | return_value = func(*args, **kwargs) 35 | object.__setattr__(self, 'lazy_self', lambda: return_value) 36 | return return_value 37 | 38 | object.__setattr__(self, 'lazy_self', lazy_self) 39 | 40 | def __getattribute__(self, name): 41 | # type: (Any) -> Any 42 | lazy_self = object.__getattribute__(self, 'lazy_self') 43 | if name == 'lazy_self': 44 | return lazy_self 45 | return getattr(lazy_self(), name) 46 | 47 | def __setattr__(self, name, value): 48 | # type: (Any, Any) -> None 49 | setattr(self.lazy_self(), name, value) 50 | 51 | 52 | RegExp = partial(LazyEval, re.compile) 53 | 54 | 55 | # originally as comment 56 | # https://github.com/pre-commit/pre-commit/pull/211#issuecomment-186466605 57 | # if you use this in your code, I suggest adding a test in your test suite 58 | # that check this routines output against a known piece of your YAML 59 | # before upgrades to this code break your round-tripped YAML 60 | def load_yaml_guess_indent(stream, **kw): 61 | # type: (StreamTextType, Any) -> Any 62 | """guess the indent and block sequence indent of yaml stream/string 63 | 64 | returns round_trip_loaded stream, indent level, block sequence indent 65 | - block sequence indent is the number of spaces before a dash relative to previous indent 66 | - if there are no block sequences, indent is taken from nested mappings, block sequence 67 | indent is unset (None) in that case 68 | """ 69 | from .main import round_trip_load 70 | 71 | # load a yaml file guess the indentation, if you use TABs ... 72 | def leading_spaces(l): 73 | # type: (Any) -> int 74 | idx = 0 75 | while idx < len(l) and l[idx] == ' ': 76 | idx += 1 77 | return idx 78 | 79 | if isinstance(stream, text_type): 80 | yaml_str = stream # type: Any 81 | elif isinstance(stream, binary_type): 82 | # most likely, but the Reader checks BOM for this 83 | yaml_str = stream.decode('utf-8') 84 | else: 85 | yaml_str = stream.read() 86 | map_indent = None 87 | indent = None # default if not found for some reason 88 | block_seq_indent = None 89 | prev_line_key_only = None 90 | key_indent = 0 91 | for line in yaml_str.splitlines(): 92 | rline = line.rstrip() 93 | lline = rline.lstrip() 94 | if lline.startswith('- '): 95 | l_s = leading_spaces(line) 96 | block_seq_indent = l_s - key_indent 97 | idx = l_s + 1 98 | while line[idx] == ' ': # this will end as we rstripped 99 | idx += 1 100 | if line[idx] == '#': # comment after - 101 | continue 102 | indent = idx - key_indent 103 | break 104 | if map_indent is None and prev_line_key_only is not None and rline: 105 | idx = 0 106 | while line[idx] in ' -': 107 | idx += 1 108 | if idx > prev_line_key_only: 109 | map_indent = idx - prev_line_key_only 110 | if rline.endswith(':'): 111 | key_indent = leading_spaces(line) 112 | idx = 0 113 | while line[idx] == ' ': # this will end on ':' 114 | idx += 1 115 | prev_line_key_only = idx 116 | continue 117 | prev_line_key_only = None 118 | if indent is None and map_indent is not None: 119 | indent = map_indent 120 | return round_trip_load(yaml_str, **kw), indent, block_seq_indent 121 | 122 | 123 | def configobj_walker(cfg): 124 | # type: (Any) -> Any 125 | """ 126 | walks over a ConfigObj (INI file with comments) generating 127 | corresponding YAML output (including comments 128 | """ 129 | from configobj import ConfigObj # type: ignore 130 | 131 | assert isinstance(cfg, ConfigObj) 132 | for c in cfg.initial_comment: 133 | if c.strip(): 134 | yield c 135 | for s in _walk_section(cfg): 136 | if s.strip(): 137 | yield s 138 | for c in cfg.final_comment: 139 | if c.strip(): 140 | yield c 141 | 142 | 143 | def _walk_section(s, level=0): 144 | # type: (Any, int) -> Any 145 | from configobj import Section 146 | 147 | assert isinstance(s, Section) 148 | indent = u' ' * level 149 | for name in s.scalars: 150 | for c in s.comments[name]: 151 | yield indent + c.strip() 152 | x = s[name] 153 | if u'\n' in x: 154 | i = indent + u' ' 155 | x = u'|\n' + i + x.strip().replace(u'\n', u'\n' + i) 156 | elif ':' in x: 157 | x = u"'" + x.replace(u"'", u"''") + u"'" 158 | line = u'{0}{1}: {2}'.format(indent, name, x) 159 | c = s.inline_comments[name] 160 | if c: 161 | line += u' ' + c 162 | yield line 163 | for name in s.sections: 164 | for c in s.comments[name]: 165 | yield indent + c.strip() 166 | line = u'{0}{1}:'.format(indent, name) 167 | c = s.inline_comments[name] 168 | if c: 169 | line += u' ' + c 170 | yield line 171 | for val in _walk_section(s[name], level=level + 1): 172 | yield val 173 | 174 | 175 | # def config_obj_2_rt_yaml(cfg): 176 | # from .comments import CommentedMap, CommentedSeq 177 | # from configobj import ConfigObj 178 | # assert isinstance(cfg, ConfigObj) 179 | # #for c in cfg.initial_comment: 180 | # # if c.strip(): 181 | # # pass 182 | # cm = CommentedMap() 183 | # for name in s.sections: 184 | # cm[name] = d = CommentedMap() 185 | # 186 | # 187 | # #for c in cfg.final_comment: 188 | # # if c.strip(): 189 | # # yield c 190 | # return cm 191 | -------------------------------------------------------------------------------- /srsly/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/srsly/6c0fe81e2cd30093aa3bb12b9f4df67e4da1ff59/srsly/tests/__init__.py -------------------------------------------------------------------------------- /srsly/tests/cloudpickle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/srsly/6c0fe81e2cd30093aa3bb12b9f4df67e4da1ff59/srsly/tests/cloudpickle/__init__.py -------------------------------------------------------------------------------- /srsly/tests/cloudpickle/cloudpickle_file_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | import tempfile 5 | import unittest 6 | 7 | import pytest 8 | 9 | import srsly.cloudpickle as cloudpickle 10 | from srsly.cloudpickle.compat import pickle 11 | 12 | 13 | class CloudPickleFileTests(unittest.TestCase): 14 | """In Cloudpickle, expected behaviour when pickling an opened file 15 | is to send its contents over the wire and seek to the same position.""" 16 | 17 | def setUp(self): 18 | self.tmpdir = tempfile.mkdtemp() 19 | self.tmpfilepath = os.path.join(self.tmpdir, 'testfile') 20 | self.teststring = 'Hello world!' 21 | 22 | def tearDown(self): 23 | shutil.rmtree(self.tmpdir) 24 | 25 | def test_empty_file(self): 26 | # Empty file 27 | open(self.tmpfilepath, 'w').close() 28 | with open(self.tmpfilepath, 'r') as f: 29 | self.assertEqual('', pickle.loads(cloudpickle.dumps(f)).read()) 30 | os.remove(self.tmpfilepath) 31 | 32 | def test_closed_file(self): 33 | # Write & close 34 | with open(self.tmpfilepath, 'w') as f: 35 | f.write(self.teststring) 36 | with pytest.raises(pickle.PicklingError) as excinfo: 37 | cloudpickle.dumps(f) 38 | assert "Cannot pickle closed files" in str(excinfo.value) 39 | os.remove(self.tmpfilepath) 40 | 41 | def test_r_mode(self): 42 | # Write & close 43 | with open(self.tmpfilepath, 'w') as f: 44 | f.write(self.teststring) 45 | # Open for reading 46 | with open(self.tmpfilepath, 'r') as f: 47 | new_f = pickle.loads(cloudpickle.dumps(f)) 48 | self.assertEqual(self.teststring, new_f.read()) 49 | os.remove(self.tmpfilepath) 50 | 51 | def test_w_mode(self): 52 | with open(self.tmpfilepath, 'w') as f: 53 | f.write(self.teststring) 54 | f.seek(0) 55 | self.assertRaises(pickle.PicklingError, 56 | lambda: cloudpickle.dumps(f)) 57 | os.remove(self.tmpfilepath) 58 | 59 | def test_plus_mode(self): 60 | # Write, then seek to 0 61 | with open(self.tmpfilepath, 'w+') as f: 62 | f.write(self.teststring) 63 | f.seek(0) 64 | new_f = pickle.loads(cloudpickle.dumps(f)) 65 | self.assertEqual(self.teststring, new_f.read()) 66 | os.remove(self.tmpfilepath) 67 | 68 | def test_seek(self): 69 | # Write, then seek to arbitrary position 70 | with open(self.tmpfilepath, 'w+') as f: 71 | f.write(self.teststring) 72 | f.seek(4) 73 | unpickled = pickle.loads(cloudpickle.dumps(f)) 74 | # unpickled StringIO is at position 4 75 | self.assertEqual(4, unpickled.tell()) 76 | self.assertEqual(self.teststring[4:], unpickled.read()) 77 | # but unpickled StringIO also contained the start 78 | unpickled.seek(0) 79 | self.assertEqual(self.teststring, unpickled.read()) 80 | os.remove(self.tmpfilepath) 81 | 82 | @pytest.mark.skip(reason="Requires pytest -s to pass") 83 | def test_pickling_special_file_handles(self): 84 | # Warning: if you want to run your tests with nose, add -s option 85 | for out in sys.stdout, sys.stderr: # Regression test for SPARK-3415 86 | self.assertEqual(out, pickle.loads(cloudpickle.dumps(out))) 87 | self.assertRaises(pickle.PicklingError, 88 | lambda: cloudpickle.dumps(sys.stdin)) 89 | 90 | 91 | if __name__ == '__main__': 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /srsly/tests/cloudpickle/mock_local_folder/mod.py: -------------------------------------------------------------------------------- 1 | """ 2 | In the distributed computing setting, this file plays the role of a "local 3 | development" file, e.g. a file that is importable locally, but unimportable in 4 | remote workers. Constructs defined in this file and usually pickled by 5 | reference should instead flagged to cloudpickle for pickling by value: this is 6 | done using the register_pickle_by_value api exposed by cloudpickle. 7 | """ 8 | import typing 9 | 10 | 11 | def local_function(): 12 | return "hello from a function importable locally!" 13 | 14 | 15 | class LocalClass: 16 | def method(self): 17 | return "hello from a class importable locally" 18 | 19 | 20 | LocalT = typing.TypeVar("LocalT") 21 | -------------------------------------------------------------------------------- /srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | 4 | def local_submod_function(): 5 | return "hello from a file located in a locally-importable subfolder!" 6 | 7 | 8 | class LocalSubmodClass: 9 | def method(self): 10 | return "hello from a class located in a locally-importable subfolder!" 11 | 12 | 13 | LocalSubmodT = typing.TypeVar("LocalSubmodT") 14 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/srsly/6c0fe81e2cd30093aa3bb12b9f4df67e4da1ff59/srsly/tests/msgpack/__init__.py -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_buffer.py: -------------------------------------------------------------------------------- 1 | from srsly.msgpack import packb, unpackb 2 | 3 | 4 | def test_unpack_buffer(): 5 | from array import array 6 | 7 | buf = array("b") 8 | buf.frombytes(packb((b"foo", b"bar"))) 9 | obj = unpackb(buf, use_list=1) 10 | assert [b"foo", b"bar"] == obj 11 | 12 | 13 | def test_unpack_bytearray(): 14 | buf = bytearray(packb(("foo", "bar"))) 15 | obj = unpackb(buf, use_list=1) 16 | assert [b"foo", b"bar"] == obj 17 | expected_type = bytes 18 | assert all(type(s) == expected_type for s in obj) 19 | 20 | 21 | def test_unpack_memoryview(): 22 | buf = bytearray(packb(("foo", "bar"))) 23 | view = memoryview(buf) 24 | obj = unpackb(view, use_list=1) 25 | assert [b"foo", b"bar"] == obj 26 | expected_type = bytes 27 | assert all(type(s) == expected_type for s in obj) 28 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_case.py: -------------------------------------------------------------------------------- 1 | from srsly.msgpack import packb, unpackb 2 | 3 | 4 | def check(length, obj): 5 | v = packb(obj) 6 | assert len(v) == length, "%r length should be %r but get %r" % (obj, length, len(v)) 7 | assert unpackb(v, use_list=0) == obj 8 | 9 | 10 | def test_1(): 11 | for o in [ 12 | None, 13 | True, 14 | False, 15 | 0, 16 | 1, 17 | (1 << 6), 18 | (1 << 7) - 1, 19 | -1, 20 | -((1 << 5) - 1), 21 | -(1 << 5), 22 | ]: 23 | check(1, o) 24 | 25 | 26 | def test_2(): 27 | for o in [1 << 7, (1 << 8) - 1, -((1 << 5) + 1), -(1 << 7)]: 28 | check(2, o) 29 | 30 | 31 | def test_3(): 32 | for o in [1 << 8, (1 << 16) - 1, -((1 << 7) + 1), -(1 << 15)]: 33 | check(3, o) 34 | 35 | 36 | def test_5(): 37 | for o in [1 << 16, (1 << 32) - 1, -((1 << 15) + 1), -(1 << 31)]: 38 | check(5, o) 39 | 40 | 41 | def test_9(): 42 | for o in [ 43 | 1 << 32, 44 | (1 << 64) - 1, 45 | -((1 << 31) + 1), 46 | -(1 << 63), 47 | 1.0, 48 | 0.1, 49 | -0.1, 50 | -1.0, 51 | ]: 52 | check(9, o) 53 | 54 | 55 | def check_raw(overhead, num): 56 | check(num + overhead, b" " * num) 57 | 58 | 59 | def test_fixraw(): 60 | check_raw(1, 0) 61 | check_raw(1, (1 << 5) - 1) 62 | 63 | 64 | def test_raw16(): 65 | check_raw(3, 1 << 5) 66 | check_raw(3, (1 << 16) - 1) 67 | 68 | 69 | def test_raw32(): 70 | check_raw(5, 1 << 16) 71 | 72 | 73 | def check_array(overhead, num): 74 | check(num + overhead, (None,) * num) 75 | 76 | 77 | def test_fixarray(): 78 | check_array(1, 0) 79 | check_array(1, (1 << 4) - 1) 80 | 81 | 82 | def test_array16(): 83 | check_array(3, 1 << 4) 84 | check_array(3, (1 << 16) - 1) 85 | 86 | 87 | def test_array32(): 88 | check_array(5, (1 << 16)) 89 | 90 | 91 | def match(obj, buf): 92 | assert packb(obj) == buf 93 | assert unpackb(buf, use_list=0) == obj 94 | 95 | 96 | def test_match(): 97 | cases = [ 98 | (None, b"\xc0"), 99 | (False, b"\xc2"), 100 | (True, b"\xc3"), 101 | (0, b"\x00"), 102 | (127, b"\x7f"), 103 | (128, b"\xcc\x80"), 104 | (256, b"\xcd\x01\x00"), 105 | (-1, b"\xff"), 106 | (-33, b"\xd0\xdf"), 107 | (-129, b"\xd1\xff\x7f"), 108 | ({1: 1}, b"\x81\x01\x01"), 109 | (1.0, b"\xcb\x3f\xf0\x00\x00\x00\x00\x00\x00"), 110 | ((), b"\x90"), 111 | ( 112 | tuple(range(15)), 113 | b"\x9f\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e", 114 | ), 115 | ( 116 | tuple(range(16)), 117 | b"\xdc\x00\x10\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", 118 | ), 119 | ({}, b"\x80"), 120 | ( 121 | dict([(x, x) for x in range(15)]), 122 | b"\x8f\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e", 123 | ), 124 | ( 125 | dict([(x, x) for x in range(16)]), 126 | b"\xde\x00\x10\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e\x0f\x0f", 127 | ), 128 | ] 129 | 130 | for v, p in cases: 131 | match(v, p) 132 | 133 | 134 | def test_unicode(): 135 | assert unpackb(packb("foobar"), use_list=1) == b"foobar" 136 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_except.py: -------------------------------------------------------------------------------- 1 | from pytest import raises 2 | import datetime 3 | from srsly.msgpack import packb, unpackb, Unpacker, FormatError, StackError, OutOfData 4 | 5 | 6 | class DummyException(Exception): 7 | pass 8 | 9 | 10 | def test_raise_on_find_unsupported_value(): 11 | with raises(TypeError): 12 | packb(datetime.datetime.now()) 13 | 14 | 15 | def test_raise_from_object_hook(): 16 | def hook(obj): 17 | raise DummyException 18 | 19 | raises(DummyException, unpackb, packb({}), object_hook=hook) 20 | raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_hook=hook) 21 | raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_pairs_hook=hook) 22 | raises(DummyException, unpackb, packb({"fizz": {"buzz": "spam"}}), object_hook=hook) 23 | raises( 24 | DummyException, 25 | unpackb, 26 | packb({"fizz": {"buzz": "spam"}}), 27 | object_pairs_hook=hook, 28 | ) 29 | 30 | 31 | def test_invalidvalue(): 32 | incomplete = b"\xd9\x97#DL_" # raw8 - length=0x97 33 | with raises(ValueError): 34 | unpackb(incomplete) 35 | 36 | with raises(OutOfData): 37 | unpacker = Unpacker() 38 | unpacker.feed(incomplete) 39 | unpacker.unpack() 40 | 41 | with raises(FormatError): 42 | unpackb(b"\xc1") # (undefined tag) 43 | 44 | with raises(FormatError): 45 | unpackb(b"\x91\xc1") # fixarray(len=1) [ (undefined tag) ] 46 | 47 | with raises(StackError): 48 | unpackb(b"\x91" * 3000) # nested fixarray(len=1) 49 | 50 | 51 | def test_strict_map_key(): 52 | valid = {u"unicode": 1, b"bytes": 2} 53 | packed = packb(valid, use_bin_type=True) 54 | assert valid == unpackb(packed, raw=False, strict_map_key=True) 55 | 56 | invalid = {42: 1} 57 | packed = packb(invalid, use_bin_type=True) 58 | with raises(ValueError): 59 | unpackb(packed, raw=False, strict_map_key=True) 60 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_extension.py: -------------------------------------------------------------------------------- 1 | import array 2 | from srsly import msgpack 3 | from srsly.msgpack.ext import ExtType 4 | 5 | 6 | def test_pack_ext_type(): 7 | def p(s): 8 | packer = msgpack.Packer() 9 | packer.pack_ext_type(0x42, s) 10 | return packer.bytes() 11 | 12 | assert p(b"A") == b"\xd4\x42A" # fixext 1 13 | assert p(b"AB") == b"\xd5\x42AB" # fixext 2 14 | assert p(b"ABCD") == b"\xd6\x42ABCD" # fixext 4 15 | assert p(b"ABCDEFGH") == b"\xd7\x42ABCDEFGH" # fixext 8 16 | assert p(b"A" * 16) == b"\xd8\x42" + b"A" * 16 # fixext 16 17 | assert p(b"ABC") == b"\xc7\x03\x42ABC" # ext 8 18 | assert p(b"A" * 0x0123) == b"\xc8\x01\x23\x42" + b"A" * 0x0123 # ext 16 19 | assert ( 20 | p(b"A" * 0x00012345) == b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345 21 | ) # ext 32 22 | 23 | 24 | def test_unpack_ext_type(): 25 | def check(b, expected): 26 | assert msgpack.unpackb(b) == expected 27 | 28 | check(b"\xd4\x42A", ExtType(0x42, b"A")) # fixext 1 29 | check(b"\xd5\x42AB", ExtType(0x42, b"AB")) # fixext 2 30 | check(b"\xd6\x42ABCD", ExtType(0x42, b"ABCD")) # fixext 4 31 | check(b"\xd7\x42ABCDEFGH", ExtType(0x42, b"ABCDEFGH")) # fixext 8 32 | check(b"\xd8\x42" + b"A" * 16, ExtType(0x42, b"A" * 16)) # fixext 16 33 | check(b"\xc7\x03\x42ABC", ExtType(0x42, b"ABC")) # ext 8 34 | check(b"\xc8\x01\x23\x42" + b"A" * 0x0123, ExtType(0x42, b"A" * 0x0123)) # ext 16 35 | check( 36 | b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345, 37 | ExtType(0x42, b"A" * 0x00012345), 38 | ) # ext 32 39 | 40 | 41 | def test_extension_type(): 42 | def default(obj): 43 | print("default called", obj) 44 | if isinstance(obj, array.array): 45 | typecode = 123 # application specific typecode 46 | data = obj.tobytes() 47 | return ExtType(typecode, data) 48 | raise TypeError("Unknown type object %r" % (obj,)) 49 | 50 | def ext_hook(code, data): 51 | print("ext_hook called", code, data) 52 | assert code == 123 53 | obj = array.array("d") 54 | obj.frombytes(data) 55 | return obj 56 | 57 | obj = [42, b"hello", array.array("d", [1.1, 2.2, 3.3])] 58 | s = msgpack.packb(obj, default=default) 59 | obj2 = msgpack.unpackb(s, ext_hook=ext_hook) 60 | assert obj == obj2 61 | 62 | 63 | def test_overriding_hooks(): 64 | def default(obj): 65 | if isinstance(obj, int): 66 | return {"__type__": "long", "__data__": str(obj)} 67 | else: 68 | return obj 69 | 70 | obj = {"testval": int(1823746192837461928374619)} 71 | refobj = {"testval": default(obj["testval"])} 72 | refout = msgpack.packb(refobj) 73 | assert isinstance(refout, (str, bytes)) 74 | testout = msgpack.packb(obj, default=default) 75 | 76 | assert refout == testout 77 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_format.py: -------------------------------------------------------------------------------- 1 | from srsly.msgpack import unpackb 2 | 3 | 4 | def check(src, should, use_list=0): 5 | assert unpackb(src, use_list=use_list) == should 6 | 7 | 8 | def testSimpleValue(): 9 | check(b"\x93\xc0\xc2\xc3", (None, False, True)) 10 | 11 | 12 | def testFixnum(): 13 | check(b"\x92\x93\x00\x40\x7f\x93\xe0\xf0\xff", ((0, 64, 127), (-32, -16, -1))) 14 | 15 | 16 | def testFixArray(): 17 | check(b"\x92\x90\x91\x91\xc0", ((), ((None,),))) 18 | 19 | 20 | def testFixRaw(): 21 | check(b"\x94\xa0\xa1a\xa2bc\xa3def", (b"", b"a", b"bc", b"def")) 22 | 23 | 24 | def testFixMap(): 25 | check( 26 | b"\x82\xc2\x81\xc0\xc0\xc3\x81\xc0\x80", {False: {None: None}, True: {None: {}}} 27 | ) 28 | 29 | 30 | def testUnsignedInt(): 31 | check( 32 | b"\x99\xcc\x00\xcc\x80\xcc\xff\xcd\x00\x00\xcd\x80\x00" 33 | b"\xcd\xff\xff\xce\x00\x00\x00\x00\xce\x80\x00\x00\x00" 34 | b"\xce\xff\xff\xff\xff", 35 | (0, 128, 255, 0, 32768, 65535, 0, 2147483648, 4294967295), 36 | ) 37 | 38 | 39 | def testSignedInt(): 40 | check( 41 | b"\x99\xd0\x00\xd0\x80\xd0\xff\xd1\x00\x00\xd1\x80\x00" 42 | b"\xd1\xff\xff\xd2\x00\x00\x00\x00\xd2\x80\x00\x00\x00" 43 | b"\xd2\xff\xff\xff\xff", 44 | (0, -128, -1, 0, -32768, -1, 0, -2147483648, -1), 45 | ) 46 | 47 | 48 | def testRaw(): 49 | check( 50 | b"\x96\xda\x00\x00\xda\x00\x01a\xda\x00\x02ab\xdb\x00\x00" 51 | b"\x00\x00\xdb\x00\x00\x00\x01a\xdb\x00\x00\x00\x02ab", 52 | (b"", b"a", b"ab", b"", b"a", b"ab"), 53 | ) 54 | 55 | 56 | def testArray(): 57 | check( 58 | b"\x96\xdc\x00\x00\xdc\x00\x01\xc0\xdc\x00\x02\xc2\xc3\xdd\x00" 59 | b"\x00\x00\x00\xdd\x00\x00\x00\x01\xc0\xdd\x00\x00\x00\x02" 60 | b"\xc2\xc3", 61 | ((), (None,), (False, True), (), (None,), (False, True)), 62 | ) 63 | 64 | 65 | def testMap(): 66 | check( 67 | b"\x96" 68 | b"\xde\x00\x00" 69 | b"\xde\x00\x01\xc0\xc2" 70 | b"\xde\x00\x02\xc0\xc2\xc3\xc2" 71 | b"\xdf\x00\x00\x00\x00" 72 | b"\xdf\x00\x00\x00\x01\xc0\xc2" 73 | b"\xdf\x00\x00\x00\x02\xc0\xc2\xc3\xc2", 74 | ( 75 | {}, 76 | {None: False}, 77 | {True: False, None: False}, 78 | {}, 79 | {None: False}, 80 | {True: False, None: False}, 81 | ), 82 | ) 83 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_limits.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from srsly.msgpack import packb, unpackb, Packer, Unpacker, ExtType 3 | from srsly.msgpack import PackOverflowError, PackValueError, UnpackValueError 4 | 5 | 6 | def test_integer(): 7 | x = -(2 ** 63) 8 | assert unpackb(packb(x)) == x 9 | with pytest.raises(PackOverflowError): 10 | packb(x - 1) 11 | 12 | x = 2 ** 64 - 1 13 | assert unpackb(packb(x)) == x 14 | with pytest.raises(PackOverflowError): 15 | packb(x + 1) 16 | 17 | 18 | def test_array_header(): 19 | packer = Packer() 20 | packer.pack_array_header(2 ** 32 - 1) 21 | with pytest.raises(PackValueError): 22 | packer.pack_array_header(2 ** 32) 23 | 24 | 25 | def test_map_header(): 26 | packer = Packer() 27 | packer.pack_map_header(2 ** 32 - 1) 28 | with pytest.raises(PackValueError): 29 | packer.pack_array_header(2 ** 32) 30 | 31 | 32 | def test_max_str_len(): 33 | d = "x" * 3 34 | packed = packb(d) 35 | 36 | unpacker = Unpacker(max_str_len=3, raw=False) 37 | unpacker.feed(packed) 38 | assert unpacker.unpack() == d 39 | 40 | unpacker = Unpacker(max_str_len=2, raw=False) 41 | with pytest.raises(UnpackValueError): 42 | unpacker.feed(packed) 43 | unpacker.unpack() 44 | 45 | 46 | def test_max_bin_len(): 47 | d = b"x" * 3 48 | packed = packb(d, use_bin_type=True) 49 | 50 | unpacker = Unpacker(max_bin_len=3) 51 | unpacker.feed(packed) 52 | assert unpacker.unpack() == d 53 | 54 | unpacker = Unpacker(max_bin_len=2) 55 | with pytest.raises(UnpackValueError): 56 | unpacker.feed(packed) 57 | unpacker.unpack() 58 | 59 | 60 | def test_max_array_len(): 61 | d = [1, 2, 3] 62 | packed = packb(d) 63 | 64 | unpacker = Unpacker(max_array_len=3) 65 | unpacker.feed(packed) 66 | assert unpacker.unpack() == d 67 | 68 | unpacker = Unpacker(max_array_len=2) 69 | with pytest.raises(UnpackValueError): 70 | unpacker.feed(packed) 71 | unpacker.unpack() 72 | 73 | 74 | def test_max_map_len(): 75 | d = {1: 2, 3: 4, 5: 6} 76 | packed = packb(d) 77 | 78 | unpacker = Unpacker(max_map_len=3) 79 | unpacker.feed(packed) 80 | assert unpacker.unpack() == d 81 | 82 | unpacker = Unpacker(max_map_len=2) 83 | with pytest.raises(UnpackValueError): 84 | unpacker.feed(packed) 85 | unpacker.unpack() 86 | 87 | 88 | def test_max_ext_len(): 89 | d = ExtType(42, b"abc") 90 | packed = packb(d) 91 | 92 | unpacker = Unpacker(max_ext_len=3) 93 | unpacker.feed(packed) 94 | assert unpacker.unpack() == d 95 | 96 | unpacker = Unpacker(max_ext_len=2) 97 | with pytest.raises(UnpackValueError): 98 | unpacker.feed(packed) 99 | unpacker.unpack() 100 | 101 | 102 | # PyPy fails following tests because of constant folding? 103 | # https://bugs.pypy.org/issue1721 104 | # @pytest.mark.skipif(True, reason="Requires very large memory.") 105 | # def test_binary(): 106 | # x = b'x' * (2**32 - 1) 107 | # assert unpackb(packb(x)) == x 108 | # del x 109 | # x = b'x' * (2**32) 110 | # with pytest.raises(ValueError): 111 | # packb(x) 112 | # 113 | # 114 | # @pytest.mark.skipif(True, reason="Requires very large memory.") 115 | # def test_string(): 116 | # x = 'x' * (2**32 - 1) 117 | # assert unpackb(packb(x)) == x 118 | # x += 'y' 119 | # with pytest.raises(ValueError): 120 | # packb(x) 121 | # 122 | # 123 | # @pytest.mark.skipif(True, reason="Requires very large memory.") 124 | # def test_array(): 125 | # x = [0] * (2**32 - 1) 126 | # assert unpackb(packb(x)) == x 127 | # x.append(0) 128 | # with pytest.raises(ValueError): 129 | # packb(x) 130 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_memoryview.py: -------------------------------------------------------------------------------- 1 | from array import array 2 | from srsly.msgpack import packb, unpackb 3 | 4 | 5 | make_memoryview = memoryview 6 | 7 | 8 | def make_array(f, data): 9 | a = array(f) 10 | a.frombytes(data) 11 | return a 12 | 13 | 14 | def get_data(a): 15 | return a.tobytes() 16 | 17 | 18 | def _runtest(format, nbytes, expected_header, expected_prefix, use_bin_type): 19 | # create a new array 20 | original_array = array(format) 21 | original_array.fromlist([255] * (nbytes // original_array.itemsize)) 22 | original_data = get_data(original_array) 23 | view = make_memoryview(original_array) 24 | 25 | # pack, unpack, and reconstruct array 26 | packed = packb(view, use_bin_type=use_bin_type) 27 | unpacked = unpackb(packed) 28 | reconstructed_array = make_array(format, unpacked) 29 | 30 | # check that we got the right amount of data 31 | assert len(original_data) == nbytes 32 | # check packed header 33 | assert packed[:1] == expected_header 34 | # check packed length prefix, if any 35 | assert packed[1 : 1 + len(expected_prefix)] == expected_prefix 36 | # check packed data 37 | assert packed[1 + len(expected_prefix) :] == original_data 38 | # check array unpacked correctly 39 | assert original_array == reconstructed_array 40 | 41 | 42 | def test_fixstr_from_byte(): 43 | _runtest("B", 1, b"\xa1", b"", False) 44 | _runtest("B", 31, b"\xbf", b"", False) 45 | 46 | 47 | def test_fixstr_from_float(): 48 | _runtest("f", 4, b"\xa4", b"", False) 49 | _runtest("f", 28, b"\xbc", b"", False) 50 | 51 | 52 | def test_str16_from_byte(): 53 | _runtest("B", 2 ** 8, b"\xda", b"\x01\x00", False) 54 | _runtest("B", 2 ** 16 - 1, b"\xda", b"\xff\xff", False) 55 | 56 | 57 | def test_str16_from_float(): 58 | _runtest("f", 2 ** 8, b"\xda", b"\x01\x00", False) 59 | _runtest("f", 2 ** 16 - 4, b"\xda", b"\xff\xfc", False) 60 | 61 | 62 | def test_str32_from_byte(): 63 | _runtest("B", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False) 64 | 65 | 66 | def test_str32_from_float(): 67 | _runtest("f", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False) 68 | 69 | 70 | def test_bin8_from_byte(): 71 | _runtest("B", 1, b"\xc4", b"\x01", True) 72 | _runtest("B", 2 ** 8 - 1, b"\xc4", b"\xff", True) 73 | 74 | 75 | def test_bin8_from_float(): 76 | _runtest("f", 4, b"\xc4", b"\x04", True) 77 | _runtest("f", 2 ** 8 - 4, b"\xc4", b"\xfc", True) 78 | 79 | 80 | def test_bin16_from_byte(): 81 | _runtest("B", 2 ** 8, b"\xc5", b"\x01\x00", True) 82 | _runtest("B", 2 ** 16 - 1, b"\xc5", b"\xff\xff", True) 83 | 84 | 85 | def test_bin16_from_float(): 86 | _runtest("f", 2 ** 8, b"\xc5", b"\x01\x00", True) 87 | _runtest("f", 2 ** 16 - 4, b"\xc5", b"\xff\xfc", True) 88 | 89 | 90 | def test_bin32_from_byte(): 91 | _runtest("B", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True) 92 | 93 | 94 | def test_bin32_from_float(): 95 | _runtest("f", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True) 96 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_newspec.py: -------------------------------------------------------------------------------- 1 | from srsly.msgpack import packb, unpackb, ExtType 2 | 3 | 4 | def test_str8(): 5 | header = b"\xd9" 6 | data = b"x" * 32 7 | b = packb(data.decode(), use_bin_type=True) 8 | assert len(b) == len(data) + 2 9 | assert b[0:2] == header + b"\x20" 10 | assert b[2:] == data 11 | assert unpackb(b) == data 12 | 13 | data = b"x" * 255 14 | b = packb(data.decode(), use_bin_type=True) 15 | assert len(b) == len(data) + 2 16 | assert b[0:2] == header + b"\xff" 17 | assert b[2:] == data 18 | assert unpackb(b) == data 19 | 20 | 21 | def test_bin8(): 22 | header = b"\xc4" 23 | data = b"" 24 | b = packb(data, use_bin_type=True) 25 | assert len(b) == len(data) + 2 26 | assert b[0:2] == header + b"\x00" 27 | assert b[2:] == data 28 | assert unpackb(b) == data 29 | 30 | data = b"x" * 255 31 | b = packb(data, use_bin_type=True) 32 | assert len(b) == len(data) + 2 33 | assert b[0:2] == header + b"\xff" 34 | assert b[2:] == data 35 | assert unpackb(b) == data 36 | 37 | 38 | def test_bin16(): 39 | header = b"\xc5" 40 | data = b"x" * 256 41 | b = packb(data, use_bin_type=True) 42 | assert len(b) == len(data) + 3 43 | assert b[0:1] == header 44 | assert b[1:3] == b"\x01\x00" 45 | assert b[3:] == data 46 | assert unpackb(b) == data 47 | 48 | data = b"x" * 65535 49 | b = packb(data, use_bin_type=True) 50 | assert len(b) == len(data) + 3 51 | assert b[0:1] == header 52 | assert b[1:3] == b"\xff\xff" 53 | assert b[3:] == data 54 | assert unpackb(b) == data 55 | 56 | 57 | def test_bin32(): 58 | header = b"\xc6" 59 | data = b"x" * 65536 60 | b = packb(data, use_bin_type=True) 61 | assert len(b) == len(data) + 5 62 | assert b[0:1] == header 63 | assert b[1:5] == b"\x00\x01\x00\x00" 64 | assert b[5:] == data 65 | assert unpackb(b) == data 66 | 67 | 68 | def test_ext(): 69 | def check(ext, packed): 70 | assert packb(ext) == packed 71 | assert unpackb(packed) == ext 72 | 73 | check(ExtType(0x42, b"Z"), b"\xd4\x42Z") # fixext 1 74 | check(ExtType(0x42, b"ZZ"), b"\xd5\x42ZZ") # fixext 2 75 | check(ExtType(0x42, b"Z" * 4), b"\xd6\x42" + b"Z" * 4) # fixext 4 76 | check(ExtType(0x42, b"Z" * 8), b"\xd7\x42" + b"Z" * 8) # fixext 8 77 | check(ExtType(0x42, b"Z" * 16), b"\xd8\x42" + b"Z" * 16) # fixext 16 78 | # ext 8 79 | check(ExtType(0x42, b""), b"\xc7\x00\x42") 80 | check(ExtType(0x42, b"Z" * 255), b"\xc7\xff\x42" + b"Z" * 255) 81 | # ext 16 82 | check(ExtType(0x42, b"Z" * 256), b"\xc8\x01\x00\x42" + b"Z" * 256) 83 | check(ExtType(0x42, b"Z" * 0xFFFF), b"\xc8\xff\xff\x42" + b"Z" * 0xFFFF) 84 | # ext 32 85 | check(ExtType(0x42, b"Z" * 0x10000), b"\xc9\x00\x01\x00\x00\x42" + b"Z" * 0x10000) 86 | # needs large memory 87 | # check(ExtType(0x42, b'Z'*0xffffffff), 88 | # b'\xc9\xff\xff\xff\xff\x42' + b'Z'*0xffffffff) 89 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_pack.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import pytest 3 | from collections import OrderedDict 4 | from io import BytesIO 5 | from srsly.msgpack import packb, unpackb, Unpacker, Packer 6 | 7 | 8 | def check(data, use_list=False): 9 | re = unpackb(packb(data), use_list=use_list) 10 | assert re == data 11 | 12 | 13 | def testPack(): 14 | test_data = [ 15 | 0, 16 | 1, 17 | 127, 18 | 128, 19 | 255, 20 | 256, 21 | 65535, 22 | 65536, 23 | 4294967295, 24 | 4294967296, 25 | -1, 26 | -32, 27 | -33, 28 | -128, 29 | -129, 30 | -32768, 31 | -32769, 32 | -4294967296, 33 | -4294967297, 34 | 1.0, 35 | b"", 36 | b"a", 37 | b"a" * 31, 38 | b"a" * 32, 39 | None, 40 | True, 41 | False, 42 | (), 43 | ((),), 44 | ((), None), 45 | {None: 0}, 46 | (1 << 23), 47 | ] 48 | for td in test_data: 49 | check(td) 50 | 51 | 52 | def testPackUnicode(): 53 | test_data = ["", "abcd", ["defgh"], "Русский текст"] 54 | for td in test_data: 55 | re = unpackb(packb(td), use_list=1, raw=False) 56 | assert re == td 57 | packer = Packer() 58 | data = packer.pack(td) 59 | re = Unpacker(BytesIO(data), raw=False, use_list=1).unpack() 60 | assert re == td 61 | 62 | 63 | def testPackUTF32(): # deprecated 64 | re = unpackb(packb("", encoding="utf-32"), use_list=1, encoding="utf-32") 65 | assert re == "" 66 | re = unpackb(packb("abcd", encoding="utf-32"), use_list=1, encoding="utf-32") 67 | assert re == "abcd" 68 | re = unpackb(packb(["defgh"], encoding="utf-32"), use_list=1, encoding="utf-32") 69 | assert re == ["defgh"] 70 | try: 71 | packb("Русский текст", encoding="utf-32") 72 | except LookupError as e: 73 | pytest.xfail(str(e)) 74 | # try: 75 | # test_data = ["", "abcd", ["defgh"], "Русский текст"] 76 | # for td in test_data: 77 | # except LookupError as e: 78 | # pytest.xfail(e) 79 | 80 | 81 | def testPackBytes(): 82 | test_data = [b"", b"abcd", (b"defgh",)] 83 | for td in test_data: 84 | check(td) 85 | 86 | 87 | def testPackByteArrays(): 88 | test_data = [bytearray(b""), bytearray(b"abcd"), (bytearray(b"defgh"),)] 89 | for td in test_data: 90 | check(td) 91 | 92 | 93 | def testIgnoreUnicodeErrors(): # deprecated 94 | re = unpackb( 95 | packb(b"abc\xeddef"), encoding="utf-8", unicode_errors="ignore", use_list=1 96 | ) 97 | assert re == "abcdef" 98 | 99 | 100 | def testStrictUnicodeUnpack(): 101 | with pytest.raises(UnicodeDecodeError): 102 | unpackb(packb(b"abc\xeddef"), raw=False, use_list=1) 103 | 104 | 105 | def testStrictUnicodePack(): # deprecated 106 | with pytest.raises(UnicodeEncodeError): 107 | packb("abc\xeddef", encoding="ascii", unicode_errors="strict") 108 | 109 | 110 | def testIgnoreErrorsPack(): # deprecated 111 | re = unpackb( 112 | packb("abcФФФdef", encoding="ascii", unicode_errors="ignore"), 113 | raw=False, 114 | use_list=1, 115 | ) 116 | assert re == "abcdef" 117 | 118 | 119 | def testDecodeBinary(): 120 | re = unpackb(packb(b"abc"), encoding=None, use_list=1) 121 | assert re == b"abc" 122 | 123 | 124 | def testPackFloat(): 125 | assert packb(1.0, use_single_float=True) == b"\xca" + struct.pack(str(">f"), 1.0) 126 | assert packb(1.0, use_single_float=False) == b"\xcb" + struct.pack(str(">d"), 1.0) 127 | 128 | 129 | def testArraySize(sizes=[0, 5, 50, 1000]): 130 | bio = BytesIO() 131 | packer = Packer() 132 | for size in sizes: 133 | bio.write(packer.pack_array_header(size)) 134 | for i in range(size): 135 | bio.write(packer.pack(i)) 136 | 137 | bio.seek(0) 138 | unpacker = Unpacker(bio, use_list=1) 139 | for size in sizes: 140 | assert unpacker.unpack() == list(range(size)) 141 | 142 | 143 | def test_manualreset(sizes=[0, 5, 50, 1000]): 144 | packer = Packer(autoreset=False) 145 | for size in sizes: 146 | packer.pack_array_header(size) 147 | for i in range(size): 148 | packer.pack(i) 149 | 150 | bio = BytesIO(packer.bytes()) 151 | unpacker = Unpacker(bio, use_list=1) 152 | for size in sizes: 153 | assert unpacker.unpack() == list(range(size)) 154 | 155 | packer.reset() 156 | assert packer.bytes() == b"" 157 | 158 | 159 | def testMapSize(sizes=[0, 5, 50, 1000]): 160 | bio = BytesIO() 161 | packer = Packer() 162 | for size in sizes: 163 | bio.write(packer.pack_map_header(size)) 164 | for i in range(size): 165 | bio.write(packer.pack(i)) # key 166 | bio.write(packer.pack(i * 2)) # value 167 | 168 | bio.seek(0) 169 | unpacker = Unpacker(bio) 170 | for size in sizes: 171 | assert unpacker.unpack() == dict((i, i * 2) for i in range(size)) 172 | 173 | 174 | def test_odict(): 175 | seq = [(b"one", 1), (b"two", 2), (b"three", 3), (b"four", 4)] 176 | od = OrderedDict(seq) 177 | assert unpackb(packb(od), use_list=1) == dict(seq) 178 | 179 | def pair_hook(seq): 180 | return list(seq) 181 | 182 | assert unpackb(packb(od), object_pairs_hook=pair_hook, use_list=1) == seq 183 | 184 | 185 | def test_pairlist(): 186 | pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")] 187 | packer = Packer() 188 | packed = packer.pack_map_pairs(pairlist) 189 | unpacked = unpackb(packed, object_pairs_hook=list) 190 | assert pairlist == unpacked 191 | 192 | 193 | def test_get_buffer(): 194 | packer = Packer(autoreset=0, use_bin_type=True) 195 | packer.pack([1, 2]) 196 | strm = BytesIO() 197 | strm.write(packer.getbuffer()) 198 | written = strm.getvalue() 199 | 200 | expected = packb([1, 2], use_bin_type=True) 201 | assert written == expected 202 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_read_size.py: -------------------------------------------------------------------------------- 1 | from srsly.msgpack import packb, Unpacker, OutOfData 2 | 3 | 4 | UnexpectedTypeException = ValueError 5 | 6 | 7 | def test_read_array_header(): 8 | unpacker = Unpacker() 9 | unpacker.feed(packb(["a", "b", "c"])) 10 | assert unpacker.read_array_header() == 3 11 | assert unpacker.unpack() == b"a" 12 | assert unpacker.unpack() == b"b" 13 | assert unpacker.unpack() == b"c" 14 | try: 15 | unpacker.unpack() 16 | assert 0, "should raise exception" 17 | except OutOfData: 18 | assert 1, "okay" 19 | 20 | 21 | def test_read_map_header(): 22 | unpacker = Unpacker() 23 | unpacker.feed(packb({"a": "A"})) 24 | assert unpacker.read_map_header() == 1 25 | assert unpacker.unpack() == b"a" 26 | assert unpacker.unpack() == b"A" 27 | try: 28 | unpacker.unpack() 29 | assert 0, "should raise exception" 30 | except OutOfData: 31 | assert 1, "okay" 32 | 33 | 34 | def test_incorrect_type_array(): 35 | unpacker = Unpacker() 36 | unpacker.feed(packb(1)) 37 | try: 38 | unpacker.read_array_header() 39 | assert 0, "should raise exception" 40 | except UnexpectedTypeException: 41 | assert 1, "okay" 42 | 43 | 44 | def test_incorrect_type_map(): 45 | unpacker = Unpacker() 46 | unpacker.feed(packb(1)) 47 | try: 48 | unpacker.read_map_header() 49 | assert 0, "should raise exception" 50 | except UnexpectedTypeException: 51 | assert 1, "okay" 52 | 53 | 54 | def test_correct_type_nested_array(): 55 | unpacker = Unpacker() 56 | unpacker.feed(packb({"a": ["b", "c", "d"]})) 57 | try: 58 | unpacker.read_array_header() 59 | assert 0, "should raise exception" 60 | except UnexpectedTypeException: 61 | assert 1, "okay" 62 | 63 | 64 | def test_incorrect_type_nested_map(): 65 | unpacker = Unpacker() 66 | unpacker.feed(packb([{"a": "b"}])) 67 | try: 68 | unpacker.read_map_header() 69 | assert 0, "should raise exception" 70 | except UnexpectedTypeException: 71 | assert 1, "okay" 72 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_seq.py: -------------------------------------------------------------------------------- 1 | import io 2 | from srsly import msgpack 3 | 4 | 5 | binarydata = bytes(bytearray(range(256))) 6 | 7 | 8 | def gen_binary_data(idx): 9 | return binarydata[: idx % 300] 10 | 11 | 12 | def test_exceeding_unpacker_read_size(): 13 | dumpf = io.BytesIO() 14 | 15 | packer = msgpack.Packer() 16 | 17 | NUMBER_OF_STRINGS = 6 18 | read_size = 16 19 | # 5 ok for read_size=16, while 6 glibc detected *** python: double free or corruption (fasttop): 20 | # 20 ok for read_size=256, while 25 segfaults / glibc detected *** python: double free or corruption (!prev) 21 | # 40 ok for read_size=1024, while 50 introduces errors 22 | # 7000 ok for read_size=1024*1024, while 8000 leads to glibc detected *** python: double free or corruption (!prev): 23 | 24 | for idx in range(NUMBER_OF_STRINGS): 25 | data = gen_binary_data(idx) 26 | dumpf.write(packer.pack(data)) 27 | 28 | f = io.BytesIO(dumpf.getvalue()) 29 | dumpf.close() 30 | 31 | unpacker = msgpack.Unpacker(f, read_size=read_size, use_list=1) 32 | 33 | read_count = 0 34 | for idx, o in enumerate(unpacker): 35 | assert type(o) == bytes 36 | assert o == gen_binary_data(idx) 37 | read_count += 1 38 | 39 | assert read_count == NUMBER_OF_STRINGS 40 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_sequnpack.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pytest 3 | from srsly.msgpack import Unpacker, BufferFull 4 | from srsly.msgpack import pack 5 | from srsly.msgpack.exceptions import OutOfData 6 | 7 | 8 | def test_partialdata(): 9 | unpacker = Unpacker() 10 | unpacker.feed(b"\xa5") 11 | with pytest.raises(StopIteration): 12 | next(iter(unpacker)) 13 | unpacker.feed(b"h") 14 | with pytest.raises(StopIteration): 15 | next(iter(unpacker)) 16 | unpacker.feed(b"a") 17 | with pytest.raises(StopIteration): 18 | next(iter(unpacker)) 19 | unpacker.feed(b"l") 20 | with pytest.raises(StopIteration): 21 | next(iter(unpacker)) 22 | unpacker.feed(b"l") 23 | with pytest.raises(StopIteration): 24 | next(iter(unpacker)) 25 | unpacker.feed(b"o") 26 | assert next(iter(unpacker)) == b"hallo" 27 | 28 | 29 | def test_foobar(): 30 | unpacker = Unpacker(read_size=3, use_list=1) 31 | unpacker.feed(b"foobar") 32 | assert unpacker.unpack() == ord(b"f") 33 | assert unpacker.unpack() == ord(b"o") 34 | assert unpacker.unpack() == ord(b"o") 35 | assert unpacker.unpack() == ord(b"b") 36 | assert unpacker.unpack() == ord(b"a") 37 | assert unpacker.unpack() == ord(b"r") 38 | with pytest.raises(OutOfData): 39 | unpacker.unpack() 40 | 41 | unpacker.feed(b"foo") 42 | unpacker.feed(b"bar") 43 | 44 | k = 0 45 | for o, e in zip(unpacker, "foobarbaz"): 46 | assert o == ord(e) 47 | k += 1 48 | assert k == len(b"foobar") 49 | 50 | 51 | def test_foobar_skip(): 52 | unpacker = Unpacker(read_size=3, use_list=1) 53 | unpacker.feed(b"foobar") 54 | assert unpacker.unpack() == ord(b"f") 55 | unpacker.skip() 56 | assert unpacker.unpack() == ord(b"o") 57 | unpacker.skip() 58 | assert unpacker.unpack() == ord(b"a") 59 | unpacker.skip() 60 | with pytest.raises(OutOfData): 61 | unpacker.unpack() 62 | 63 | 64 | def test_maxbuffersize(): 65 | with pytest.raises(ValueError): 66 | Unpacker(read_size=5, max_buffer_size=3) 67 | unpacker = Unpacker(read_size=3, max_buffer_size=3, use_list=1) 68 | unpacker.feed(b"fo") 69 | with pytest.raises(BufferFull): 70 | unpacker.feed(b"ob") 71 | unpacker.feed(b"o") 72 | assert ord("f") == next(unpacker) 73 | unpacker.feed(b"b") 74 | assert ord("o") == next(unpacker) 75 | assert ord("o") == next(unpacker) 76 | assert ord("b") == next(unpacker) 77 | 78 | 79 | def test_readbytes(): 80 | unpacker = Unpacker(read_size=3) 81 | unpacker.feed(b"foobar") 82 | assert unpacker.unpack() == ord(b"f") 83 | assert unpacker.read_bytes(3) == b"oob" 84 | assert unpacker.unpack() == ord(b"a") 85 | assert unpacker.unpack() == ord(b"r") 86 | 87 | # Test buffer refill 88 | unpacker = Unpacker(io.BytesIO(b"foobar"), read_size=3) 89 | assert unpacker.unpack() == ord(b"f") 90 | assert unpacker.read_bytes(3) == b"oob" 91 | assert unpacker.unpack() == ord(b"a") 92 | assert unpacker.unpack() == ord(b"r") 93 | 94 | 95 | def test_issue124(): 96 | unpacker = Unpacker() 97 | unpacker.feed(b"\xa1?\xa1!") 98 | assert tuple(unpacker) == (b"?", b"!") 99 | assert tuple(unpacker) == () 100 | unpacker.feed(b"\xa1?\xa1") 101 | assert tuple(unpacker) == (b"?",) 102 | assert tuple(unpacker) == () 103 | unpacker.feed(b"!") 104 | assert tuple(unpacker) == (b"!",) 105 | assert tuple(unpacker) == () 106 | 107 | 108 | def test_unpack_tell(): 109 | stream = io.BytesIO() 110 | messages = [2 ** i - 1 for i in range(65)] 111 | messages += [-(2 ** i) for i in range(1, 64)] 112 | messages += [ 113 | b"hello", 114 | b"hello" * 1000, 115 | list(range(20)), 116 | {i: bytes(i) * i for i in range(10)}, 117 | {i: bytes(i) * i for i in range(32)}, 118 | ] 119 | offsets = [] 120 | for m in messages: 121 | pack(m, stream) 122 | offsets.append(stream.tell()) 123 | stream.seek(0) 124 | unpacker = Unpacker(stream) 125 | for m, o in zip(messages, offsets): 126 | m2 = next(unpacker) 127 | assert m == m2 128 | assert o == unpacker.tell() 129 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_stricttype.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from srsly.msgpack import packb, unpackb, ExtType 3 | 4 | 5 | def test_namedtuple(): 6 | T = namedtuple("T", "foo bar") 7 | 8 | def default(o): 9 | if isinstance(o, T): 10 | return dict(o._asdict()) 11 | raise TypeError("Unsupported type %s" % (type(o),)) 12 | 13 | packed = packb(T(1, 42), strict_types=True, use_bin_type=True, default=default) 14 | unpacked = unpackb(packed, raw=False) 15 | assert unpacked == {"foo": 1, "bar": 42} 16 | 17 | 18 | def test_tuple(): 19 | t = ("one", 2, b"three", (4,)) 20 | 21 | def default(o): 22 | if isinstance(o, tuple): 23 | return {"__type__": "tuple", "value": list(o)} 24 | raise TypeError("Unsupported type %s" % (type(o),)) 25 | 26 | def convert(o): 27 | if o.get("__type__") == "tuple": 28 | return tuple(o["value"]) 29 | return o 30 | 31 | data = packb(t, strict_types=True, use_bin_type=True, default=default) 32 | expected = unpackb(data, raw=False, object_hook=convert) 33 | 34 | assert expected == t 35 | 36 | 37 | def test_tuple_ext(): 38 | t = ("one", 2, b"three", (4,)) 39 | 40 | MSGPACK_EXT_TYPE_TUPLE = 0 41 | 42 | def default(o): 43 | if isinstance(o, tuple): 44 | # Convert to list and pack 45 | payload = packb( 46 | list(o), strict_types=True, use_bin_type=True, default=default 47 | ) 48 | return ExtType(MSGPACK_EXT_TYPE_TUPLE, payload) 49 | raise TypeError(repr(o)) 50 | 51 | def convert(code, payload): 52 | if code == MSGPACK_EXT_TYPE_TUPLE: 53 | # Unpack and convert to tuple 54 | return tuple(unpackb(payload, raw=False, ext_hook=convert)) 55 | raise ValueError("Unknown Ext code {}".format(code)) 56 | 57 | data = packb(t, strict_types=True, use_bin_type=True, default=default) 58 | expected = unpackb(data, raw=False, ext_hook=convert) 59 | 60 | assert expected == t 61 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_subtype.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from srsly.msgpack import packb 3 | 4 | 5 | class MyList(list): 6 | pass 7 | 8 | 9 | class MyDict(dict): 10 | pass 11 | 12 | 13 | class MyTuple(tuple): 14 | pass 15 | 16 | 17 | MyNamedTuple = namedtuple("MyNamedTuple", "x y") 18 | 19 | 20 | def test_types(): 21 | assert packb(MyDict()) == packb(dict()) 22 | assert packb(MyList()) == packb(list()) 23 | assert packb(MyNamedTuple(1, 2)) == packb((1, 2)) 24 | -------------------------------------------------------------------------------- /srsly/tests/msgpack/test_unpack.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import sys 3 | import pytest 4 | from srsly.msgpack import Unpacker, packb, OutOfData, ExtType 5 | 6 | 7 | def test_unpack_array_header_from_file(): 8 | f = BytesIO(packb([1, 2, 3, 4])) 9 | unpacker = Unpacker(f) 10 | assert unpacker.read_array_header() == 4 11 | assert unpacker.unpack() == 1 12 | assert unpacker.unpack() == 2 13 | assert unpacker.unpack() == 3 14 | assert unpacker.unpack() == 4 15 | with pytest.raises(OutOfData): 16 | unpacker.unpack() 17 | 18 | 19 | @pytest.mark.skipif( 20 | "not hasattr(sys, 'getrefcount') == True", 21 | reason="sys.getrefcount() is needed to pass this test", 22 | ) 23 | def test_unpacker_hook_refcnt(): 24 | result = [] 25 | 26 | def hook(x): 27 | result.append(x) 28 | return x 29 | 30 | basecnt = sys.getrefcount(hook) 31 | 32 | up = Unpacker(object_hook=hook, list_hook=hook) 33 | 34 | assert sys.getrefcount(hook) >= basecnt + 2 35 | 36 | up.feed(packb([{}])) 37 | up.feed(packb([{}])) 38 | assert up.unpack() == [{}] 39 | assert up.unpack() == [{}] 40 | assert result == [{}, [{}], {}, [{}]] 41 | 42 | del up 43 | 44 | assert sys.getrefcount(hook) == basecnt 45 | 46 | 47 | def test_unpacker_ext_hook(): 48 | class MyUnpacker(Unpacker): 49 | def __init__(self): 50 | super(MyUnpacker, self).__init__(ext_hook=self._hook, raw=False) 51 | 52 | def _hook(self, code, data): 53 | if code == 1: 54 | return int(data) 55 | else: 56 | return ExtType(code, data) 57 | 58 | unpacker = MyUnpacker() 59 | unpacker.feed(packb({"a": 1})) 60 | assert unpacker.unpack() == {"a": 1} 61 | unpacker.feed(packb({"a": ExtType(1, b"123")})) 62 | assert unpacker.unpack() == {"a": 123} 63 | unpacker.feed(packb({"a": ExtType(2, b"321")})) 64 | assert unpacker.unpack() == {"a": ExtType(2, b"321")} 65 | 66 | 67 | if __name__ == "__main__": 68 | test_unpack_array_header_from_file() 69 | test_unpacker_hook_refcnt() 70 | test_unpacker_ext_hook() 71 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/srsly/6c0fe81e2cd30093aa3bb12b9f4df67e4da1ff59/srsly/tests/ruamel_yaml/__init__.py -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_a_dedent.py: -------------------------------------------------------------------------------- 1 | from .roundtrip import dedent 2 | 3 | 4 | class TestDedent: 5 | def test_start_newline(self): 6 | # fmt: off 7 | x = dedent(""" 8 | 123 9 | 456 10 | """) 11 | # fmt: on 12 | assert x == "123\n 456\n" 13 | 14 | def test_start_space_newline(self): 15 | # special construct to prevent stripping of following whitespace 16 | # fmt: off 17 | x = dedent(" " """ 18 | 123 19 | """) 20 | # fmt: on 21 | assert x == "123\n" 22 | 23 | def test_start_no_newline(self): 24 | # special construct to prevent stripping of following whitespac 25 | x = dedent( 26 | """\ 27 | 123 28 | 456 29 | """ 30 | ) 31 | assert x == "123\n 456\n" 32 | 33 | def test_preserve_no_newline_at_end(self): 34 | x = dedent( 35 | """ 36 | 123""" 37 | ) 38 | assert x == "123" 39 | 40 | def test_preserve_no_newline_at_all(self): 41 | x = dedent( 42 | """\ 43 | 123""" 44 | ) 45 | assert x == "123" 46 | 47 | def test_multiple_dedent(self): 48 | x = dedent( 49 | dedent( 50 | """ 51 | 123 52 | """ 53 | ) 54 | ) 55 | assert x == "123\n" 56 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_add_xxx.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import re 4 | import pytest # NOQA 5 | 6 | from .roundtrip import dedent 7 | 8 | 9 | # from PyYAML docs 10 | class Dice(tuple): 11 | def __new__(cls, a, b): 12 | return tuple.__new__(cls, [a, b]) 13 | 14 | def __repr__(self): 15 | return "Dice(%s,%s)" % self 16 | 17 | 18 | def dice_constructor(loader, node): 19 | value = loader.construct_scalar(node) 20 | a, b = map(int, value.split("d")) 21 | return Dice(a, b) 22 | 23 | 24 | def dice_representer(dumper, data): 25 | return dumper.represent_scalar(u"!dice", u"{}d{}".format(*data)) 26 | 27 | 28 | def test_dice_constructor(): 29 | import srsly.ruamel_yaml # NOQA 30 | 31 | srsly.ruamel_yaml.add_constructor(u"!dice", dice_constructor) 32 | with pytest.raises(ValueError): 33 | data = srsly.ruamel_yaml.load( 34 | "initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader 35 | ) 36 | assert str(data) == "{'initial hit points': Dice(8,4)}" 37 | 38 | 39 | def test_dice_constructor_with_loader(): 40 | import srsly.ruamel_yaml # NOQA 41 | 42 | with pytest.raises(ValueError): 43 | srsly.ruamel_yaml.add_constructor( 44 | u"!dice", dice_constructor, Loader=srsly.ruamel_yaml.Loader 45 | ) 46 | data = srsly.ruamel_yaml.load( 47 | "initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader 48 | ) 49 | assert str(data) == "{'initial hit points': Dice(8,4)}" 50 | 51 | 52 | def test_dice_representer(): 53 | import srsly.ruamel_yaml # NOQA 54 | 55 | srsly.ruamel_yaml.add_representer(Dice, dice_representer) 56 | # srsly.ruamel_yaml 0.15.8+ no longer forces quotes tagged scalars 57 | assert ( 58 | srsly.ruamel_yaml.dump(dict(gold=Dice(10, 6)), default_flow_style=False) 59 | == "gold: !dice 10d6\n" 60 | ) 61 | 62 | 63 | def test_dice_implicit_resolver(): 64 | import srsly.ruamel_yaml # NOQA 65 | 66 | pattern = re.compile(r"^\d+d\d+$") 67 | with pytest.raises(ValueError): 68 | srsly.ruamel_yaml.add_implicit_resolver(u"!dice", pattern) 69 | assert ( 70 | srsly.ruamel_yaml.dump(dict(treasure=Dice(10, 20)), default_flow_style=False) 71 | == "treasure: 10d20\n" 72 | ) 73 | assert srsly.ruamel_yaml.load( 74 | "damage: 5d10", Loader=srsly.ruamel_yaml.Loader 75 | ) == dict(damage=Dice(5, 10)) 76 | 77 | 78 | class Obj1(dict): 79 | def __init__(self, suffix): 80 | self._suffix = suffix 81 | self._node = None 82 | 83 | def add_node(self, n): 84 | self._node = n 85 | 86 | def __repr__(self): 87 | return "Obj1(%s->%s)" % (self._suffix, self.items()) 88 | 89 | def dump(self): 90 | return repr(self._node) 91 | 92 | 93 | class YAMLObj1(object): 94 | yaml_tag = u"!obj:" 95 | 96 | @classmethod 97 | def from_yaml(cls, loader, suffix, node): 98 | import srsly.ruamel_yaml # NOQA 99 | 100 | obj1 = Obj1(suffix) 101 | if isinstance(node, srsly.ruamel_yaml.MappingNode): 102 | obj1.add_node(loader.construct_mapping(node)) 103 | else: 104 | raise NotImplementedError 105 | return obj1 106 | 107 | @classmethod 108 | def to_yaml(cls, dumper, data): 109 | return dumper.represent_scalar(cls.yaml_tag + data._suffix, data.dump()) 110 | 111 | 112 | def test_yaml_obj(): 113 | import srsly.ruamel_yaml # NOQA 114 | 115 | srsly.ruamel_yaml.add_representer(Obj1, YAMLObj1.to_yaml) 116 | srsly.ruamel_yaml.add_multi_constructor(YAMLObj1.yaml_tag, YAMLObj1.from_yaml) 117 | with pytest.raises(ValueError): 118 | x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader) 119 | print(x) 120 | assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n""" 121 | 122 | 123 | def test_yaml_obj_with_loader_and_dumper(): 124 | import srsly.ruamel_yaml # NOQA 125 | 126 | srsly.ruamel_yaml.add_representer( 127 | Obj1, YAMLObj1.to_yaml, Dumper=srsly.ruamel_yaml.Dumper 128 | ) 129 | srsly.ruamel_yaml.add_multi_constructor( 130 | YAMLObj1.yaml_tag, YAMLObj1.from_yaml, Loader=srsly.ruamel_yaml.Loader 131 | ) 132 | with pytest.raises(ValueError): 133 | x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader) 134 | # x = srsly.ruamel_yaml.load('!obj:x.2\na: 1') 135 | print(x) 136 | assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n""" 137 | 138 | 139 | # ToDo use nullege to search add_multi_representer and add_path_resolver 140 | # and add some test code 141 | 142 | # Issue 127 reported by Tommy Wang 143 | 144 | 145 | def test_issue_127(): 146 | import srsly.ruamel_yaml # NOQA 147 | 148 | class Ref(srsly.ruamel_yaml.YAMLObject): 149 | yaml_constructor = srsly.ruamel_yaml.RoundTripConstructor 150 | yaml_representer = srsly.ruamel_yaml.RoundTripRepresenter 151 | yaml_tag = u"!Ref" 152 | 153 | def __init__(self, logical_id): 154 | self.logical_id = logical_id 155 | 156 | @classmethod 157 | def from_yaml(cls, loader, node): 158 | return cls(loader.construct_scalar(node)) 159 | 160 | @classmethod 161 | def to_yaml(cls, dumper, data): 162 | if isinstance(data.logical_id, srsly.ruamel_yaml.scalarstring.ScalarString): 163 | style = data.logical_id.style # srsly.ruamel_yaml>0.15.8 164 | else: 165 | style = None 166 | return dumper.represent_scalar(cls.yaml_tag, data.logical_id, style=style) 167 | 168 | document = dedent( 169 | """\ 170 | AList: 171 | - !Ref One 172 | - !Ref 'Two' 173 | - !Ref 174 | Two and a half 175 | BList: [!Ref Three, !Ref "Four"] 176 | CList: 177 | - Five Six 178 | - 'Seven Eight' 179 | """ 180 | ) 181 | data = srsly.ruamel_yaml.round_trip_load(document, preserve_quotes=True) 182 | assert srsly.ruamel_yaml.round_trip_dump( 183 | data, indent=4, block_seq_indent=2 184 | ) == document.replace("\n Two and", " Two and") 185 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_class_register.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | testing of YAML.register_class and @yaml_object 5 | """ 6 | 7 | from .roundtrip import YAML 8 | 9 | 10 | class User0(object): 11 | def __init__(self, name, age): 12 | self.name = name 13 | self.age = age 14 | 15 | 16 | class User1(object): 17 | yaml_tag = u"!user" 18 | 19 | def __init__(self, name, age): 20 | self.name = name 21 | self.age = age 22 | 23 | @classmethod 24 | def to_yaml(cls, representer, node): 25 | return representer.represent_scalar( 26 | cls.yaml_tag, u"{.name}-{.age}".format(node, node) 27 | ) 28 | 29 | @classmethod 30 | def from_yaml(cls, constructor, node): 31 | return cls(*node.value.split("-")) 32 | 33 | 34 | class TestRegisterClass(object): 35 | def test_register_0_rt(self): 36 | yaml = YAML() 37 | yaml.register_class(User0) 38 | ys = """ 39 | - !User0 40 | name: Anthon 41 | age: 18 42 | """ 43 | d = yaml.load(ys) 44 | yaml.dump(d, compare=ys, unordered_lines=True) 45 | 46 | def test_register_0_safe(self): 47 | # default_flow_style = None 48 | yaml = YAML(typ="safe") 49 | yaml.register_class(User0) 50 | ys = """ 51 | - !User0 {age: 18, name: Anthon} 52 | """ 53 | d = yaml.load(ys) 54 | yaml.dump(d, compare=ys) 55 | 56 | def test_register_0_unsafe(self): 57 | # default_flow_style = None 58 | yaml = YAML(typ="unsafe") 59 | yaml.register_class(User0) 60 | ys = """ 61 | - !User0 {age: 18, name: Anthon} 62 | """ 63 | d = yaml.load(ys) 64 | yaml.dump(d, compare=ys) 65 | 66 | def test_register_1_rt(self): 67 | yaml = YAML() 68 | yaml.register_class(User1) 69 | ys = """ 70 | - !user Anthon-18 71 | """ 72 | d = yaml.load(ys) 73 | yaml.dump(d, compare=ys) 74 | 75 | def test_register_1_safe(self): 76 | yaml = YAML(typ="safe") 77 | yaml.register_class(User1) 78 | ys = """ 79 | [!user Anthon-18] 80 | """ 81 | d = yaml.load(ys) 82 | yaml.dump(d, compare=ys) 83 | 84 | def test_register_1_unsafe(self): 85 | yaml = YAML(typ="unsafe") 86 | yaml.register_class(User1) 87 | ys = """ 88 | [!user Anthon-18] 89 | """ 90 | d = yaml.load(ys) 91 | yaml.dump(d, compare=ys) 92 | 93 | 94 | class TestDecorator(object): 95 | def test_decorator_implicit(self): 96 | from srsly.ruamel_yaml import yaml_object 97 | 98 | yml = YAML() 99 | 100 | @yaml_object(yml) 101 | class User2(object): 102 | def __init__(self, name, age): 103 | self.name = name 104 | self.age = age 105 | 106 | ys = """ 107 | - !User2 108 | name: Anthon 109 | age: 18 110 | """ 111 | d = yml.load(ys) 112 | yml.dump(d, compare=ys, unordered_lines=True) 113 | 114 | def test_decorator_explicit(self): 115 | from srsly.ruamel_yaml import yaml_object 116 | 117 | yml = YAML() 118 | 119 | @yaml_object(yml) 120 | class User3(object): 121 | yaml_tag = u"!USER" 122 | 123 | def __init__(self, name, age): 124 | self.name = name 125 | self.age = age 126 | 127 | @classmethod 128 | def to_yaml(cls, representer, node): 129 | return representer.represent_scalar( 130 | cls.yaml_tag, u"{.name}-{.age}".format(node, node) 131 | ) 132 | 133 | @classmethod 134 | def from_yaml(cls, constructor, node): 135 | return cls(*node.value.split("-")) 136 | 137 | ys = """ 138 | - !USER Anthon-18 139 | """ 140 | d = yml.load(ys) 141 | yml.dump(d, compare=ys) 142 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_collections.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | collections.OrderedDict is a new class not supported by PyYAML (issue 83 by Frazer McLean) 5 | 6 | This is now so integrated in Python that it can be mapped to !!omap 7 | 8 | """ 9 | 10 | import pytest # NOQA 11 | 12 | 13 | from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA 14 | 15 | 16 | class TestOrderedDict: 17 | def test_ordereddict(self): 18 | from collections import OrderedDict 19 | import srsly.ruamel_yaml # NOQA 20 | 21 | assert srsly.ruamel_yaml.dump(OrderedDict()) == "!!omap []\n" 22 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_contextmanager.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | """ 6 | testing of anchors and the aliases referring to them 7 | """ 8 | 9 | import sys 10 | import pytest 11 | 12 | 13 | single_doc = """\ 14 | - a: 1 15 | - b: 16 | - 2 17 | - 3 18 | """ 19 | 20 | single_data = [dict(a=1), dict(b=[2, 3])] 21 | 22 | multi_doc = """\ 23 | --- 24 | - abc 25 | - xyz 26 | --- 27 | - a: 1 28 | - b: 29 | - 2 30 | - 3 31 | """ 32 | 33 | multi_doc_data = [["abc", "xyz"], single_data] 34 | 35 | 36 | def get_yaml(): 37 | from srsly.ruamel_yaml import YAML 38 | 39 | return YAML() 40 | 41 | 42 | class TestOldStyle: 43 | def test_single_load(self): 44 | d = get_yaml().load(single_doc) 45 | print(d) 46 | print(type(d[0])) 47 | assert d == single_data 48 | 49 | def test_single_load_no_arg(self): 50 | with pytest.raises(TypeError): 51 | assert get_yaml().load() == single_data 52 | 53 | def test_multi_load(self): 54 | data = list(get_yaml().load_all(multi_doc)) 55 | assert data == multi_doc_data 56 | 57 | def test_single_dump(self, capsys): 58 | get_yaml().dump(single_data, sys.stdout) 59 | out, err = capsys.readouterr() 60 | assert out == single_doc 61 | 62 | def test_multi_dump(self, capsys): 63 | yaml = get_yaml() 64 | yaml.explicit_start = True 65 | yaml.dump_all(multi_doc_data, sys.stdout) 66 | out, err = capsys.readouterr() 67 | assert out == multi_doc 68 | 69 | 70 | class TestContextManager: 71 | def test_single_dump(self, capsys): 72 | from srsly.ruamel_yaml import YAML 73 | 74 | with YAML(output=sys.stdout) as yaml: 75 | yaml.dump(single_data) 76 | out, err = capsys.readouterr() 77 | print(err) 78 | assert out == single_doc 79 | 80 | def test_multi_dump(self, capsys): 81 | from srsly.ruamel_yaml import YAML 82 | 83 | with YAML(output=sys.stdout) as yaml: 84 | yaml.explicit_start = True 85 | yaml.dump(multi_doc_data[0]) 86 | yaml.dump(multi_doc_data[1]) 87 | 88 | out, err = capsys.readouterr() 89 | print(err) 90 | assert out == multi_doc 91 | 92 | # input is not as simple with a context manager 93 | # you need to indicate what you expect hence load and load_all 94 | 95 | # @pytest.mark.xfail(strict=True) 96 | # def test_single_load(self): 97 | # from srsly.ruamel_yaml import YAML 98 | # with YAML(input=single_doc) as yaml: 99 | # assert yaml.load() == single_data 100 | # 101 | # @pytest.mark.xfail(strict=True) 102 | # def test_multi_load(self): 103 | # from srsly.ruamel_yaml import YAML 104 | # with YAML(input=multi_doc) as yaml: 105 | # for idx, data in enumerate(yaml.load()): 106 | # assert data == multi_doc_data[0] 107 | 108 | def test_roundtrip(self, capsys): 109 | from srsly.ruamel_yaml import YAML 110 | 111 | with YAML(output=sys.stdout) as yaml: 112 | yaml.explicit_start = True 113 | for data in yaml.load_all(multi_doc): 114 | yaml.dump(data) 115 | 116 | out, err = capsys.readouterr() 117 | print(err) 118 | assert out == multi_doc 119 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_copy.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Testing copy and deepcopy, instigated by Issue 84 (Peter Amstutz) 5 | """ 6 | 7 | import copy 8 | 9 | import pytest # NOQA 10 | 11 | from .roundtrip import dedent, round_trip_load, round_trip_dump 12 | 13 | 14 | class TestDeepCopy: 15 | def test_preserve_flow_style_simple(self): 16 | x = dedent( 17 | """\ 18 | {foo: bar, baz: quux} 19 | """ 20 | ) 21 | data = round_trip_load(x) 22 | data_copy = copy.deepcopy(data) 23 | y = round_trip_dump(data_copy) 24 | print("x [{}]".format(x)) 25 | print("y [{}]".format(y)) 26 | assert y == x 27 | assert data.fa.flow_style() == data_copy.fa.flow_style() 28 | 29 | def test_deepcopy_flow_style_nested_dict(self): 30 | x = dedent( 31 | """\ 32 | a: {foo: bar, baz: quux} 33 | """ 34 | ) 35 | data = round_trip_load(x) 36 | assert data["a"].fa.flow_style() is True 37 | data_copy = copy.deepcopy(data) 38 | assert data_copy["a"].fa.flow_style() is True 39 | data_copy["a"].fa.set_block_style() 40 | assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style() 41 | assert data["a"].fa._flow_style is True 42 | assert data_copy["a"].fa._flow_style is False 43 | y = round_trip_dump(data_copy) 44 | 45 | print("x [{}]".format(x)) 46 | print("y [{}]".format(y)) 47 | assert y == dedent( 48 | """\ 49 | a: 50 | foo: bar 51 | baz: quux 52 | """ 53 | ) 54 | 55 | def test_deepcopy_flow_style_nested_list(self): 56 | x = dedent( 57 | """\ 58 | a: [1, 2, 3] 59 | """ 60 | ) 61 | data = round_trip_load(x) 62 | assert data["a"].fa.flow_style() is True 63 | data_copy = copy.deepcopy(data) 64 | assert data_copy["a"].fa.flow_style() is True 65 | data_copy["a"].fa.set_block_style() 66 | assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style() 67 | assert data["a"].fa._flow_style is True 68 | assert data_copy["a"].fa._flow_style is False 69 | y = round_trip_dump(data_copy) 70 | 71 | print("x [{}]".format(x)) 72 | print("y [{}]".format(y)) 73 | assert y == dedent( 74 | """\ 75 | a: 76 | - 1 77 | - 2 78 | - 3 79 | """ 80 | ) 81 | 82 | 83 | class TestCopy: 84 | def test_copy_flow_style_nested_dict(self): 85 | x = dedent( 86 | """\ 87 | a: {foo: bar, baz: quux} 88 | """ 89 | ) 90 | data = round_trip_load(x) 91 | assert data["a"].fa.flow_style() is True 92 | data_copy = copy.copy(data) 93 | assert data_copy["a"].fa.flow_style() is True 94 | data_copy["a"].fa.set_block_style() 95 | assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style() 96 | assert data["a"].fa._flow_style is False 97 | assert data_copy["a"].fa._flow_style is False 98 | y = round_trip_dump(data_copy) 99 | z = round_trip_dump(data) 100 | assert y == z 101 | 102 | assert y == dedent( 103 | """\ 104 | a: 105 | foo: bar 106 | baz: quux 107 | """ 108 | ) 109 | 110 | def test_copy_flow_style_nested_list(self): 111 | x = dedent( 112 | """\ 113 | a: [1, 2, 3] 114 | """ 115 | ) 116 | data = round_trip_load(x) 117 | assert data["a"].fa.flow_style() is True 118 | data_copy = copy.copy(data) 119 | assert data_copy["a"].fa.flow_style() is True 120 | data_copy["a"].fa.set_block_style() 121 | assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style() 122 | assert data["a"].fa._flow_style is False 123 | assert data_copy["a"].fa._flow_style is False 124 | y = round_trip_dump(data_copy) 125 | 126 | print("x [{}]".format(x)) 127 | print("y [{}]".format(y)) 128 | assert y == dedent( 129 | """\ 130 | a: 131 | - 1 132 | - 2 133 | - 3 134 | """ 135 | ) 136 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_datetime.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | http://yaml.org/type/timestamp.html specifies the regexp to use 5 | for datetime.date and datetime.datetime construction. Date is simple 6 | but datetime can have 'T' or 't' as well as 'Z' or a timezone offset (in 7 | hours and minutes). This information was originally used to create 8 | a UTC datetime and then discarded 9 | 10 | examples from the above: 11 | 12 | canonical: 2001-12-15T02:59:43.1Z 13 | valid iso8601: 2001-12-14t21:59:43.10-05:00 14 | space separated: 2001-12-14 21:59:43.10 -5 15 | no time zone (Z): 2001-12-15 2:59:43.10 16 | date (00:00:00Z): 2002-12-14 17 | 18 | Please note that a fraction can only be included if not equal to 0 19 | 20 | """ 21 | 22 | import copy 23 | import pytest # NOQA 24 | 25 | from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA 26 | 27 | 28 | class TestDateTime: 29 | def test_date_only(self): 30 | inp = """ 31 | - 2011-10-02 32 | """ 33 | exp = """ 34 | - 2011-10-02 35 | """ 36 | round_trip(inp, exp) 37 | 38 | def test_zero_fraction(self): 39 | inp = """ 40 | - 2011-10-02 16:45:00.0 41 | """ 42 | exp = """ 43 | - 2011-10-02 16:45:00 44 | """ 45 | round_trip(inp, exp) 46 | 47 | def test_long_fraction(self): 48 | inp = """ 49 | - 2011-10-02 16:45:00.1234 # expand with zeros 50 | - 2011-10-02 16:45:00.123456 51 | - 2011-10-02 16:45:00.12345612 # round to microseconds 52 | - 2011-10-02 16:45:00.1234565 # round up 53 | - 2011-10-02 16:45:00.12345678 # round up 54 | """ 55 | exp = """ 56 | - 2011-10-02 16:45:00.123400 # expand with zeros 57 | - 2011-10-02 16:45:00.123456 58 | - 2011-10-02 16:45:00.123456 # round to microseconds 59 | - 2011-10-02 16:45:00.123457 # round up 60 | - 2011-10-02 16:45:00.123457 # round up 61 | """ 62 | round_trip(inp, exp) 63 | 64 | def test_canonical(self): 65 | inp = """ 66 | - 2011-10-02T16:45:00.1Z 67 | """ 68 | exp = """ 69 | - 2011-10-02T16:45:00.100000Z 70 | """ 71 | round_trip(inp, exp) 72 | 73 | def test_spaced_timezone(self): 74 | inp = """ 75 | - 2011-10-02T11:45:00 -5 76 | """ 77 | exp = """ 78 | - 2011-10-02T11:45:00-5 79 | """ 80 | round_trip(inp, exp) 81 | 82 | def test_normal_timezone(self): 83 | round_trip( 84 | """ 85 | - 2011-10-02T11:45:00-5 86 | - 2011-10-02 11:45:00-5 87 | - 2011-10-02T11:45:00-05:00 88 | - 2011-10-02 11:45:00-05:00 89 | """ 90 | ) 91 | 92 | def test_no_timezone(self): 93 | inp = """ 94 | - 2011-10-02 6:45:00 95 | """ 96 | exp = """ 97 | - 2011-10-02 06:45:00 98 | """ 99 | round_trip(inp, exp) 100 | 101 | def test_explicit_T(self): 102 | inp = """ 103 | - 2011-10-02T16:45:00 104 | """ 105 | exp = """ 106 | - 2011-10-02T16:45:00 107 | """ 108 | round_trip(inp, exp) 109 | 110 | def test_explicit_t(self): # to upper 111 | inp = """ 112 | - 2011-10-02t16:45:00 113 | """ 114 | exp = """ 115 | - 2011-10-02T16:45:00 116 | """ 117 | round_trip(inp, exp) 118 | 119 | def test_no_T_multi_space(self): 120 | inp = """ 121 | - 2011-10-02 16:45:00 122 | """ 123 | exp = """ 124 | - 2011-10-02 16:45:00 125 | """ 126 | round_trip(inp, exp) 127 | 128 | def test_iso(self): 129 | round_trip( 130 | """ 131 | - 2011-10-02T15:45:00+01:00 132 | """ 133 | ) 134 | 135 | def test_zero_tz(self): 136 | round_trip( 137 | """ 138 | - 2011-10-02T15:45:00+0 139 | """ 140 | ) 141 | 142 | def test_issue_45(self): 143 | round_trip( 144 | """ 145 | dt: 2016-08-19T22:45:47Z 146 | """ 147 | ) 148 | 149 | def test_deepcopy_datestring(self): 150 | # reported by Quuxplusone, http://stackoverflow.com/a/41577841/1307905 151 | x = dedent( 152 | """\ 153 | foo: 2016-10-12T12:34:56 154 | """ 155 | ) 156 | data = copy.deepcopy(round_trip_load(x)) 157 | assert round_trip_dump(data) == x 158 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_deprecation.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import sys 6 | import pytest # NOQA 7 | 8 | 9 | @pytest.mark.skipif(sys.version_info < (3, 7) or sys.version_info >= (3, 9), 10 | reason='collections not available?') 11 | def test_collections_deprecation(): 12 | with pytest.warns(DeprecationWarning): 13 | from collections import Hashable # NOQA 14 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_documents.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import pytest # NOQA 4 | 5 | from .roundtrip import round_trip, round_trip_load_all 6 | 7 | 8 | class TestDocument: 9 | def test_single_doc_begin_end(self): 10 | inp = """\ 11 | --- 12 | - a 13 | - b 14 | ... 15 | """ 16 | round_trip(inp, explicit_start=True, explicit_end=True) 17 | 18 | def test_multi_doc_begin_end(self): 19 | from srsly.ruamel_yaml import dump_all, RoundTripDumper 20 | 21 | inp = """\ 22 | --- 23 | - a 24 | ... 25 | --- 26 | - b 27 | ... 28 | """ 29 | docs = list(round_trip_load_all(inp)) 30 | assert docs == [["a"], ["b"]] 31 | out = dump_all( 32 | docs, Dumper=RoundTripDumper, explicit_start=True, explicit_end=True 33 | ) 34 | assert out == "---\n- a\n...\n---\n- b\n...\n" 35 | 36 | def test_multi_doc_no_start(self): 37 | inp = """\ 38 | - a 39 | ... 40 | --- 41 | - b 42 | ... 43 | """ 44 | docs = list(round_trip_load_all(inp)) 45 | assert docs == [["a"], ["b"]] 46 | 47 | def test_multi_doc_no_end(self): 48 | inp = """\ 49 | - a 50 | --- 51 | - b 52 | """ 53 | docs = list(round_trip_load_all(inp)) 54 | assert docs == [["a"], ["b"]] 55 | 56 | def test_multi_doc_ends_only(self): 57 | # this is ok in 1.2 58 | inp = """\ 59 | - a 60 | ... 61 | - b 62 | ... 63 | """ 64 | docs = list(round_trip_load_all(inp, version=(1, 2))) 65 | assert docs == [["a"], ["b"]] 66 | 67 | def test_multi_doc_ends_only_1_1(self): 68 | from srsly.ruamel_yaml import parser 69 | 70 | # this is not ok in 1.1 71 | with pytest.raises(parser.ParserError): 72 | inp = """\ 73 | - a 74 | ... 75 | - b 76 | ... 77 | """ 78 | docs = list(round_trip_load_all(inp, version=(1, 1))) 79 | assert docs == [["a"], ["b"]] # not True, but not reached 80 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_float.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | import pytest # NOQA 6 | 7 | from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA 8 | 9 | # http://yaml.org/type/int.html is where underscores in integers are defined 10 | 11 | 12 | class TestFloat: 13 | def test_round_trip_non_exp(self): 14 | data = round_trip( 15 | """\ 16 | - 1.0 17 | - 1.00 18 | - 23.100 19 | - -1.0 20 | - -1.00 21 | - -23.100 22 | - 42. 23 | - -42. 24 | - +42. 25 | - .5 26 | - +.5 27 | - -.5 28 | """ 29 | ) 30 | print(data) 31 | assert 0.999 < data[0] < 1.001 32 | assert 0.999 < data[1] < 1.001 33 | assert 23.099 < data[2] < 23.101 34 | assert 0.999 < -data[3] < 1.001 35 | assert 0.999 < -data[4] < 1.001 36 | assert 23.099 < -data[5] < 23.101 37 | assert 41.999 < data[6] < 42.001 38 | assert 41.999 < -data[7] < 42.001 39 | assert 41.999 < data[8] < 42.001 40 | assert 0.49 < data[9] < 0.51 41 | assert 0.49 < data[10] < 0.51 42 | assert -0.51 < data[11] < -0.49 43 | 44 | def test_round_trip_zeros_0(self): 45 | data = round_trip( 46 | """\ 47 | - 0. 48 | - +0. 49 | - -0. 50 | - 0.0 51 | - +0.0 52 | - -0.0 53 | - 0.00 54 | - +0.00 55 | - -0.00 56 | """ 57 | ) 58 | print(data) 59 | for d in data: 60 | assert -0.00001 < d < 0.00001 61 | 62 | def Xtest_round_trip_non_exp_trailing_dot(self): 63 | data = round_trip( 64 | """\ 65 | """ 66 | ) 67 | print(data) 68 | 69 | def test_yaml_1_1_no_dot(self): 70 | from srsly.ruamel_yaml.error import MantissaNoDotYAML1_1Warning 71 | 72 | with pytest.warns(MantissaNoDotYAML1_1Warning): 73 | round_trip_load( 74 | """\ 75 | %YAML 1.1 76 | --- 77 | - 1e6 78 | """ 79 | ) 80 | 81 | 82 | class TestCalculations(object): 83 | def test_mul_00(self): 84 | # issue 149 reported by jan.brezina@tul.cz 85 | d = round_trip_load( 86 | """\ 87 | - 0.1 88 | """ 89 | ) 90 | d[0] *= -1 91 | x = round_trip_dump(d) 92 | assert x == "- -0.1\n" 93 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_flowsequencekey.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | test flow style sequences as keys roundtrip 5 | 6 | """ 7 | 8 | # import pytest 9 | 10 | from .roundtrip import round_trip # , dedent, round_trip_load, round_trip_dump 11 | 12 | 13 | class TestFlowStyleSequenceKey: 14 | def test_so_39595807(self): 15 | inp = """\ 16 | %YAML 1.2 17 | --- 18 | [2, 3, 4]: 19 | a: 20 | - 1 21 | - 2 22 | b: Hello World! 23 | c: 'Voilà!' 24 | """ 25 | round_trip(inp, preserve_quotes=True, explicit_start=True, version=(1, 2)) 26 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_int.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | import pytest # NOQA 6 | 7 | from .roundtrip import dedent, round_trip_load, round_trip_dump 8 | 9 | # http://yaml.org/type/int.html is where underscores in integers are defined 10 | 11 | 12 | class TestBinHexOct: 13 | def test_calculate(self): 14 | # make sure type, leading zero(s) and underscore are preserved 15 | s = dedent( 16 | """\ 17 | - 42 18 | - 0b101010 19 | - 0x_2a 20 | - 0x2A 21 | - 0o00_52 22 | """ 23 | ) 24 | d = round_trip_load(s) 25 | for idx, elem in enumerate(d): 26 | elem -= 21 27 | d[idx] = elem 28 | for idx, elem in enumerate(d): 29 | elem *= 2 30 | d[idx] = elem 31 | for idx, elem in enumerate(d): 32 | t = elem 33 | elem **= 2 34 | elem //= t 35 | d[idx] = elem 36 | assert round_trip_dump(d) == s 37 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_json_numbers.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import pytest # NOQA 6 | 7 | import json 8 | 9 | 10 | def load(s, typ=float): 11 | import srsly.ruamel_yaml 12 | 13 | x = '{"low": %s }' % (s) 14 | print("input: [%s]" % (s), repr(x)) 15 | # just to check it is loadable json 16 | res = json.loads(x) 17 | assert isinstance(res["low"], typ) 18 | ret_val = srsly.ruamel_yaml.load(x, srsly.ruamel_yaml.RoundTripLoader) 19 | print(ret_val) 20 | return ret_val["low"] 21 | 22 | 23 | class TestJSONNumbers: 24 | # based on http://stackoverflow.com/a/30462009/1307905 25 | # yaml number regex: http://yaml.org/spec/1.2/spec.html#id2804092 26 | # 27 | # -? [1-9] ( \. [0-9]* [1-9] )? ( e [-+] [1-9] [0-9]* )? 28 | # 29 | # which is not a superset of the JSON numbers 30 | def test_json_number_float(self): 31 | for x in ( 32 | y.split("#")[0].strip() 33 | for y in """ 34 | 1.0 # should fail on YAML spec on 1-9 allowed as single digit 35 | -1.0 36 | 1e-06 37 | 3.1e-5 38 | 3.1e+5 39 | 3.1e5 # should fail on YAML spec: no +- after e 40 | """.splitlines() 41 | ): 42 | if not x: 43 | continue 44 | res = load(x) 45 | assert isinstance(res, float) 46 | 47 | def test_json_number_int(self): 48 | for x in ( 49 | y.split("#")[0].strip() 50 | for y in """ 51 | 42 52 | """.splitlines() 53 | ): 54 | if not x: 55 | continue 56 | res = load(x, int) 57 | assert isinstance(res, int) 58 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_line_col.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import pytest # NOQA 4 | 5 | from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA 6 | 7 | 8 | def load(s): 9 | return round_trip_load(dedent(s)) 10 | 11 | 12 | class TestLineCol: 13 | def test_item_00(self): 14 | data = load( 15 | """ 16 | - a 17 | - e 18 | - [b, d] 19 | - c 20 | """ 21 | ) 22 | assert data[2].lc.line == 2 23 | assert data[2].lc.col == 2 24 | 25 | def test_item_01(self): 26 | data = load( 27 | """ 28 | - a 29 | - e 30 | - {x: 3} 31 | - c 32 | """ 33 | ) 34 | assert data[2].lc.line == 2 35 | assert data[2].lc.col == 2 36 | 37 | def test_item_02(self): 38 | data = load( 39 | """ 40 | - a 41 | - e 42 | - !!set {x, y} 43 | - c 44 | """ 45 | ) 46 | assert data[2].lc.line == 2 47 | assert data[2].lc.col == 2 48 | 49 | def test_item_03(self): 50 | data = load( 51 | """ 52 | - a 53 | - e 54 | - !!omap 55 | - x: 1 56 | - y: 3 57 | - c 58 | """ 59 | ) 60 | assert data[2].lc.line == 2 61 | assert data[2].lc.col == 2 62 | 63 | def test_item_04(self): 64 | data = load( 65 | """ 66 | # testing line and column based on SO 67 | # http://stackoverflow.com/questions/13319067/ 68 | - key1: item 1 69 | key2: item 2 70 | - key3: another item 1 71 | key4: another item 2 72 | """ 73 | ) 74 | assert data[0].lc.line == 2 75 | assert data[0].lc.col == 2 76 | assert data[1].lc.line == 4 77 | assert data[1].lc.col == 2 78 | 79 | def test_pos_mapping(self): 80 | data = load( 81 | """ 82 | a: 1 83 | b: 2 84 | c: 3 85 | # comment 86 | klm: 42 87 | d: 4 88 | """ 89 | ) 90 | assert data.lc.key("klm") == (4, 0) 91 | assert data.lc.value("klm") == (4, 5) 92 | 93 | def test_pos_sequence(self): 94 | data = load( 95 | """ 96 | - a 97 | - b 98 | - c 99 | # next one! 100 | - klm 101 | - d 102 | """ 103 | ) 104 | assert data.lc.item(3) == (4, 2) 105 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_none.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | import pytest # NOQA 5 | 6 | 7 | class TestNone: 8 | def test_dump00(self): 9 | import srsly.ruamel_yaml # NOQA 10 | 11 | data = None 12 | s = srsly.ruamel_yaml.round_trip_dump(data) 13 | assert s == "null\n...\n" 14 | d = srsly.ruamel_yaml.round_trip_load(s) 15 | assert d == data 16 | 17 | def test_dump01(self): 18 | import srsly.ruamel_yaml # NOQA 19 | 20 | data = None 21 | s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=True) 22 | assert s == "null\n...\n" 23 | d = srsly.ruamel_yaml.round_trip_load(s) 24 | assert d == data 25 | 26 | def test_dump02(self): 27 | import srsly.ruamel_yaml # NOQA 28 | 29 | data = None 30 | s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=False) 31 | assert s == "null\n...\n" 32 | d = srsly.ruamel_yaml.round_trip_load(s) 33 | assert d == data 34 | 35 | def test_dump03(self): 36 | import srsly.ruamel_yaml # NOQA 37 | 38 | data = None 39 | s = srsly.ruamel_yaml.round_trip_dump(data, explicit_start=True) 40 | assert s == "---\n...\n" 41 | d = srsly.ruamel_yaml.round_trip_load(s) 42 | assert d == data 43 | 44 | def test_dump04(self): 45 | import srsly.ruamel_yaml # NOQA 46 | 47 | data = None 48 | s = srsly.ruamel_yaml.round_trip_dump( 49 | data, explicit_start=True, explicit_end=False 50 | ) 51 | assert s == "---\n...\n" 52 | d = srsly.ruamel_yaml.round_trip_load(s) 53 | assert d == data 54 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_numpy.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function, absolute_import, division, unicode_literals 4 | 5 | try: 6 | import numpy 7 | except: # NOQA 8 | numpy = None 9 | 10 | 11 | def Xtest_numpy(): 12 | import srsly.ruamel_yaml 13 | 14 | if numpy is None: 15 | return 16 | data = numpy.arange(10) 17 | print("data", type(data), data) 18 | 19 | yaml_str = srsly.ruamel_yaml.dump(data) 20 | datb = srsly.ruamel_yaml.load(yaml_str) 21 | print("datb", type(datb), datb) 22 | 23 | print("\nYAML", yaml_str) 24 | assert data == datb 25 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_program_config.py: -------------------------------------------------------------------------------- 1 | import pytest # NOQA 2 | 3 | # import srsly.ruamel_yaml 4 | from .roundtrip import round_trip 5 | 6 | 7 | class TestProgramConfig: 8 | def test_application_arguments(self): 9 | # application configur 10 | round_trip( 11 | """ 12 | args: 13 | username: anthon 14 | passwd: secret 15 | fullname: Anthon van der Neut 16 | tmux: 17 | session-name: test 18 | loop: 19 | wait: 10 20 | """ 21 | ) 22 | 23 | def test_single(self): 24 | # application configuration 25 | round_trip( 26 | """ 27 | # default arguments for the program 28 | args: # needed to prevent comment wrapping 29 | # this should be your username 30 | username: anthon 31 | passwd: secret # this is plaintext don't reuse \ 32 | # important/system passwords 33 | fullname: Anthon van der Neut 34 | tmux: 35 | session-name: test # make sure this doesn't clash with 36 | # other sessions 37 | loop: # looping related defaults 38 | # experiment with the following 39 | wait: 10 40 | # no more argument info to pass 41 | """ 42 | ) 43 | 44 | def test_multi(self): 45 | # application configuration 46 | round_trip( 47 | """ 48 | # default arguments for the program 49 | args: # needed to prevent comment wrapping 50 | # this should be your username 51 | username: anthon 52 | passwd: secret # this is plaintext don't reuse 53 | # important/system passwords 54 | fullname: Anthon van der Neut 55 | tmux: 56 | session-name: test # make sure this doesn't clash with 57 | # other sessions 58 | loop: # looping related defaults 59 | # experiment with the following 60 | wait: 10 61 | # no more argument info to pass 62 | """ 63 | ) 64 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_string.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | """ 6 | various test cases for string scalars in YAML files 7 | '|' for preserved newlines 8 | '>' for folded (newlines become spaces) 9 | 10 | and the chomping modifiers: 11 | '-' for stripping: final line break and any trailing empty lines are excluded 12 | '+' for keeping: final line break and empty lines are preserved 13 | '' for clipping: final line break preserved, empty lines at end not 14 | included in content (no modifier) 15 | 16 | """ 17 | 18 | import pytest 19 | import platform 20 | 21 | # from srsly.ruamel_yaml.compat import ordereddict 22 | from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA 23 | 24 | 25 | class TestLiteralScalarString: 26 | def test_basic_string(self): 27 | round_trip( 28 | """ 29 | a: abcdefg 30 | """ 31 | ) 32 | 33 | def test_quoted_integer_string(self): 34 | round_trip( 35 | """ 36 | a: '12345' 37 | """ 38 | ) 39 | 40 | @pytest.mark.skipif( 41 | platform.python_implementation() == "Jython", 42 | reason="Jython throws RepresenterError", 43 | ) 44 | def test_preserve_string(self): 45 | inp = """ 46 | a: | 47 | abc 48 | def 49 | """ 50 | round_trip(inp, intermediate=dict(a="abc\ndef\n")) 51 | 52 | @pytest.mark.skipif( 53 | platform.python_implementation() == "Jython", 54 | reason="Jython throws RepresenterError", 55 | ) 56 | def test_preserve_string_strip(self): 57 | s = """ 58 | a: |- 59 | abc 60 | def 61 | 62 | """ 63 | round_trip(s, intermediate=dict(a="abc\ndef")) 64 | 65 | @pytest.mark.skipif( 66 | platform.python_implementation() == "Jython", 67 | reason="Jython throws RepresenterError", 68 | ) 69 | def test_preserve_string_keep(self): 70 | # with pytest.raises(AssertionError) as excinfo: 71 | inp = """ 72 | a: |+ 73 | ghi 74 | jkl 75 | 76 | 77 | b: x 78 | """ 79 | round_trip(inp, intermediate=dict(a="ghi\njkl\n\n\n", b="x")) 80 | 81 | @pytest.mark.skipif( 82 | platform.python_implementation() == "Jython", 83 | reason="Jython throws RepresenterError", 84 | ) 85 | def test_preserve_string_keep_at_end(self): 86 | # at EOF you have to specify the ... to get proper "closure" 87 | # of the multiline scalar 88 | inp = """ 89 | a: |+ 90 | ghi 91 | jkl 92 | 93 | ... 94 | """ 95 | round_trip(inp, intermediate=dict(a="ghi\njkl\n\n")) 96 | 97 | def test_fold_string(self): 98 | inp = """ 99 | a: > 100 | abc 101 | def 102 | 103 | """ 104 | round_trip(inp) 105 | 106 | def test_fold_string_strip(self): 107 | inp = """ 108 | a: >- 109 | abc 110 | def 111 | 112 | """ 113 | round_trip(inp) 114 | 115 | def test_fold_string_keep(self): 116 | with pytest.raises(AssertionError) as excinfo: # NOQA 117 | inp = """ 118 | a: >+ 119 | abc 120 | def 121 | 122 | """ 123 | round_trip(inp, intermediate=dict(a="abc def\n\n")) 124 | 125 | 126 | class TestQuotedScalarString: 127 | def test_single_quoted_string(self): 128 | inp = """ 129 | a: 'abc' 130 | """ 131 | round_trip(inp, preserve_quotes=True) 132 | 133 | def test_double_quoted_string(self): 134 | inp = """ 135 | a: "abc" 136 | """ 137 | round_trip(inp, preserve_quotes=True) 138 | 139 | def test_non_preserved_double_quoted_string(self): 140 | inp = """ 141 | a: "abc" 142 | """ 143 | exp = """ 144 | a: abc 145 | """ 146 | round_trip(inp, outp=exp) 147 | 148 | 149 | class TestReplace: 150 | """inspired by issue 110 from sandres23""" 151 | 152 | def test_replace_preserved_scalar_string(self): 153 | import srsly 154 | 155 | s = dedent( 156 | """\ 157 | foo: | 158 | foo 159 | foo 160 | bar 161 | foo 162 | """ 163 | ) 164 | data = round_trip_load(s, preserve_quotes=True) 165 | so = data["foo"].replace("foo", "bar", 2) 166 | assert isinstance(so, srsly.ruamel_yaml.scalarstring.LiteralScalarString) 167 | assert so == dedent( 168 | """ 169 | bar 170 | bar 171 | bar 172 | foo 173 | """ 174 | ) 175 | 176 | def test_replace_double_quoted_scalar_string(self): 177 | import srsly 178 | 179 | s = dedent( 180 | """\ 181 | foo: "foo foo bar foo" 182 | """ 183 | ) 184 | data = round_trip_load(s, preserve_quotes=True) 185 | so = data["foo"].replace("foo", "bar", 2) 186 | assert isinstance(so, srsly.ruamel_yaml.scalarstring.DoubleQuotedScalarString) 187 | assert so == "bar bar bar foo" 188 | 189 | 190 | class TestWalkTree: 191 | def test_basic(self): 192 | from srsly.ruamel_yaml.comments import CommentedMap 193 | from srsly.ruamel_yaml.scalarstring import walk_tree 194 | 195 | data = CommentedMap() 196 | data[1] = "a" 197 | data[2] = "with\nnewline\n" 198 | walk_tree(data) 199 | exp = """\ 200 | 1: a 201 | 2: | 202 | with 203 | newline 204 | """ 205 | assert round_trip_dump(data) == dedent(exp) 206 | 207 | def test_map(self): 208 | from srsly.ruamel_yaml.compat import ordereddict 209 | from srsly.ruamel_yaml.comments import CommentedMap 210 | from srsly.ruamel_yaml.scalarstring import walk_tree, preserve_literal 211 | from srsly.ruamel_yaml.scalarstring import DoubleQuotedScalarString as dq 212 | from srsly.ruamel_yaml.scalarstring import SingleQuotedScalarString as sq 213 | 214 | data = CommentedMap() 215 | data[1] = "a" 216 | data[2] = "with\nnew : line\n" 217 | data[3] = "${abc}" 218 | data[4] = "almost:mapping" 219 | m = ordereddict([("\n", preserve_literal), ("${", sq), (":", dq)]) 220 | walk_tree(data, map=m) 221 | exp = """\ 222 | 1: a 223 | 2: | 224 | with 225 | new : line 226 | 3: '${abc}' 227 | 4: "almost:mapping" 228 | """ 229 | assert round_trip_dump(data) == dedent(exp) 230 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_tag.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import pytest # NOQA 4 | 5 | from .roundtrip import round_trip, round_trip_load, YAML 6 | 7 | 8 | def register_xxx(**kw): 9 | import srsly.ruamel_yaml as yaml 10 | 11 | class XXX(yaml.comments.CommentedMap): 12 | @staticmethod 13 | def yaml_dump(dumper, data): 14 | return dumper.represent_mapping(u"!xxx", data) 15 | 16 | @classmethod 17 | def yaml_load(cls, constructor, node): 18 | data = cls() 19 | yield data 20 | constructor.construct_mapping(node, data) 21 | 22 | yaml.add_constructor(u"!xxx", XXX.yaml_load, constructor=yaml.RoundTripConstructor) 23 | yaml.add_representer(XXX, XXX.yaml_dump, representer=yaml.RoundTripRepresenter) 24 | 25 | 26 | class TestIndentFailures: 27 | def test_tag(self): 28 | round_trip( 29 | """\ 30 | !!python/object:__main__.Developer 31 | name: Anthon 32 | location: Germany 33 | language: python 34 | """ 35 | ) 36 | 37 | def test_full_tag(self): 38 | round_trip( 39 | """\ 40 | !!tag:yaml.org,2002:python/object:__main__.Developer 41 | name: Anthon 42 | location: Germany 43 | language: python 44 | """ 45 | ) 46 | 47 | def test_standard_tag(self): 48 | round_trip( 49 | """\ 50 | !!tag:yaml.org,2002:python/object:map 51 | name: Anthon 52 | location: Germany 53 | language: python 54 | """ 55 | ) 56 | 57 | def test_Y1(self): 58 | round_trip( 59 | """\ 60 | !yyy 61 | name: Anthon 62 | location: Germany 63 | language: python 64 | """ 65 | ) 66 | 67 | def test_Y2(self): 68 | round_trip( 69 | """\ 70 | !!yyy 71 | name: Anthon 72 | location: Germany 73 | language: python 74 | """ 75 | ) 76 | 77 | 78 | class TestRoundTripCustom: 79 | def test_X1(self): 80 | register_xxx() 81 | round_trip( 82 | """\ 83 | !xxx 84 | name: Anthon 85 | location: Germany 86 | language: python 87 | """ 88 | ) 89 | 90 | @pytest.mark.xfail(strict=True) 91 | def test_X_pre_tag_comment(self): 92 | register_xxx() 93 | round_trip( 94 | """\ 95 | - 96 | # hello 97 | !xxx 98 | name: Anthon 99 | location: Germany 100 | language: python 101 | """ 102 | ) 103 | 104 | @pytest.mark.xfail(strict=True) 105 | def test_X_post_tag_comment(self): 106 | register_xxx() 107 | round_trip( 108 | """\ 109 | - !xxx 110 | # hello 111 | name: Anthon 112 | location: Germany 113 | language: python 114 | """ 115 | ) 116 | 117 | def test_scalar_00(self): 118 | # https://stackoverflow.com/a/45967047/1307905 119 | round_trip( 120 | """\ 121 | Outputs: 122 | Vpc: 123 | Value: !Ref: vpc # first tag 124 | Export: 125 | Name: !Sub "${AWS::StackName}-Vpc" # second tag 126 | """ 127 | ) 128 | 129 | 130 | class TestIssue201: 131 | def test_encoded_unicode_tag(self): 132 | round_trip_load( 133 | """ 134 | s: !!python/%75nicode 'abc' 135 | """ 136 | ) 137 | 138 | 139 | class TestImplicitTaggedNodes: 140 | def test_scalar(self): 141 | round_trip( 142 | """\ 143 | - !Scalar abcdefg 144 | """ 145 | ) 146 | 147 | def test_mapping(self): 148 | round_trip( 149 | """\ 150 | - !Mapping {a: 1, b: 2} 151 | """ 152 | ) 153 | 154 | def test_sequence(self): 155 | yaml = YAML() 156 | yaml.brace_single_entry_mapping_in_flow_sequence = True 157 | yaml.mapping_value_align = True 158 | yaml.round_trip( 159 | """ 160 | - !Sequence [a, {b: 1}, {c: {d: 3}}] 161 | """ 162 | ) 163 | 164 | def test_sequence2(self): 165 | yaml = YAML() 166 | yaml.mapping_value_align = True 167 | yaml.round_trip( 168 | """ 169 | - !Sequence [a, b: 1, c: {d: 3}] 170 | """ 171 | ) 172 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_version.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import pytest # NOQA 4 | 5 | from .roundtrip import dedent, round_trip, round_trip_load 6 | 7 | 8 | def load(s, version=None): 9 | import srsly.ruamel_yaml # NOQA 10 | 11 | return srsly.ruamel_yaml.round_trip_load(dedent(s), version) 12 | 13 | 14 | class TestVersions: 15 | def test_explicit_1_2(self): 16 | r = load( 17 | """\ 18 | %YAML 1.2 19 | --- 20 | - 12:34:56 21 | - 012 22 | - 012345678 23 | - 0o12 24 | - on 25 | - off 26 | - yes 27 | - no 28 | - true 29 | """ 30 | ) 31 | assert r[0] == "12:34:56" 32 | assert r[1] == 12 33 | assert r[2] == 12345678 34 | assert r[3] == 10 35 | assert r[4] == "on" 36 | assert r[5] == "off" 37 | assert r[6] == "yes" 38 | assert r[7] == "no" 39 | assert r[8] is True 40 | 41 | def test_explicit_1_1(self): 42 | r = load( 43 | """\ 44 | %YAML 1.1 45 | --- 46 | - 12:34:56 47 | - 012 48 | - 012345678 49 | - 0o12 50 | - on 51 | - off 52 | - yes 53 | - no 54 | - true 55 | """ 56 | ) 57 | assert r[0] == 45296 58 | assert r[1] == 10 59 | assert r[2] == "012345678" 60 | assert r[3] == "0o12" 61 | assert r[4] is True 62 | assert r[5] is False 63 | assert r[6] is True 64 | assert r[7] is False 65 | assert r[8] is True 66 | 67 | def test_implicit_1_2(self): 68 | r = load( 69 | """\ 70 | - 12:34:56 71 | - 12:34:56.78 72 | - 012 73 | - 012345678 74 | - 0o12 75 | - on 76 | - off 77 | - yes 78 | - no 79 | - true 80 | """ 81 | ) 82 | assert r[0] == "12:34:56" 83 | assert r[1] == "12:34:56.78" 84 | assert r[2] == 12 85 | assert r[3] == 12345678 86 | assert r[4] == 10 87 | assert r[5] == "on" 88 | assert r[6] == "off" 89 | assert r[7] == "yes" 90 | assert r[8] == "no" 91 | assert r[9] is True 92 | 93 | def test_load_version_1_1(self): 94 | inp = """\ 95 | - 12:34:56 96 | - 12:34:56.78 97 | - 012 98 | - 012345678 99 | - 0o12 100 | - on 101 | - off 102 | - yes 103 | - no 104 | - true 105 | """ 106 | r = load(inp, version="1.1") 107 | assert r[0] == 45296 108 | assert r[1] == 45296.78 109 | assert r[2] == 10 110 | assert r[3] == "012345678" 111 | assert r[4] == "0o12" 112 | assert r[5] is True 113 | assert r[6] is False 114 | assert r[7] is True 115 | assert r[8] is False 116 | assert r[9] is True 117 | 118 | 119 | class TestIssue62: 120 | # bitbucket issue 62, issue_62 121 | def test_00(self): 122 | import srsly.ruamel_yaml # NOQA 123 | 124 | s = dedent( 125 | """\ 126 | {}# Outside flow collection: 127 | - ::vector 128 | - ": - ()" 129 | - Up, up, and away! 130 | - -123 131 | - http://example.com/foo#bar 132 | # Inside flow collection: 133 | - [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar] 134 | """ 135 | ) 136 | with pytest.raises(srsly.ruamel_yaml.parser.ParserError): 137 | round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True) 138 | round_trip(s.format(""), preserve_quotes=True) 139 | 140 | def test_00_single_comment(self): 141 | import srsly.ruamel_yaml # NOQA 142 | 143 | s = dedent( 144 | """\ 145 | {}# Outside flow collection: 146 | - ::vector 147 | - ": - ()" 148 | - Up, up, and away! 149 | - -123 150 | - http://example.com/foo#bar 151 | - [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar] 152 | """ 153 | ) 154 | with pytest.raises(srsly.ruamel_yaml.parser.ParserError): 155 | round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True) 156 | round_trip(s.format(""), preserve_quotes=True) 157 | # round_trip(s.format('%YAML 1.2\n---\n'), preserve_quotes=True, version=(1, 2)) 158 | 159 | def test_01(self): 160 | import srsly.ruamel_yaml # NOQA 161 | 162 | s = dedent( 163 | """\ 164 | {}[random plain value that contains a ? character] 165 | """ 166 | ) 167 | with pytest.raises(srsly.ruamel_yaml.parser.ParserError): 168 | round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True) 169 | round_trip(s.format(""), preserve_quotes=True) 170 | # note the flow seq on the --- line! 171 | round_trip(s.format("%YAML 1.2\n--- "), preserve_quotes=True, version="1.2") 172 | 173 | def test_so_45681626(self): 174 | # was not properly parsing 175 | round_trip_load('{"in":{},"out":{}}') 176 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_yamlobject.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import print_function 4 | 5 | import sys 6 | import pytest # NOQA 7 | 8 | from .roundtrip import save_and_run # NOQA 9 | 10 | 11 | def test_monster(tmpdir): 12 | program_src = u'''\ 13 | import srsly.ruamel_yaml 14 | from textwrap import dedent 15 | 16 | class Monster(srsly.ruamel_yaml.YAMLObject): 17 | yaml_tag = u'!Monster' 18 | 19 | def __init__(self, name, hp, ac, attacks): 20 | self.name = name 21 | self.hp = hp 22 | self.ac = ac 23 | self.attacks = attacks 24 | 25 | def __repr__(self): 26 | return "%s(name=%r, hp=%r, ac=%r, attacks=%r)" % ( 27 | self.__class__.__name__, self.name, self.hp, self.ac, self.attacks) 28 | 29 | data = srsly.ruamel_yaml.load(dedent("""\\ 30 | --- !Monster 31 | name: Cave spider 32 | hp: [2,6] # 2d6 33 | ac: 16 34 | attacks: [BITE, HURT] 35 | """), Loader=srsly.ruamel_yaml.Loader) 36 | # normal dump, keys will be sorted 37 | assert srsly.ruamel_yaml.dump(data) == dedent("""\\ 38 | !Monster 39 | ac: 16 40 | attacks: [BITE, HURT] 41 | hp: [2, 6] 42 | name: Cave spider 43 | """) 44 | ''' 45 | assert save_and_run(program_src, tmpdir) == 1 46 | 47 | 48 | @pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__") 49 | def test_qualified_name00(tmpdir): 50 | """issue 214""" 51 | program_src = u"""\ 52 | from srsly.ruamel_yaml import YAML 53 | from srsly.ruamel_yaml.compat import StringIO 54 | 55 | class A: 56 | def f(self): 57 | pass 58 | 59 | yaml = YAML(typ='unsafe', pure=True) 60 | yaml.explicit_end = True 61 | buf = StringIO() 62 | yaml.dump(A.f, buf) 63 | res = buf.getvalue() 64 | print('res', repr(res)) 65 | assert res == "!!python/name:__main__.A.f ''\\n...\\n" 66 | x = yaml.load(res) 67 | assert x == A.f 68 | """ 69 | assert save_and_run(program_src, tmpdir) == 1 70 | 71 | 72 | @pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__") 73 | def test_qualified_name01(tmpdir): 74 | """issue 214""" 75 | from srsly.ruamel_yaml import YAML 76 | import srsly.ruamel_yaml.comments 77 | from srsly.ruamel_yaml.compat import StringIO 78 | 79 | with pytest.raises(ValueError): 80 | yaml = YAML(typ="unsafe", pure=True) 81 | yaml.explicit_end = True 82 | buf = StringIO() 83 | yaml.dump(srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor, buf) 84 | res = buf.getvalue() 85 | assert ( 86 | res 87 | == "!!python/name:srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor ''\n...\n" 88 | ) 89 | x = yaml.load(res) 90 | assert x == srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor 91 | -------------------------------------------------------------------------------- /srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import sys 4 | import pytest # NOQA 5 | 6 | from .roundtrip import round_trip_load, round_trip_dump, dedent 7 | 8 | 9 | class TestLeftOverDebug: 10 | # idea here is to capture round_trip_output via pytest stdout capture 11 | # if there is are any leftover debug statements they should show up 12 | def test_00(self, capsys): 13 | s = dedent( 14 | """ 15 | a: 1 16 | b: [] 17 | c: [a, 1] 18 | d: {f: 3.14, g: 42} 19 | """ 20 | ) 21 | d = round_trip_load(s) 22 | round_trip_dump(d, sys.stdout) 23 | out, err = capsys.readouterr() 24 | assert out == s 25 | 26 | def test_01(self, capsys): 27 | s = dedent( 28 | """ 29 | - 1 30 | - [] 31 | - [a, 1] 32 | - {f: 3.14, g: 42} 33 | - - 123 34 | """ 35 | ) 36 | d = round_trip_load(s) 37 | round_trip_dump(d, sys.stdout) 38 | out, err = capsys.readouterr() 39 | assert out == s 40 | -------------------------------------------------------------------------------- /srsly/tests/test_msgpack_api.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pathlib import Path 3 | import datetime 4 | from mock import patch 5 | import numpy 6 | 7 | from .._msgpack_api import read_msgpack, write_msgpack 8 | from .._msgpack_api import msgpack_loads, msgpack_dumps 9 | from .._msgpack_api import msgpack_encoders, msgpack_decoders 10 | from .util import make_tempdir 11 | 12 | 13 | def test_msgpack_dumps(): 14 | data = {"hello": "world", "test": 123} 15 | expected = [b"\x82\xa5hello\xa5world\xa4test{", b"\x82\xa4test{\xa5hello\xa5world"] 16 | msg = msgpack_dumps(data) 17 | assert msg in expected 18 | 19 | 20 | def test_msgpack_loads(): 21 | msg = b"\x82\xa5hello\xa5world\xa4test{" 22 | data = msgpack_loads(msg) 23 | assert len(data) == 2 24 | assert data["hello"] == "world" 25 | assert data["test"] == 123 26 | 27 | 28 | def test_read_msgpack_file(): 29 | file_contents = b"\x81\xa5hello\xa5world" 30 | with make_tempdir({"tmp.msg": file_contents}, mode="wb") as temp_dir: 31 | file_path = temp_dir / "tmp.msg" 32 | assert file_path.exists() 33 | data = read_msgpack(file_path) 34 | assert len(data) == 1 35 | assert data["hello"] == "world" 36 | 37 | 38 | def test_read_msgpack_file_invalid(): 39 | file_contents = b"\xa5hello\xa5world" 40 | with make_tempdir({"tmp.msg": file_contents}, mode="wb") as temp_dir: 41 | file_path = temp_dir / "tmp.msg" 42 | assert file_path.exists() 43 | with pytest.raises(ValueError): 44 | read_msgpack(file_path) 45 | 46 | 47 | def test_write_msgpack_file(): 48 | data = {"hello": "world", "test": 123} 49 | expected = [b"\x82\xa5hello\xa5world\xa4test{", b"\x82\xa4test{\xa5hello\xa5world"] 50 | with make_tempdir(mode="wb") as temp_dir: 51 | file_path = temp_dir / "tmp.msg" 52 | write_msgpack(file_path, data) 53 | with Path(file_path).open("rb") as f: 54 | assert f.read() in expected 55 | 56 | 57 | @patch("srsly.msgpack._msgpack_numpy.np", None) 58 | @patch("srsly.msgpack._msgpack_numpy.has_numpy", False) 59 | def test_msgpack_without_numpy(): 60 | """Test that msgpack works without numpy and raises correct errors (e.g. 61 | when serializing datetime objects, the error should be msgpack's TypeError, 62 | not a "'np' is not defined error").""" 63 | with pytest.raises(TypeError): 64 | msgpack_loads(msgpack_dumps(datetime.datetime.now())) 65 | 66 | 67 | def test_msgpack_custom_encoder_decoder(): 68 | class CustomObject: 69 | def __init__(self, value): 70 | self.value = value 71 | 72 | def serialize_obj(obj, chain=None): 73 | if isinstance(obj, CustomObject): 74 | return {"__custom__": obj.value} 75 | return obj if chain is None else chain(obj) 76 | 77 | def deserialize_obj(obj, chain=None): 78 | if "__custom__" in obj: 79 | return CustomObject(obj["__custom__"]) 80 | return obj if chain is None else chain(obj) 81 | 82 | data = {"a": 123, "b": CustomObject({"foo": "bar"})} 83 | with pytest.raises(TypeError): 84 | msgpack_dumps(data) 85 | 86 | # Register custom encoders/decoders to handle CustomObject 87 | msgpack_encoders.register("custom_object", func=serialize_obj) 88 | msgpack_decoders.register("custom_object", func=deserialize_obj) 89 | bytes_data = msgpack_dumps(data) 90 | new_data = msgpack_loads(bytes_data) 91 | assert new_data["a"] == 123 92 | assert isinstance(new_data["b"], CustomObject) 93 | assert new_data["b"].value == {"foo": "bar"} 94 | # Test that it also works with combinations of encoders/decoders (e.g. numpy) 95 | data = {"a": numpy.zeros((1, 2, 3)), "b": CustomObject({"foo": "bar"})} 96 | bytes_data = msgpack_dumps(data) 97 | new_data = msgpack_loads(bytes_data) 98 | assert isinstance(new_data["a"], numpy.ndarray) 99 | assert isinstance(new_data["b"], CustomObject) 100 | assert new_data["b"].value == {"foo": "bar"} 101 | -------------------------------------------------------------------------------- /srsly/tests/test_pickle_api.py: -------------------------------------------------------------------------------- 1 | from .._pickle_api import pickle_dumps, pickle_loads 2 | 3 | 4 | def test_pickle_dumps(): 5 | data = {"hello": "world", "test": 123} 6 | expected = [ 7 | b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x05hello\x94\x8c\x05world\x94\x8c\x04test\x94K{u.", 8 | b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x04test\x94K{\x8c\x05hello\x94\x8c\x05world\x94u.", 9 | b"\x80\x02}q\x00(X\x04\x00\x00\x00testq\x01K{X\x05\x00\x00\x00helloq\x02X\x05\x00\x00\x00worldq\x03u.", 10 | b"\x80\x05\x95\x1e\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x05hello\x94\x8c\x05world\x94\x8c\x04test\x94K{u.", 11 | ] 12 | msg = pickle_dumps(data) 13 | assert msg in expected 14 | 15 | 16 | def test_pickle_loads(): 17 | msg = pickle_dumps({"hello": "world", "test": 123}) 18 | data = pickle_loads(msg) 19 | assert len(data) == 2 20 | assert data["hello"] == "world" 21 | assert data["test"] == 123 22 | -------------------------------------------------------------------------------- /srsly/tests/test_yaml_api.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | from pathlib import Path 3 | import pytest 4 | 5 | from .._yaml_api import yaml_dumps, yaml_loads, read_yaml, write_yaml 6 | from .._yaml_api import is_yaml_serializable 7 | from ..ruamel_yaml.comments import CommentedMap 8 | from .util import make_tempdir 9 | 10 | 11 | def test_yaml_dumps(): 12 | data = {"a": [1, "hello"], "b": {"foo": "bar", "baz": [10.5, 120]}} 13 | result = yaml_dumps(data) 14 | expected = "a:\n - 1\n - hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n" 15 | assert result == expected 16 | 17 | 18 | def test_yaml_dumps_indent(): 19 | data = {"a": [1, "hello"], "b": {"foo": "bar", "baz": [10.5, 120]}} 20 | result = yaml_dumps(data, indent_mapping=2, indent_sequence=2, indent_offset=0) 21 | expected = "a:\n- 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n" 22 | assert result == expected 23 | 24 | 25 | def test_yaml_loads(): 26 | data = "a:\n- 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n" 27 | result = yaml_loads(data) 28 | # Check that correct loader is used and result is regular dict, not the 29 | # custom ruamel.yaml "ordereddict" class 30 | assert not isinstance(result, CommentedMap) 31 | assert result == {"a": [1, "hello"], "b": {"foo": "bar", "baz": [10.5, 120]}} 32 | 33 | 34 | def test_read_yaml_file(): 35 | file_contents = "a:\n- 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n" 36 | with make_tempdir({"tmp.yaml": file_contents}) as temp_dir: 37 | file_path = temp_dir / "tmp.yaml" 38 | assert file_path.exists() 39 | data = read_yaml(file_path) 40 | assert len(data) == 2 41 | assert data["a"] == [1, "hello"] 42 | 43 | 44 | def test_read_yaml_file_invalid(): 45 | file_contents = "a: - 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n" 46 | with make_tempdir({"tmp.yaml": file_contents}) as temp_dir: 47 | file_path = temp_dir / "tmp.yaml" 48 | assert file_path.exists() 49 | with pytest.raises(ValueError): 50 | read_yaml(file_path) 51 | 52 | 53 | def test_read_yaml_stdin(monkeypatch): 54 | input_data = "a:\n - 1\n - hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n" 55 | monkeypatch.setattr("sys.stdin", StringIO(input_data)) 56 | data = read_yaml("-") 57 | assert len(data) == 2 58 | assert data["a"] == [1, "hello"] 59 | 60 | 61 | def test_write_yaml_file(): 62 | data = {"hello": "world", "test": [123, 456]} 63 | expected = "hello: world\ntest:\n - 123\n - 456\n" 64 | with make_tempdir() as temp_dir: 65 | file_path = temp_dir / "tmp.yaml" 66 | write_yaml(file_path, data) 67 | with Path(file_path).open("r", encoding="utf8") as f: 68 | assert f.read() == expected 69 | 70 | 71 | def test_write_yaml_stdout(capsys): 72 | data = {"hello": "world", "test": [123, 456]} 73 | expected = "hello: world\ntest:\n - 123\n - 456\n\n" 74 | write_yaml("-", data) 75 | captured = capsys.readouterr() 76 | assert captured.out == expected 77 | 78 | 79 | @pytest.mark.parametrize( 80 | "obj,expected", 81 | [ 82 | (["a", "b", 1, 2], True), 83 | ({"a": "b", "c": 123}, True), 84 | ("hello", True), 85 | (lambda x: x, False), 86 | ({"a": lambda x: x}, False), 87 | ], 88 | ) 89 | def test_is_yaml_serializable(obj, expected): 90 | assert is_yaml_serializable(obj) == expected 91 | # Check again to be sure it's consistent 92 | assert is_yaml_serializable(obj) == expected 93 | -------------------------------------------------------------------------------- /srsly/tests/ujson/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/srsly/6c0fe81e2cd30093aa3bb12b9f4df67e4da1ff59/srsly/tests/ujson/__init__.py -------------------------------------------------------------------------------- /srsly/tests/util.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | from contextlib import contextmanager 4 | import shutil 5 | 6 | 7 | @contextmanager 8 | def make_tempdir(files={}, mode="w"): 9 | temp_dir_str = tempfile.mkdtemp() 10 | temp_dir = Path(temp_dir_str) 11 | for name, content in files.items(): 12 | path = temp_dir / name 13 | with path.open(mode) as file_: 14 | file_.write(content) 15 | yield temp_dir 16 | shutil.rmtree(temp_dir_str) 17 | -------------------------------------------------------------------------------- /srsly/ujson/__init__.py: -------------------------------------------------------------------------------- 1 | from .ujson import decode, encode, dump, dumps, load, loads # noqa: F401 2 | -------------------------------------------------------------------------------- /srsly/ujson/lib/dconv_wrapper.cc: -------------------------------------------------------------------------------- 1 | #include "double-conversion.h" 2 | 3 | namespace double_conversion 4 | { 5 | static StringToDoubleConverter* s2d_instance = NULL; 6 | static DoubleToStringConverter* d2s_instance = NULL; 7 | 8 | extern "C" 9 | { 10 | void dconv_d2s_init(int flags, 11 | const char* infinity_symbol, 12 | const char* nan_symbol, 13 | char exponent_character, 14 | int decimal_in_shortest_low, 15 | int decimal_in_shortest_high, 16 | int max_leading_padding_zeroes_in_precision_mode, 17 | int max_trailing_padding_zeroes_in_precision_mode) 18 | { 19 | d2s_instance = new DoubleToStringConverter(flags, infinity_symbol, nan_symbol, 20 | exponent_character, decimal_in_shortest_low, 21 | decimal_in_shortest_high, max_leading_padding_zeroes_in_precision_mode, 22 | max_trailing_padding_zeroes_in_precision_mode); 23 | } 24 | 25 | int dconv_d2s(double value, char* buf, int buflen, int* strlength) 26 | { 27 | StringBuilder sb(buf, buflen); 28 | int success = static_cast(d2s_instance->ToShortest(value, &sb)); 29 | *strlength = success ? sb.position() : -1; 30 | return success; 31 | } 32 | 33 | void dconv_d2s_free() 34 | { 35 | delete d2s_instance; 36 | d2s_instance = NULL; 37 | } 38 | 39 | void dconv_s2d_init(int flags, double empty_string_value, 40 | double junk_string_value, const char* infinity_symbol, 41 | const char* nan_symbol) 42 | { 43 | s2d_instance = new StringToDoubleConverter(flags, empty_string_value, 44 | junk_string_value, infinity_symbol, nan_symbol); 45 | } 46 | 47 | double dconv_s2d(const char* buffer, int length, int* processed_characters_count) 48 | { 49 | return s2d_instance->StringToDouble(buffer, length, processed_characters_count); 50 | } 51 | 52 | void dconv_s2d_free() 53 | { 54 | delete s2d_instance; 55 | s2d_instance = NULL; 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /srsly/ujson/py_defines.h: -------------------------------------------------------------------------------- 1 | /* 2 | Developed by ESN, an Electronic Arts Inc. studio. 3 | Copyright (c) 2014, Electronic Arts Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | * Neither the name of ESN, Electronic Arts Inc. nor the 14 | names of its contributors may be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | 29 | Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) 30 | http://code.google.com/p/stringencoders/ 31 | Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. 32 | 33 | Numeric decoder derived from from TCL library 34 | http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms 35 | * Copyright (c) 1988-1993 The Regents of the University of California. 36 | * Copyright (c) 1994 Sun Microsystems, Inc. 37 | */ 38 | 39 | #include 40 | 41 | #if PY_MAJOR_VERSION >= 3 42 | 43 | #define PyInt_Check PyLong_Check 44 | #define PyInt_AS_LONG PyLong_AsLong 45 | #define PyInt_FromLong PyLong_FromLong 46 | 47 | #define PyString_Check PyBytes_Check 48 | #define PyString_GET_SIZE PyBytes_GET_SIZE 49 | #define PyString_AS_STRING PyBytes_AS_STRING 50 | 51 | #define PyString_FromString PyUnicode_FromString 52 | 53 | #endif 54 | -------------------------------------------------------------------------------- /srsly/ujson/ujson.c: -------------------------------------------------------------------------------- 1 | /* 2 | Developed by ESN, an Electronic Arts Inc. studio. 3 | Copyright (c) 2014, Electronic Arts Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | * Neither the name of ESN, Electronic Arts Inc. nor the 14 | names of its contributors may be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | 29 | Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) 30 | http://code.google.com/p/stringencoders/ 31 | Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. 32 | 33 | Numeric decoder derived from from TCL library 34 | http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms 35 | * Copyright (c) 1988-1993 The Regents of the University of California. 36 | * Copyright (c) 1994 Sun Microsystems, Inc. 37 | */ 38 | 39 | #include "py_defines.h" 40 | #include "version.h" 41 | 42 | /* objToJSON */ 43 | PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs); 44 | void initObjToJSON(void); 45 | 46 | /* JSONToObj */ 47 | PyObject* JSONToObj(PyObject* self, PyObject *args, PyObject *kwargs); 48 | 49 | /* objToJSONFile */ 50 | PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs); 51 | 52 | /* JSONFileToObj */ 53 | PyObject* JSONFileToObj(PyObject* self, PyObject *args, PyObject *kwargs); 54 | 55 | 56 | #define ENCODER_HELP_TEXT "Use ensure_ascii=false to output UTF-8. Pass in double_precision to alter the maximum digit precision of doubles. Set encode_html_chars=True to encode < > & as unicode escape sequences. Set escape_forward_slashes=False to prevent escaping / characters." 57 | 58 | static PyMethodDef ujsonMethods[] = { 59 | {"encode", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT}, 60 | {"decode", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."}, 61 | {"dumps", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT}, 62 | {"loads", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."}, 63 | {"dump", (PyCFunction) objToJSONFile, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON file. " ENCODER_HELP_TEXT}, 64 | {"load", (PyCFunction) JSONFileToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as file to dict object structure. Use precise_float=True to use high precision float decoder."}, 65 | {NULL, NULL, 0, NULL} /* Sentinel */ 66 | }; 67 | 68 | #if PY_MAJOR_VERSION >= 3 69 | 70 | static struct PyModuleDef moduledef = { 71 | PyModuleDef_HEAD_INIT, 72 | "ujson", 73 | 0, /* m_doc */ 74 | -1, /* m_size */ 75 | ujsonMethods, /* m_methods */ 76 | NULL, /* m_reload */ 77 | NULL, /* m_traverse */ 78 | NULL, /* m_clear */ 79 | NULL /* m_free */ 80 | }; 81 | 82 | #define PYMODINITFUNC PyObject *PyInit_ujson(void) 83 | #define PYMODULE_CREATE() PyModule_Create(&moduledef) 84 | #define MODINITERROR return NULL 85 | 86 | #else 87 | 88 | #define PYMODINITFUNC PyMODINIT_FUNC initujson(void) 89 | #define PYMODULE_CREATE() Py_InitModule("ujson", ujsonMethods) 90 | #define MODINITERROR return 91 | 92 | #endif 93 | 94 | PYMODINITFUNC 95 | { 96 | PyObject *module; 97 | PyObject *version_string; 98 | 99 | initObjToJSON(); 100 | module = PYMODULE_CREATE(); 101 | 102 | if (module == NULL) 103 | { 104 | MODINITERROR; 105 | } 106 | 107 | version_string = PyString_FromString (UJSON_VERSION); 108 | PyModule_AddObject (module, "__version__", version_string); 109 | 110 | #if PY_MAJOR_VERSION >= 3 111 | return module; 112 | #endif 113 | } 114 | -------------------------------------------------------------------------------- /srsly/ujson/version.h: -------------------------------------------------------------------------------- 1 | /* 2 | Developed by ESN, an Electronic Arts Inc. studio. 3 | Copyright (c) 2014, Electronic Arts Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | * Neither the name of ESN, Electronic Arts Inc. nor the 14 | names of its contributors may be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | 29 | Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) 30 | http://code.google.com/p/stringencoders/ 31 | Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. 32 | 33 | Numeric decoder derived from from TCL library 34 | http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms 35 | * Copyright (c) 1988-1993 The Regents of the University of California. 36 | * Copyright (c) 1994 Sun Microsystems, Inc. 37 | */ 38 | 39 | #define UJSON_VERSION "1.35" 40 | -------------------------------------------------------------------------------- /srsly/util.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Dict, Any, List, Tuple 3 | from collections import OrderedDict 4 | 5 | 6 | # fmt: off 7 | FilePath = Union[str, Path] 8 | # Superficial JSON input/output types 9 | # https://github.com/python/typing/issues/182#issuecomment-186684288 10 | JSONOutput = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] 11 | JSONOutputBin = Union[bytes, str, int, float, bool, None, Dict[str, Any], List[Any]] 12 | # For input, we also accept tuples, ordered dicts etc. 13 | JSONInput = Union[str, int, float, bool, None, Dict[str, Any], List[Any], Tuple[Any, ...], OrderedDict] 14 | JSONInputBin = Union[bytes, str, int, float, bool, None, Dict[str, Any], List[Any], Tuple[Any, ...], OrderedDict] 15 | YAMLInput = JSONInput 16 | YAMLOutput = JSONOutput 17 | # fmt: on 18 | 19 | 20 | def force_path(location, require_exists=True): 21 | if not isinstance(location, Path): 22 | location = Path(location) 23 | if require_exists and not location.exists(): 24 | raise ValueError(f"Can't read file: {location}") 25 | return location 26 | 27 | 28 | def force_string(location): 29 | if isinstance(location, str): 30 | return location 31 | return str(location) 32 | --------------------------------------------------------------------------------