├── .flake8 ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── python-package.yml │ └── release.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── MIGRATING_FROM_OLDER_VERSIONS.rst ├── README.rst ├── benchmark ├── bytebuffer_bench.py └── read_s3.py ├── ci_helpers ├── README.txt ├── doctest.py ├── helpers.sh ├── run_benchmarks.py ├── run_integration_tests.py └── test_missing_dependencies.py ├── extending.md ├── help.txt ├── howto.md ├── integration-tests ├── README.md ├── initialize_s3_bucket.py ├── requirements.txt ├── test_184.py ├── test_207.py ├── test_azure.py ├── test_ftp.py ├── test_gcs.py ├── test_hdfs.py ├── test_http.py ├── test_minio.py ├── test_s3.py ├── test_s3_buffering.py ├── test_s3_ported.py ├── test_s3_readline.py ├── test_ssh.py ├── test_version_id.py └── test_webhdfs.py ├── pyproject.toml ├── release ├── README.md ├── annotate_pr.py ├── check_preamble.py ├── doctest.sh ├── hijack_pr.py ├── merge.sh ├── prepare.sh ├── update_help_txt.sh └── update_release_notes.py ├── sampledata └── hello.zip ├── setup.py ├── smart_open ├── __init__.py ├── azure.py ├── bytebuffer.py ├── compression.py ├── concurrency.py ├── constants.py ├── doctools.py ├── ftp.py ├── gcs.py ├── hdfs.py ├── http.py ├── local_file.py ├── s3.py ├── smart_open_lib.py ├── ssh.py ├── tests │ ├── __init__.py │ ├── fixtures │ │ ├── __init__.py │ │ ├── good_transport.py │ │ ├── missing_deps_transport.py │ │ └── no_schemes_transport.py │ ├── test_azure.py │ ├── test_bytebuffer.py │ ├── test_compression.py │ ├── test_data │ │ ├── 1984.txt.bz2 │ │ ├── 1984.txt.gz │ │ ├── 1984.txt.gzip │ │ ├── 1984.txt.xz │ │ ├── cp852.tsv.txt │ │ ├── crime-and-punishment.txt │ │ ├── crime-and-punishment.txt.gz │ │ ├── crime-and-punishment.txt.xz │ │ ├── crlf_at_1k_boundary.warc.gz │ │ └── ssh.cfg │ ├── test_gcs.py │ ├── test_hdfs.py │ ├── test_http.py │ ├── test_package.py │ ├── test_s3.py │ ├── test_s3_version.py │ ├── test_smart_open.py │ ├── test_ssh.py │ ├── test_transport.py │ └── test_utils.py ├── transport.py ├── utils.py ├── version.py └── webhdfs.py └── update_helptext.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E121: Continuation line under-indented for hanging indent 3 | # E123: Continuation line missing indentation or outdented 4 | # E125: Continuation line with same indent as next logical line 5 | # E128: Continuation line under-indented for visual indent 6 | # E226: Missing whitespace around arithmetic operator 7 | # W503: Line break occurred before a binary operator 8 | ignore=E121,E123,E125,E128,E226,W503 9 | max-line-length=110 -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [piskvorky] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | #### Problem description 2 | 3 | Be sure your description clearly answers the following questions: 4 | 5 | - What are you trying to achieve? 6 | - What is the expected result? 7 | - What are you seeing instead? 8 | 9 | #### Steps/code to reproduce the problem 10 | 11 | In order for us to be able to solve your problem, we have to be able to reproduce it on our end. 12 | Without reproducing the problem, it is unlikely that we'll be able to help you. 13 | 14 | Include full tracebacks, logs and datasets if necessary. 15 | Please keep the examples minimal ([minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example)). 16 | 17 | #### Versions 18 | 19 | Please provide the output of: 20 | 21 | ```python 22 | import platform, sys, smart_open 23 | print(platform.platform()) 24 | print("Python", sys.version) 25 | print("smart_open", smart_open.__version__) 26 | ``` 27 | 28 | #### Checklist 29 | 30 | Before you create the issue, please make sure you have: 31 | 32 | - [ ] Described the problem clearly 33 | - [ ] Provided a minimal reproducible example, including any required data 34 | - [ ] Provided the version numbers of the relevant software 35 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | > Please **pick a concise, informative and complete title** for your PR. 2 | > 3 | > The title is important because it will appear in [our change log](https://github.com/RaRe-Technologies/smart_open/blob/master/CHANGELOG.md). 4 | 5 | ### Motivation 6 | 7 | > Please explain the motivation behind this PR. 8 | > 9 | > If you're fixing a bug, link to the issue using a [supported keyword](https://docs.github.com/en/issues/tracking-your-work-with-issues/using-issues/linking-a-pull-request-to-an-issue) like "Fixes #{issue_number}". 10 | > 11 | > If you're adding a new feature, then consider opening a ticket and discussing it with the maintainers before you actually do the hard work. 12 | 13 | Fixes #{issue_number} 14 | 15 | ### Tests 16 | 17 | > If you're fixing a bug, consider [test-driven development](https://en.wikipedia.org/wiki/Test-driven_development): 18 | > 19 | > 1. Create a unit test that demonstrates the bug. The test should **fail**. 20 | > 2. Implement your bug fix. 21 | > 3. The test you created should now **pass**. 22 | > 23 | > If you're implementing a new feature, include unit tests for it. 24 | > 25 | > Make sure all existing unit tests pass. 26 | > You can run them locally using: 27 | > 28 | > pytest smart_open 29 | > 30 | > If there are any failures, please fix them before creating the PR (or mark it as WIP, see below). 31 | 32 | ### Work in progress 33 | 34 | > If you're still working on your PR, mark the PR as [draft PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request). 35 | > 36 | > We'll skip reviewing it for the time being. 37 | > 38 | > Once it's ready, mark the PR as "ready for review", and ping one of the maintainers (e.g. mpenkov). 39 | 40 | ### Checklist 41 | 42 | > Before you mark the PR as "ready for review", please make sure you have: 43 | 44 | - [ ] Picked a concise, informative and complete title 45 | - [ ] Clearly explained the motivation behind the PR 46 | - [ ] Linked to any existing issues that your PR will be solving 47 | - [ ] Included tests for any new functionality 48 | - [ ] Run `python update_helptext.py` in case there are API changes 49 | - [ ] Checked that all unit tests pass 50 | 51 | ### Workflow 52 | 53 | > Please avoid rebasing and force-pushing to the branch of the PR once a review is in progress. 54 | > 55 | > Rebasing can make your commits look a bit cleaner, but it also makes life more difficult from the reviewer, because they are no longer able to distinguish between code that has already been reviewed, and unreviewed code. 56 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: [push, pull_request] 3 | concurrency: # https://stackoverflow.com/questions/66335225#comment133398800_72408109 4 | group: ${{ github.workflow }}-${{ github.ref || github.run_id }} 5 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 6 | jobs: 7 | linters: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | 12 | - name: Setup up Python 3.11 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: "3.11" 16 | 17 | - name: Install dependencies 18 | run: pip install flake8 -e .[all] 19 | 20 | - name: Run flake8 linter (source) 21 | run: flake8 --show-source smart_open 22 | 23 | - name: "Check whether help.txt update was forgotten" 24 | if: github.event_name == 'pull_request' 25 | run: | 26 | python update_helptext.py 27 | test ! "$(git diff)" && echo "no changes" || ( git diff && echo 'looks like "python update_helptext.py" was forgotten' && exit 1 ) 28 | 29 | unit_tests: 30 | needs: [linters] 31 | runs-on: ${{ matrix.os }} 32 | strategy: 33 | matrix: 34 | include: 35 | - {python-version: '3.8', os: ubuntu-22.04} 36 | - {python-version: '3.9', os: ubuntu-22.04} 37 | - {python-version: '3.10', os: ubuntu-22.04} 38 | - {python-version: '3.11', os: ubuntu-22.04} 39 | - {python-version: '3.12', os: ubuntu-22.04} 40 | - {python-version: '3.13', os: ubuntu-22.04} 41 | 42 | - {python-version: '3.8', os: windows-2019} 43 | - {python-version: '3.9', os: windows-2019} 44 | - {python-version: '3.10', os: windows-2019} 45 | - {python-version: '3.11', os: windows-2019} 46 | - {python-version: '3.12', os: windows-2019} 47 | - {python-version: '3.13', os: windows-2019} 48 | steps: 49 | - uses: actions/checkout@v2 50 | 51 | - uses: actions/setup-python@v2 52 | with: 53 | python-version: ${{ matrix.python-version }} 54 | 55 | - name: Install smart_open without dependencies 56 | run: pip install -e . 57 | 58 | - name: Check that smart_open imports without dependencies 59 | run: python -c 'import smart_open' 60 | 61 | - name: Install smart_open and its dependencies 62 | run: pip install -e .[test] 63 | 64 | - name: Run unit tests 65 | run: pytest smart_open -v -rfxECs --durations=20 66 | 67 | doctest: 68 | needs: [linters,unit_tests] 69 | runs-on: ${{ matrix.os }} 70 | strategy: 71 | matrix: 72 | include: 73 | - {python-version: '3.8', os: ubuntu-22.04} 74 | - {python-version: '3.9', os: ubuntu-22.04} 75 | - {python-version: '3.10', os: ubuntu-22.04} 76 | - {python-version: '3.11', os: ubuntu-22.04} 77 | - {python-version: '3.12', os: ubuntu-22.04} 78 | - {python-version: '3.13', os: ubuntu-22.04} 79 | 80 | # 81 | # Some of the doctests don't pass on Windows because of Windows-specific 82 | # character encoding issues. 83 | # 84 | # - {python-version: '3.7', os: windows-2019} 85 | # - {python-version: '3.8', os: windows-2019} 86 | # - {python-version: '3.9', os: windows-2019} 87 | # - {python-version: '3.10', os: windows-2019} 88 | # - {python-version: '3.11', os: windows-2019} 89 | # - {python-version: '3.12', os: windows-2019} 90 | # - {python-version: '3.13', os: windows-2019} 91 | 92 | steps: 93 | - uses: actions/checkout@v2 94 | 95 | - uses: actions/setup-python@v2 96 | with: 97 | python-version: ${{ matrix.python-version }} 98 | 99 | - name: Install smart_open and its dependencies 100 | run: pip install -e .[test] 101 | 102 | - name: Run doctests 103 | run: python ci_helpers/doctest.py 104 | env: 105 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 106 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 107 | 108 | integration: 109 | needs: [linters,unit_tests] 110 | runs-on: ${{ matrix.os }} 111 | strategy: 112 | matrix: 113 | include: 114 | - {python-version: '3.8', os: ubuntu-22.04} 115 | - {python-version: '3.9', os: ubuntu-22.04} 116 | - {python-version: '3.10', os: ubuntu-22.04} 117 | - {python-version: '3.11', os: ubuntu-22.04} 118 | - {python-version: '3.12', os: ubuntu-22.04} 119 | - {python-version: '3.13', os: ubuntu-22.04} 120 | 121 | # Not sure why we exclude these, perhaps for historical reasons? 122 | # 123 | # - {python-version: '3.7', os: windows-2019} 124 | # - {python-version: '3.8', os: windows-2019} 125 | # - {python-version: '3.9', os: windows-2019} 126 | # - {python-version: '3.10', os: windows-2019} 127 | # - {python-version: '3.11', os: windows-2019} 128 | # - {python-version: '3.12', os: windows-2019} 129 | # - {python-version: '3.13', os: windows-2019} 130 | 131 | steps: 132 | - uses: actions/checkout@v2 133 | 134 | - uses: actions/setup-python@v2 135 | with: 136 | python-version: ${{ matrix.python-version }} 137 | 138 | - name: Install smart_open and its dependencies 139 | run: pip install -e .[test] 140 | 141 | - run: bash ci_helpers/helpers.sh enable_moto_server 142 | if: ${{ matrix.moto_server }} 143 | 144 | - name: Start vsftpd 145 | timeout-minutes: 2 146 | run: | 147 | sudo apt-get install vsftpd 148 | sudo bash ci_helpers/helpers.sh create_ftp_ftps_servers 149 | 150 | - name: Run integration tests 151 | run: python ci_helpers/run_integration_tests.py 152 | env: 153 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 154 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 155 | 156 | - run: bash ci_helpers/helpers.sh disable_moto_server 157 | if: ${{ matrix.moto_server }} 158 | 159 | - run: sudo bash ci_helpers/helpers.sh delete_ftp_ftps_servers 160 | 161 | benchmarks: 162 | needs: [linters,unit_tests] 163 | runs-on: ${{ matrix.os }} 164 | strategy: 165 | matrix: 166 | include: 167 | - {python-version: '3.8', os: ubuntu-22.04} 168 | - {python-version: '3.9', os: ubuntu-22.04} 169 | - {python-version: '3.10', os: ubuntu-22.04} 170 | - {python-version: '3.11', os: ubuntu-22.04} 171 | - {python-version: '3.12', os: ubuntu-22.04} 172 | - {python-version: '3.13', os: ubuntu-22.04} 173 | 174 | # - {python-version: '3.7', os: windows-2019} 175 | # - {python-version: '3.8', os: windows-2019} 176 | # - {python-version: '3.9', os: windows-2019} 177 | # - {python-version: '3.10', os: windows-2019} 178 | # - {python-version: '3.11', os: windows-2019} 179 | # - {python-version: '3.12', os: windows-2019} 180 | # - {python-version: '3.13', os: windows-2019} 181 | 182 | steps: 183 | - uses: actions/checkout@v2 184 | 185 | - uses: actions/setup-python@v2 186 | with: 187 | python-version: ${{ matrix.python-version }} 188 | 189 | - name: Install smart_open and its dependencies 190 | run: pip install -e .[test] 191 | 192 | - name: Run benchmarks 193 | run: python ci_helpers/run_benchmarks.py 194 | env: 195 | SO_BUCKET: smart-open 196 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 197 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 198 | 199 | # 200 | # The test_coverage environment in tox.ini generates coverage data and 201 | # saves it to disk. This step uploads that data. We do it 202 | # separately from the tox env because the upload can fail for various 203 | # reasons (e.g. https://github.com/lemurheavy/coveralls-public/issues/1392) 204 | # and we don't want it to break the build. 205 | # 206 | # Looks like there's a github action for this 207 | # (https://github.com/coverallsapp/github-action/issues/30) but it does 208 | # not work with pytest output. 209 | # 210 | # - name: Upload code coverage to coveralls.io 211 | # if: ${{ matrix.coveralls }} 212 | # continue-on-error: true 213 | # env: 214 | # GITHUB_TOKEN: ${{ github.token }} 215 | # run: | 216 | # pip install coveralls 217 | # coveralls 218 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | jobs: 8 | tarball: 9 | if: github.event_name == 'push' 10 | timeout-minutes: 1 11 | runs-on: ubuntu-20.04 12 | env: 13 | PYPI_USERNAME: ${{ secrets.PYPI_USERNAME }} 14 | PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 15 | steps: 16 | - uses: actions/checkout@v1 17 | 18 | - uses: actions/setup-python@v1 19 | with: 20 | python-version: "3.8.x" 21 | 22 | # https://github.community/t/how-to-get-just-the-tag-name/16241/4 23 | - name: Extract the version number 24 | id: get_version 25 | run: | 26 | echo ::set-output name=V::$(python smart_open/version.py) 27 | 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | python -m venv venv 32 | . venv/bin/activate 33 | pip install twine wheel 34 | 35 | - name: Build and upload tarball to PyPI 36 | run: | 37 | . venv/bin/activate 38 | python setup.py sdist 39 | twine upload dist/smart_open-${{ steps.get_version.outputs.V }}.tar.gz -u ${{ env.PYPI_USERNAME }} -p ${{ env.PYPI_PASSWORD }} 40 | 41 | - name: Build and upload wheel to PyPI 42 | run: | 43 | . venv/bin/activate 44 | python setup.py bdist_wheel 45 | twine upload dist/smart_open-${{ steps.get_version.outputs.V }}-py3-none-any.whl -u ${{ env.PYPI_USERNAME }} -p ${{ env.PYPI_PASSWORD }} 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | 25 | # PyInstaller 26 | # Usually these files are written by a python script from a template 27 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 28 | *.manifest 29 | *.spec 30 | 31 | # Installer logs 32 | pip-log.txt 33 | pip-delete-this-directory.txt 34 | 35 | # Unit test / coverage reports 36 | htmlcov/ 37 | .tox/ 38 | .coverage 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | 43 | # Translations 44 | *.mo 45 | *.pot 46 | 47 | # Django stuff: 48 | *.log 49 | 50 | # Sphinx documentation 51 | docs/_build/ 52 | 53 | # PyBuilder 54 | target/ 55 | 56 | # vim 57 | *.swp 58 | *.swo 59 | 60 | # PyCharm 61 | .idea/ 62 | 63 | # VSCode 64 | .vscode/ 65 | 66 | # env files 67 | .env 68 | .venv 69 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | Clone the repo and use a python installation to create a venv: 4 | 5 | ```sh 6 | git clone git@github.com:RaRe-Technologies/smart_open.git 7 | cd smart_open 8 | python -m venv .venv 9 | ``` 10 | 11 | Activate the venv to start working and install test deps: 12 | 13 | ```sh 14 | .venv/bin/activate 15 | pip install -e ".[test]" 16 | ``` 17 | 18 | Tests should pass: 19 | 20 | ```sh 21 | pytest 22 | ``` 23 | 24 | Thats it! When you're done, deactivate the venv: 25 | 26 | ```sh 27 | deactivate 28 | ``` 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Radim Řehůřek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.rst 3 | include MIGRATING_FROM_OLDER_VERSIONS.rst 4 | include CHANGELOG.md 5 | -------------------------------------------------------------------------------- /MIGRATING_FROM_OLDER_VERSIONS.rst: -------------------------------------------------------------------------------- 1 | Migrating to the new compression parameter 2 | ========================================== 3 | 4 | smart_open versions 6.0.0 and above no longer support the ``ignore_ext`` parameter. 5 | Use the ``compression`` parameter instead: 6 | 7 | .. code-block:: python 8 | 9 | fin = smart_open.open("/path/file.gz", ignore_ext=True) # No 10 | fin = smart_open.open("/path/file.gz", compression="disable") # Yes 11 | 12 | fin = smart_open.open("/path/file.gz", ignore_ext=False) # No 13 | fin = smart_open.open("/path/file.gz") # Yes 14 | fin = smart_open.open("/path/file.gz", compression="infer_from_extension") # Yes, if you want to be explicit 15 | 16 | fin = smart_open.open("/path/file", compression=".gz") # Yes 17 | 18 | 19 | Migrating to the new client-based S3 API 20 | ======================================== 21 | 22 | Version of smart_open prior to 5.0.0 used the boto3 `resource API`_ for communicating with S3. 23 | This API was easy to integrate for smart_open developers, but this came at a cost: it was not thread- or multiprocess-safe. 24 | Furthermore, as smart_open supported more and more options, the transport parameter list grew, making it less maintainable. 25 | 26 | Starting with version 5.0.0, smart_open uses the `client API`_ instead of the resource API. 27 | Functionally, very little changes for the smart_open user. 28 | The only difference is in passing transport parameters to the S3 backend. 29 | 30 | More specifically, the following S3 transport parameters are no longer supported: 31 | 32 | - `multipart_upload_kwargs` 33 | - `object_kwargs` 34 | - `resource` 35 | - `resource_kwargs` 36 | - `session` 37 | - `singlepart_upload_kwargs` 38 | 39 | **If you weren't using the above parameters, nothing changes for you.** 40 | 41 | However, if you were using any of the above, then you need to adjust your code. 42 | Here are some quick recipes below. 43 | 44 | If you were previously passing `session`, then construct an S3 client from the session and pass that instead. 45 | For example, before: 46 | 47 | .. code-block:: python 48 | 49 | smart_open.open('s3://bucket/key', transport_params={'session': session}) 50 | 51 | After: 52 | 53 | .. code-block:: python 54 | 55 | smart_open.open('s3://bucket/key', transport_params={'client': session.client('s3')}) 56 | 57 | If you were passing `resource`, then replace the resource with a client, and pass that instead. 58 | For example, before: 59 | 60 | .. code-block:: python 61 | 62 | resource = session.resource('s3', **resource_kwargs) 63 | smart_open.open('s3://bucket/key', transport_params={'resource': resource}) 64 | 65 | After: 66 | 67 | .. code-block:: python 68 | 69 | client = session.client('s3') 70 | smart_open.open('s3://bucket/key', transport_params={'client': client}) 71 | 72 | If you were passing any of the `*_kwargs` parameters, you will need to include them in `client_kwargs`, keeping in mind the following transformations. 73 | 74 | ========================== ====================================== ========================== 75 | Parameter name Resource API method Client API function 76 | ========================== ====================================== ========================== 77 | `multipart_upload_kwargs` `S3.Object.initiate_multipart_upload`_ `S3.Client.create_multipart_upload`_ 78 | `object_kwargs` `S3.Object.get`_ `S3.Client.get_object`_ 79 | `resource_kwargs` S3.resource `S3.client`_ 80 | `singlepart_upload_kwargs` `S3.Object.put`_ `S3.Client.put_object`_ 81 | ========================== ====================================== ========================== 82 | 83 | Most of the above is self-explanatory, with the exception of `resource_kwargs`. 84 | These were previously used mostly for passing a custom endpoint URL. 85 | 86 | The `client_kwargs` dict can thus contain the following members: 87 | 88 | - `S3.Client`: initializer parameters, e.g. those to pass directly to the `boto3.client` function, such as `endpoint_url`. 89 | - `S3.Client.create_multipart_upload` 90 | - `S3.Client.get_object` 91 | - `S3.Client.put_object` 92 | 93 | Here's a before-and-after example for connecting to a custom endpoint. Before: 94 | 95 | .. code-block:: python 96 | 97 | session = boto3.Session(profile_name='digitalocean') 98 | resource_kwargs = {'endpoint_url': 'https://ams3.digitaloceanspaces.com'} 99 | with open('s3://bucket/key.txt', 'wb', transport_params={'resource_kwarg': resource_kwargs}) as fout: 100 | fout.write(b'here we stand') 101 | 102 | After: 103 | 104 | .. code-block:: python 105 | 106 | session = boto3.Session(profile_name='digitalocean') 107 | client = session.client('s3', endpoint_url='https://ams3.digitaloceanspaces.com') 108 | with open('s3://bucket/key.txt', 'wb', transport_params={'client': client}) as fout: 109 | fout.write(b'here we stand') 110 | 111 | See `README `_ and `HOWTO `_ for more examples. 112 | 113 | .. _resource API: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#service-resource 114 | .. _S3.Object.initiate_multipart_upload: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Object.initiate_multipart_upload 115 | .. _S3.Object.get: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.ObjectSummary.get 116 | .. _S3.Object.put: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.ObjectSummary.put 117 | 118 | .. _client API: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#client 119 | .. _S3.Client: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#client 120 | .. _S3.Client.create_multipart_upload: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.create_multipart_upload 121 | .. _S3.Client.get_object: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object 122 | .. _S3.Client.put_object: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.put_object 123 | 124 | Migrating to the new dependency management subsystem 125 | ==================================================== 126 | 127 | Smart_open has grown over the years to cover a lot of different storages, each with a different set of library dependencies. Not everybody needs *all* of them, so to make each smart_open installation leaner and faster, version 3.0.0 introduced a new, backward-incompatible installation method: 128 | 129 | * smart_open < 3.0.0: All dependencies were installed by default. No way to select just a subset during installation. 130 | * smart_open >= 3.0.0: No dependencies installed by default. Install the ones you need with e.g. ``pip install smart_open[s3]`` (only AWS), or ``smart_open[all]`` (install everything = same behaviour as < 3.0.0; use this for backward compatibility). 131 | 132 | You can read more about the motivation and internal discussions for this change `here `_. 133 | 134 | Migrating to the new ``open`` function 135 | ====================================== 136 | 137 | Since 1.8.1, there is a ``smart_open.open`` function that replaces ``smart_open.smart_open``. 138 | The new function offers several advantages over the old one: 139 | 140 | - 100% compatible with the built-in ``open`` function (aka ``io.open``): it accepts all 141 | the parameters that the built-in ``open`` accepts. 142 | - The default open mode is now "r", the same as for the built-in ``open``. 143 | The default for the old ``smart_open.smart_open`` function used to be "rb". 144 | - Fully documented keyword parameters (try ``help("smart_open.open")``) 145 | 146 | The instructions below will help you migrate to the new function painlessly. 147 | 148 | First, update your imports: 149 | 150 | .. code-block:: python 151 | 152 | >>> from smart_open import smart_open # before 153 | >>> from smart_open import open # after 154 | 155 | In general, ``smart_open`` uses ``io.open`` directly, where possible, so if your 156 | code already uses ``open`` for local file I/O, then it will continue to work. 157 | If you want to continue using the built-in ``open`` function for e.g. debugging, 158 | then you can ``import smart_open`` and use ``smart_open.open``. 159 | 160 | **The default read mode is now "r" (read text).** 161 | If your code was implicitly relying on the default mode being "rb" (read 162 | binary), you'll need to update it and pass "rb" explicitly. 163 | 164 | Before: 165 | 166 | .. code-block:: python 167 | 168 | >>> import smart_open 169 | >>> smart_open.smart_open('s3://commoncrawl/robots.txt').read(32) # 'rb' used to be the default 170 | b'User-Agent: *\nDisallow: /' 171 | 172 | After: 173 | 174 | .. code-block:: python 175 | 176 | >>> import smart_open 177 | >>> smart_open.open('s3://commoncrawl/robots.txt', 'rb').read(32) 178 | b'User-Agent: *\nDisallow: /' 179 | 180 | The ``ignore_extension`` keyword parameter is now called ``ignore_ext``. 181 | It behaves identically otherwise. 182 | 183 | The most significant change is in the handling on keyword parameters for the 184 | transport layer, e.g. HTTP, S3, etc. The old function accepted these directly: 185 | 186 | .. code-block:: python 187 | 188 | >>> url = 's3://smart-open-py37-benchmark-results/test.txt' 189 | >>> session = boto3.Session(profile_name='smart_open') 190 | >>> smart_open.smart_open(url, 'r', session=session).read(32) 191 | 'first line\nsecond line\nthird lin' 192 | 193 | The new function accepts a ``transport_params`` keyword argument. It's a dict. 194 | Put your transport parameters in that dictionary. 195 | 196 | .. code-block:: python 197 | 198 | >>> url = 's3://smart-open-py37-benchmark-results/test.txt' 199 | >>> params = {'session': boto3.Session(profile_name='smart_open')} 200 | >>> open(url, 'r', transport_params=params).read(32) 201 | 'first line\nsecond line\nthird lin' 202 | 203 | Renamed parameters: 204 | 205 | - ``s3_upload`` -> ``multipart_upload_kwargs`` 206 | - ``s3_session`` -> ``session`` 207 | 208 | Removed parameters: 209 | 210 | - ``profile_name`` 211 | 212 | **The profile_name parameter has been removed.** 213 | Pass an entire ``boto3.Session`` object instead. 214 | 215 | Before: 216 | 217 | .. code-block:: python 218 | 219 | >>> url = 's3://smart-open-py37-benchmark-results/test.txt' 220 | >>> smart_open.smart_open(url, 'r', profile_name='smart_open').read(32) 221 | 'first line\nsecond line\nthird lin' 222 | 223 | After: 224 | 225 | .. code-block:: python 226 | 227 | >>> url = 's3://smart-open-py37-benchmark-results/test.txt' 228 | >>> params = {'session': boto3.Session(profile_name='smart_open')} 229 | >>> open(url, 'r', transport_params=params).read(32) 230 | 'first line\nsecond line\nthird lin' 231 | 232 | See ``help("smart_open.open")`` for the full list of acceptable parameter names, 233 | or view the help online `here `__. 234 | 235 | If you pass an invalid parameter name, the ``smart_open.open`` function will warn you about it. 236 | Keep an eye on your logs for WARNING messages from ``smart_open``. 237 | -------------------------------------------------------------------------------- /benchmark/bytebuffer_bench.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | 4 | import smart_open 5 | from smart_open.bytebuffer import ByteBuffer 6 | 7 | 8 | def raw_bytebuffer_benchmark(): 9 | buffer = ByteBuffer() 10 | 11 | start = time.time() 12 | for _ in range(10_000): 13 | assert buffer.fill([b"X" * 1000]) == 1000 14 | return time.time() - start 15 | 16 | 17 | def file_read_benchmark(filename): 18 | file = smart_open.open(filename, mode="rb") 19 | 20 | start = time.time() 21 | read = file.read(100_000_000) 22 | end = time.time() 23 | 24 | if len(read) < 100_000_000: 25 | print("File smaller than 100MB") 26 | 27 | return end - start 28 | 29 | 30 | print("Raw ByteBuffer benchmark:", raw_bytebuffer_benchmark()) 31 | 32 | if len(sys.argv) > 1: 33 | bench_result = file_read_benchmark(sys.argv[1]) 34 | print("File read benchmark", bench_result) 35 | -------------------------------------------------------------------------------- /benchmark/read_s3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import boto3 4 | import smart_open 5 | 6 | urls = [line.strip() for line in sys.stdin] 7 | 8 | tp = {} 9 | if 'create_session_and_resource' in sys.argv: 10 | tp['session'] = boto3.Session() 11 | tp['resource'] = tp['session'].resource('s3') 12 | elif 'create_resource' in sys.argv: 13 | tp['resource'] = boto3.resource('s3') 14 | elif 'create_session' in sys.argv: 15 | tp['session'] = boto3.Session() 16 | 17 | for url in urls: 18 | smart_open.open(url, transport_params=tp).read() 19 | -------------------------------------------------------------------------------- /ci_helpers/README.txt: -------------------------------------------------------------------------------- 1 | This subdirectory contains helper scripts for our continuous integration workflows file. 2 | 3 | They are designed to be platform-independent: they run on both Linux and Windows. 4 | -------------------------------------------------------------------------------- /ci_helpers/doctest.py: -------------------------------------------------------------------------------- 1 | """Runs the doctests, if the AWS credentials are available. 2 | 3 | Without the credentials, skips the tests entirely, because otherwise they will fail. 4 | """ 5 | import os 6 | import subprocess 7 | 8 | if os.environ.get('AWS_ACCESS_KEY_ID') and os.environ.get('AWS_SECRET_ACCESS_KEY'): 9 | subprocess.check_call(['python', '-m', 'doctest', 'README.rst', '-v']) 10 | -------------------------------------------------------------------------------- /ci_helpers/helpers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | enable_moto_server(){ 7 | moto_server -p5000 2>/dev/null& 8 | } 9 | 10 | create_ftp_ftps_servers(){ 11 | # 12 | # Must be run as root 13 | # 14 | home_dir=/home/user 15 | user=user 16 | pass=123 17 | ftp_port=21 18 | ftps_port=90 19 | 20 | mkdir $home_dir 21 | useradd -p $(echo $pass | openssl passwd -1 -stdin) -d $home_dir $user 22 | chown $user:$user $home_dir 23 | openssl req -x509 -nodes -new -sha256 -days 10240 -newkey rsa:2048 -keyout /etc/vsftpd.key -out /etc/vsftpd.pem -subj "/C=ZA/CN=localhost" 24 | chmod 755 /etc/vsftpd.key 25 | chmod 755 /etc/vsftpd.pem 26 | 27 | server_setup=''' 28 | listen=YES 29 | listen_ipv6=NO 30 | write_enable=YES 31 | pasv_enable=YES 32 | pasv_min_port=40000 33 | pasv_max_port=40009 34 | chroot_local_user=YES 35 | allow_writeable_chroot=YES''' 36 | 37 | additional_ssl_setup=''' 38 | rsa_cert_file=/etc/vsftpd.pem 39 | rsa_private_key_file=/etc/vsftpd.key 40 | ssl_enable=YES 41 | allow_anon_ssl=NO 42 | force_local_data_ssl=NO 43 | force_local_logins_ssl=NO 44 | require_ssl_reuse=NO 45 | ''' 46 | 47 | cp /etc/vsftpd.conf /etc/vsftpd-ssl.conf 48 | echo -e "$server_setup\nlisten_port=${ftp_port}" >> /etc/vsftpd.conf 49 | echo -e "$server_setup\nlisten_port=${ftps_port}\n$additional_ssl_setup" >> /etc/vsftpd-ssl.conf 50 | 51 | service vsftpd restart 52 | vsftpd /etc/vsftpd-ssl.conf & 53 | } 54 | 55 | disable_moto_server(){ 56 | lsof -i tcp:5000 | tail -n1 | cut -f2 -d" " | xargs kill -9 57 | } 58 | 59 | delete_ftp_ftps_servers(){ 60 | service vsftpd stop 61 | } 62 | 63 | "$@" 64 | -------------------------------------------------------------------------------- /ci_helpers/run_benchmarks.py: -------------------------------------------------------------------------------- 1 | """Runs benchmarks. 2 | 3 | We only do this is AWS credentials are available, because without them, it 4 | is impossible to run the benchmarks at all. 5 | """ 6 | import os 7 | import platform 8 | import uuid 9 | import subprocess 10 | 11 | import smart_open 12 | 13 | if os.environ.get('AWS_ACCESS_KEY_ID') and os.environ.get('AWS_SECRET_ACCESS_KEY'): 14 | 15 | required = ('SO_BUCKET', ) 16 | for varname in required: 17 | assert varname in os.environ, 'the following env vars must be set: %s' % ', '.join(required) 18 | 19 | os.environ['PYTEST_ADDOPTS'] = "--reruns 3 --reruns-delay 1" 20 | 21 | commit_hash = subprocess.check_output( 22 | ['git', 'rev-parse', 'HEAD'] 23 | ).decode('utf-8').strip() 24 | 25 | # 26 | # This is a temporary key that test_s3 will use for I/O. 27 | # 28 | os.environ['SO_KEY'] = str(uuid.uuid4()) 29 | subprocess.check_call( 30 | [ 31 | 'pytest', 32 | '-v', 33 | 'integration-tests/test_s3.py', 34 | '--benchmark-save=%s' % commit_hash, 35 | ] 36 | ) 37 | 38 | url = 's3://%s/benchmark-results/%s' % ( 39 | os.environ['SO_BUCKET'], 40 | commit_hash, 41 | ) 42 | for root, subdirs, files in os.walk('.benchmarks'): 43 | for f in files: 44 | if f.endswith('%s.json' % commit_hash): 45 | out_url = '%s/%s.json' % (url, platform.python_version()) 46 | with open(os.path.join(root, f), 'rt') as fin: 47 | with smart_open.open(out_url, 'wt') as fout: 48 | fout.write(fin.read()) 49 | -------------------------------------------------------------------------------- /ci_helpers/run_integration_tests.py: -------------------------------------------------------------------------------- 1 | """Runs integration tests.""" 2 | import os 3 | import subprocess 4 | 5 | os.environ['PYTEST_ADDOPTS'] = "--reruns 3 --reruns-delay 1" 6 | 7 | subprocess.check_call( 8 | [ 9 | 'pytest', 10 | 'integration-tests/test_207.py', 11 | 'integration-tests/test_http.py', 12 | 'integration-tests/test_ftp.py' 13 | ] 14 | ) 15 | 16 | if os.environ.get('AWS_ACCESS_KEY_ID') and os.environ.get('AWS_SECRET_ACCESS_KEY'): 17 | subprocess.check_call(['pytest', '-v', 'integration-tests/test_s3_ported.py']) 18 | -------------------------------------------------------------------------------- /ci_helpers/test_missing_dependencies.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | os.environ['SMART_OPEN_TEST_MISSING_DEPS'] = '1' 5 | command = [ 6 | 'pytest', 7 | 'smart_open/tests/test_package.py', 8 | '-v', 9 | '--cov', 'smart_open', 10 | '--cov-report', 'term-missing', 11 | ] 12 | subprocess.check_call(command) 13 | -------------------------------------------------------------------------------- /extending.md: -------------------------------------------------------------------------------- 1 | # Extending `smart_open` 2 | 3 | This document targets potential contributors to `smart_open`. 4 | Currently, there are two main directions for extending existing `smart_open` functionality: 5 | 6 | 1. Add a new transport mechanism 7 | 2. Add a new compression format 8 | 9 | The first is by far the more challenging, and also the more welcome. 10 | 11 | ## New transport mechanisms 12 | 13 | Each transport mechanism lives in its own submodule. 14 | For example, currently we have: 15 | 16 | - [smart_open.local_file](smart_open/local_file.py) 17 | - [smart_open.s3](smart_open/s3.py) 18 | - [smart_open.ssh](smart_open/ssh.py) 19 | - ... and others 20 | 21 | So, to implement a new transport mechanism, you need to create a new module. 22 | Your module must expose the following (see [smart_open.http](smart_open/http.py) for the full implementation): 23 | 24 | ```python 25 | SCHEMA = ... 26 | """The name of the mechanism, e.g. s3, ssh, etc. 27 | 28 | This is the part that goes before the `://` in a URL, e.g. `s3://`.""" 29 | 30 | URI_EXAMPLES = ('xxx://foo/bar', 'zzz://baz/boz') 31 | """This will appear in the documentation of the the `parse_uri` function.""" 32 | 33 | MISSING_DEPS = False 34 | """Wrap transport-specific imports in a try/catch and set this to True if 35 | any imports are not found. Seting MISSING_DEPS to True will cause the library 36 | to suggest installing its dependencies with an example pip command. 37 | 38 | If your transport has no external dependencies, you can omit this variable. 39 | """ 40 | 41 | def parse_uri(uri_as_str): 42 | """Parse the specified URI into a dict. 43 | 44 | At a bare minimum, the dict must have `schema` member. 45 | """ 46 | return dict(schema=XXX_SCHEMA, ...) 47 | 48 | 49 | def open_uri(uri_as_str, mode, transport_params): 50 | """Return a file-like object pointing to the URI. 51 | 52 | Parameters: 53 | 54 | uri_as_str: str 55 | The URI to open 56 | mode: str 57 | Either "rb" or "wb". You don't need to implement text modes, 58 | `smart_open` does that for you, outside of the transport layer. 59 | transport_params: dict 60 | Any additional parameters to pass to the `open` function (see below). 61 | 62 | """ 63 | # 64 | # Parse the URI using parse_uri 65 | # Consolidate the parsed URI with transport_params, if needed 66 | # Pass everything to the open function (see below). 67 | # 68 | ... 69 | 70 | 71 | def open(..., mode, param1=None, param2=None, paramN=None): 72 | """This function does the hard work. 73 | 74 | The keyword parameters are the transport_params from the `open_uri` 75 | function. 76 | 77 | """ 78 | ... 79 | ``` 80 | 81 | Have a look at the existing mechanisms to see how they work. 82 | You may define other functions and classes as necessary for your implementation. 83 | 84 | Once your module is working, register it in the [smart_open.transport](smart_open/transport.py) submodule. 85 | The `register_transport()` function updates a mapping from schemes to the modules that implement functionality for them. 86 | 87 | Once you've registered your new transport module, the following will happen automagically: 88 | 89 | 1. `smart_open` will be able to open any URI supported by your module 90 | 2. The docstring for the `smart_open.open` function will contain a section 91 | detailing the parameters for your transport module. 92 | 3. The docstring for the `parse_uri` function will include the schemas and 93 | examples supported by your module. 94 | 95 | You can confirm the documentation changes by running: 96 | 97 | python -c 'help("smart_open")' 98 | 99 | and verify that documentation for your new submodule shows up. 100 | 101 | ### What's the difference between the `open_uri` and `open` functions? 102 | 103 | There are several key differences between the two. 104 | 105 | First, the parameters to `open_uri` are the same for _all transports_. 106 | On the other hand, the parameters to the `open` function can differ from transport to transport. 107 | 108 | Second, the responsibilities of the two functions are also different. 109 | The `open` function opens the remote object. 110 | The `open_uri` function deals with parsing transport-specific details out of the URI, and then delegates to `open`. 111 | 112 | The `open` function contains documentation for transport parameters. 113 | This documentation gets parsed by the `doctools` module and appears in various docstrings. 114 | 115 | Some of these differences are by design; others as a consequence of evolution. 116 | 117 | ## New compression mechanisms 118 | 119 | The compression layer is self-contained in the `smart_open.compression` submodule. 120 | 121 | To add support for a new compressor: 122 | 123 | - Create a new function to handle your compression format (given an extension) 124 | - Add your compressor to the registry 125 | 126 | For example: 127 | 128 | ```python 129 | def _handle_xz(file_obj, mode): 130 | import lzma 131 | return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ) 132 | 133 | 134 | register_compressor('.xz', _handle_xz) 135 | ``` 136 | 137 | There are many compression formats out there, and supporting all of them is beyond the scope of `smart_open`. 138 | We want our code's functionality to cover the bare minimum required to satisfy 80% of our users. 139 | We leave the remaining 20% of users with the ability to deal with compression in their own code, using the trivial mechanism described above. 140 | 141 | Documentation 142 | ------------- 143 | 144 | Once you've contributed your extension, please add it to the documentation so that it is discoverable for other users. 145 | Some notable files: 146 | 147 | - setup.py: See the `description` keyword. Not all contributions will affect this. 148 | - README.rst 149 | - howto.md (if your extension solves a specific problem that doesn't get covered by other documentation) 150 | -------------------------------------------------------------------------------- /integration-tests/README.md: -------------------------------------------------------------------------------- 1 | This directory contains integration tests for smart_open. 2 | To run the tests, you need read/write access to an S3 bucket. 3 | Also, you need to install py.test and its benchmarks addon: 4 | 5 | pip install -r requirements.txt 6 | 7 | Then, to run the tests, run: 8 | 9 | SO_BUCKET=bucket SO_KEY=key py.test integration-tests/test_s3.py 10 | 11 | You may use any key name instead of "smart_open_test". 12 | It does not have to be an existing key. 13 | The tests will create temporary keys under `s3://SO_BUCKET/SO_KEY` and remove them at completion. 14 | 15 | The tests will take several minutes to complete. 16 | Each test will run several times to obtain summary statistics such as min, max, mean and median. 17 | This allows us to detect regressions in performance. 18 | Here is some example output (you need a wide screen to get the best of it): 19 | 20 | ``` 21 | $ SMART_OPEN_S3_URL=s3://bucket/smart_open_test py.test integration-tests/test_s3.py 22 | =============================================== test session starts ================================================ 23 | platform darwin -- Python 3.6.3, pytest-3.3.0, py-1.5.2, pluggy-0.6.0 24 | benchmark: 3.1.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) 25 | rootdir: /Users/misha/git/smart_open, inifile: 26 | plugins: benchmark-3.1.1 27 | collected 6 items 28 | 29 | integration-tests/test_s3.py ...... [100%] 30 | 31 | 32 | --------------------------------------------------------------------------------------- benchmark: 6 tests -------------------------------------------------------------------------------------- 33 | Name (time in s) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations 34 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 35 | test_s3_readwrite_text 2.7593 (1.0) 3.4935 (1.0) 3.2203 (1.0) 0.3064 (1.0) 3.3202 (1.04) 0.4730 (1.0) 1;0 0.3105 (1.0) 5 1 36 | test_s3_readwrite_text_gzip 3.0242 (1.10) 4.6782 (1.34) 3.7079 (1.15) 0.8531 (2.78) 3.2001 (1.0) 1.5850 (3.35) 2;0 0.2697 (0.87) 5 1 37 | test_s3_readwrite_binary 3.0549 (1.11) 3.9062 (1.12) 3.5399 (1.10) 0.3516 (1.15) 3.4721 (1.09) 0.5532 (1.17) 2;0 0.2825 (0.91) 5 1 38 | test_s3_performance_gz 3.1885 (1.16) 5.2845 (1.51) 3.9298 (1.22) 0.8197 (2.68) 3.6974 (1.16) 0.9693 (2.05) 1;0 0.2545 (0.82) 5 1 39 | test_s3_readwrite_binary_gzip 3.3756 (1.22) 5.0423 (1.44) 4.1763 (1.30) 0.6381 (2.08) 4.0722 (1.27) 0.9209 (1.95) 2;0 0.2394 (0.77) 5 1 40 | test_s3_performance 7.6758 (2.78) 29.5266 (8.45) 18.8346 (5.85) 10.3003 (33.62) 21.1854 (6.62) 19.6234 (41.49) 3;0 0.0531 (0.17) 5 1 41 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 42 | 43 | Legend: 44 | Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. 45 | OPS: Operations Per Second, computed as 1 / Mean 46 | ============================================ 6 passed in 285.14 seconds ============================================ 47 | ``` 48 | -------------------------------------------------------------------------------- /integration-tests/initialize_s3_bucket.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """Prepare an S3 bucket for our integration tests. 9 | 10 | Once the bucket is initialized, the tests in test_s3_ported.py should pass. 11 | """ 12 | 13 | import gzip 14 | import io 15 | import sys 16 | 17 | import boto3 18 | 19 | 20 | def gzip_compress(data): 21 | # 22 | # gzip.compress does not exist under Py2 23 | # 24 | buf = io.BytesIO() 25 | with gzip.GzipFile(fileobj=buf, mode='wb') as fout: 26 | fout.write(data) 27 | return buf.getvalue() 28 | 29 | 30 | def _build_contents(): 31 | hello_bytes = u"hello wořld\nhow are you?".encode('utf8') 32 | yield 'hello.txt', hello_bytes 33 | yield 'multiline.txt', b'englishman\nin\nnew\nyork\n' 34 | yield 'hello.txt.gz', gzip_compress(hello_bytes) 35 | 36 | for i in range(100): 37 | key = 'iter_bucket/%02d.txt' % i 38 | body = '\n'.join("line%i%i" % (i, line_no) for line_no in range(10)).encode('utf8') 39 | yield key, body 40 | 41 | 42 | CONTENTS = dict(_build_contents()) 43 | 44 | 45 | def main(): 46 | bucket_name = sys.argv[1] 47 | 48 | bucket = boto3.resource('s3').Bucket(bucket_name) 49 | 50 | # 51 | # Assume the bucket exists. Creating it ourselves and dealing with 52 | # timing issues is too much of a PITA. 53 | # 54 | for key in bucket.objects.all(): 55 | key.delete() 56 | 57 | for (key, body) in CONTENTS.items(): 58 | bucket.put_object(Key=key, Body=body) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /integration-tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest_benchmark 3 | awscli 4 | -------------------------------------------------------------------------------- /integration-tests/test_184.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import sys 9 | import time 10 | 11 | import smart_open 12 | 13 | open_fn = smart_open.smart_open 14 | # open_fn = open 15 | 16 | 17 | def report_time_iterate_rows(file_name, report_every=100000): 18 | start = time.time() 19 | last = start 20 | with open_fn(file_name, 'r') as f: 21 | for i, line in enumerate(f, start=1): 22 | if not (i % report_every): 23 | current = time.time() 24 | time_taken = current - last 25 | print('Time taken for %d rows: %.2f seconds, %.2f rows/s' % ( 26 | report_every, time_taken, report_every / time_taken)) 27 | last = current 28 | total = time.time() - start 29 | print('Total: %d rows, %.2f seconds, %.2f rows/s' % ( 30 | i, total, i / total)) 31 | 32 | 33 | report_time_iterate_rows(sys.argv[1]) 34 | -------------------------------------------------------------------------------- /integration-tests/test_207.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import os 9 | import sys 10 | import tempfile 11 | 12 | try: 13 | import numpy as np 14 | except ImportError: 15 | print("You really need numpy to proceed with this test") 16 | sys.exit(1) 17 | 18 | import smart_open 19 | 20 | 21 | def tofile(): 22 | dt = np.dtype([('time', [('min', int), ('sec', int)]), ('temp', float)]) 23 | x = np.zeros((1,), dtype=dt) 24 | 25 | with tempfile.NamedTemporaryFile(prefix='test_207', suffix='.dat', delete=False) as fout: 26 | x.tofile(fout.name) 27 | return fout.name 28 | 29 | 30 | def test(): 31 | try: 32 | path = tofile() 33 | with smart_open.smart_open(path, 'rb') as fin: 34 | loaded = np.fromfile(fin) 35 | del loaded 36 | return 0 37 | finally: 38 | os.unlink(path) 39 | return 1 40 | 41 | 42 | if __name__ == '__main__': 43 | sys.exit(test()) 44 | -------------------------------------------------------------------------------- /integration-tests/test_azure.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import os 4 | 5 | import azure.storage.blob 6 | 7 | from pytest import fixture 8 | 9 | import smart_open 10 | 11 | _AZURE_CONTAINER = os.environ.get('SO_AZURE_CONTAINER') 12 | _AZURE_STORAGE_CONNECTION_STRING = os.environ.get('AZURE_STORAGE_CONNECTION_STRING') 13 | _FILE_PREFIX = '%s://%s' % (smart_open.azure.SCHEME, _AZURE_CONTAINER) 14 | 15 | assert _AZURE_CONTAINER is not None, 'please set the SO_AZURE_CONTAINER environment variable' 16 | assert _AZURE_STORAGE_CONNECTION_STRING is not None, \ 17 | 'please set the AZURE_STORAGE_CONNECTION_STRING environment variable' 18 | 19 | 20 | @fixture 21 | def client(): 22 | # type: () -> azure.storage.blob.BlobServiceClient 23 | return azure.storage.blob.BlobServiceClient.from_connection_string(_AZURE_STORAGE_CONNECTION_STRING) 24 | 25 | 26 | def initialize_bucket(client): 27 | container_client = client.get_container_client(_AZURE_CONTAINER) 28 | blobs = container_client.list_blobs() 29 | for blob in blobs: 30 | container_client.delete_blob(blob=blob) 31 | 32 | 33 | def write_read(key, content, write_mode, read_mode, **kwargs): 34 | with smart_open.open(key, write_mode, **kwargs) as fout: 35 | fout.write(content) 36 | with smart_open.open(key, read_mode, **kwargs) as fin: 37 | return fin.read() 38 | 39 | 40 | def read_length_prefixed_messages(key, read_mode, **kwargs): 41 | result = io.BytesIO() 42 | 43 | with smart_open.open(key, read_mode, **kwargs) as fin: 44 | length_byte = fin.read(1) 45 | while len(length_byte): 46 | result.write(length_byte) 47 | msg = fin.read(ord(length_byte)) 48 | result.write(msg) 49 | length_byte = fin.read(1) 50 | return result.getvalue() 51 | 52 | 53 | def test_azure_readwrite_text(benchmark, client): 54 | initialize_bucket(client) 55 | 56 | key = _FILE_PREFIX + '/sanity.txt' 57 | text = 'с гранатою в кармане, с чекою в руке' 58 | actual = benchmark( 59 | write_read, key, text, 'w', 'r', encoding='utf-8', transport_params=dict(client=client) 60 | ) 61 | assert actual == text 62 | 63 | 64 | def test_azure_readwrite_text_gzip(benchmark, client): 65 | initialize_bucket(client) 66 | 67 | key = _FILE_PREFIX + '/sanity.txt.gz' 68 | text = 'не чайки здесь запели на знакомом языке' 69 | actual = benchmark( 70 | write_read, key, text, 'w', 'r', encoding='utf-8', transport_params=dict(client=client) 71 | ) 72 | assert actual == text 73 | 74 | 75 | def test_azure_readwrite_binary(benchmark, client): 76 | initialize_bucket(client) 77 | 78 | key = _FILE_PREFIX + '/sanity.txt' 79 | binary = b'this is a test' 80 | actual = benchmark(write_read, key, binary, 'wb', 'rb', transport_params=dict(client=client)) 81 | assert actual == binary 82 | 83 | 84 | def test_azure_readwrite_binary_gzip(benchmark, client): 85 | initialize_bucket(client) 86 | 87 | key = _FILE_PREFIX + '/sanity.txt.gz' 88 | binary = b'this is a test' 89 | actual = benchmark(write_read, key, binary, 'wb', 'rb', transport_params=dict(client=client)) 90 | assert actual == binary 91 | 92 | 93 | def test_azure_performance(benchmark, client): 94 | initialize_bucket(client) 95 | 96 | one_megabyte = io.BytesIO() 97 | for _ in range(1024*128): 98 | one_megabyte.write(b'01234567') 99 | one_megabyte = one_megabyte.getvalue() 100 | 101 | key = _FILE_PREFIX + '/performance.txt' 102 | actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb', transport_params=dict(client=client)) 103 | assert actual == one_megabyte 104 | 105 | 106 | def test_azure_performance_gz(benchmark, client): 107 | initialize_bucket(client) 108 | 109 | one_megabyte = io.BytesIO() 110 | for _ in range(1024*128): 111 | one_megabyte.write(b'01234567') 112 | one_megabyte = one_megabyte.getvalue() 113 | 114 | key = _FILE_PREFIX + '/performance.txt.gz' 115 | actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb', transport_params=dict(client=client)) 116 | assert actual == one_megabyte 117 | 118 | 119 | def test_azure_performance_small_reads(benchmark, client): 120 | initialize_bucket(client) 121 | 122 | ONE_MIB = 1024**2 123 | one_megabyte_of_msgs = io.BytesIO() 124 | msg = b'\x0f' + b'0123456789abcde' # a length-prefixed "message" 125 | for _ in range(0, ONE_MIB, len(msg)): 126 | one_megabyte_of_msgs.write(msg) 127 | one_megabyte_of_msgs = one_megabyte_of_msgs.getvalue() 128 | 129 | key = _FILE_PREFIX + '/many_reads_performance.bin' 130 | 131 | with smart_open.open(key, 'wb', transport_params=dict(client=client)) as fout: 132 | fout.write(one_megabyte_of_msgs) 133 | 134 | actual = benchmark( 135 | read_length_prefixed_messages, key, 'rb', buffering=ONE_MIB, transport_params=dict(client=client) 136 | ) 137 | assert actual == one_megabyte_of_msgs 138 | -------------------------------------------------------------------------------- /integration-tests/test_ftp.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import gzip 3 | import pytest 4 | from smart_open import open 5 | import ssl 6 | from functools import partial 7 | 8 | # localhost has self-signed cert, see ci_helpers/helpers.sh:create_ftp_ftps_servers 9 | ssl.create_default_context = partial(ssl.create_default_context, cafile="/etc/vsftpd.pem") 10 | 11 | 12 | @pytest.fixture(params=[("ftp", 21), ("ftps", 90)]) 13 | def server_info(request): 14 | return request.param 15 | 16 | def test_nonbinary(server_info): 17 | server_type = server_info[0] 18 | port_num = server_info[1] 19 | file_contents = "Test Test \n new test \n another tests" 20 | appended_content1 = "Added \n to end" 21 | 22 | with open(f"{server_type}://user:123@localhost:{port_num}/file", "w") as f: 23 | f.write(file_contents) 24 | 25 | with open(f"{server_type}://user:123@localhost:{port_num}/file", "r") as f: 26 | read_contents = f.read() 27 | assert read_contents == file_contents 28 | 29 | with open(f"{server_type}://user:123@localhost:{port_num}/file", "a") as f: 30 | f.write(appended_content1) 31 | 32 | with open(f"{server_type}://user:123@localhost:{port_num}/file", "r") as f: 33 | read_contents = f.read() 34 | assert read_contents == file_contents + appended_content1 35 | 36 | def test_binary(server_info): 37 | server_type = server_info[0] 38 | port_num = server_info[1] 39 | file_contents = b"Test Test \n new test \n another tests" 40 | appended_content1 = b"Added \n to end" 41 | 42 | with open(f"{server_type}://user:123@localhost:{port_num}/file2", "wb") as f: 43 | f.write(file_contents) 44 | 45 | with open(f"{server_type}://user:123@localhost:{port_num}/file2", "rb") as f: 46 | read_contents = f.read() 47 | assert read_contents == file_contents 48 | 49 | with open(f"{server_type}://user:123@localhost:{port_num}/file2", "ab") as f: 50 | f.write(appended_content1) 51 | 52 | with open(f"{server_type}://user:123@localhost:{port_num}/file2", "rb") as f: 53 | read_contents = f.read() 54 | assert read_contents == file_contents + appended_content1 55 | 56 | def test_compression(server_info): 57 | server_type = server_info[0] 58 | port_num = server_info[1] 59 | file_contents = "Test Test \n new test \n another tests" 60 | appended_content1 = "Added \n to end" 61 | 62 | with open(f"{server_type}://user:123@localhost:{port_num}/file.gz", "w") as f: 63 | f.write(file_contents) 64 | 65 | with open(f"{server_type}://user:123@localhost:{port_num}/file.gz", "r") as f: 66 | read_contents = f.read() 67 | assert read_contents == file_contents 68 | 69 | with open(f"{server_type}://user:123@localhost:{port_num}/file.gz", "a") as f: 70 | f.write(appended_content1) 71 | 72 | with open(f"{server_type}://user:123@localhost:{port_num}/file.gz", "r") as f: 73 | read_contents = f.read() 74 | assert read_contents == file_contents + appended_content1 75 | 76 | # ftp socket makefile returns a file whose name attribute is fileno() which is int 77 | # that can't be used to infer compression extension, so the calls above would 78 | # silently not use any compression (neither reading nor writing) so they would pass 79 | # pytest suppresses the logging.warning('unable to transparently decompress...') 80 | # so check here explicitly that the bytes on server are gzip compressed 81 | with open( 82 | f"{server_type}://user:123@localhost:{port_num}/file.gz", 83 | "rb", 84 | compression='disable', 85 | ) as f: 86 | read_contents = gzip.decompress(f.read()).decode() 87 | assert read_contents == file_contents + appended_content1 88 | 89 | def test_line_endings_non_binary(server_info): 90 | server_type = server_info[0] 91 | port_num = server_info[1] 92 | B_CLRF = b'\r\n' 93 | CLRF = '\r\n' 94 | file_contents = f"Test Test {CLRF} new test {CLRF} another tests{CLRF}" 95 | 96 | with open(f"{server_type}://user:123@localhost:{port_num}/file3", "w") as f: 97 | f.write(file_contents) 98 | 99 | with open(f"{server_type}://user:123@localhost:{port_num}/file3", "r") as f: 100 | for line in f: 101 | assert not CLRF in line 102 | 103 | with open(f"{server_type}://user:123@localhost:{port_num}/file3", "rb") as f: 104 | for line in f: 105 | assert B_CLRF in line 106 | 107 | def test_line_endings_binary(server_info): 108 | server_type = server_info[0] 109 | port_num = server_info[1] 110 | B_CLRF = b'\r\n' 111 | CLRF = '\r\n' 112 | file_contents = f"Test Test {CLRF} new test {CLRF} another tests{CLRF}".encode('utf-8') 113 | 114 | with open(f"{server_type}://user:123@localhost:{port_num}/file4", "wb") as f: 115 | f.write(file_contents) 116 | 117 | with open(f"{server_type}://user:123@localhost:{port_num}/file4", "r") as f: 118 | for line in f: 119 | assert not CLRF in line 120 | 121 | with open(f"{server_type}://user:123@localhost:{port_num}/file4", "rb") as f: 122 | for line in f: 123 | assert B_CLRF in line 124 | -------------------------------------------------------------------------------- /integration-tests/test_gcs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import os 4 | import urllib.parse 5 | 6 | import google.cloud.storage 7 | 8 | import smart_open 9 | 10 | _GCS_URL = os.environ.get('SO_GCS_URL') 11 | assert _GCS_URL is not None, 'please set the SO_GCS_URL environment variable' 12 | 13 | 14 | def initialize_bucket(): 15 | client = google.cloud.storage.Client() 16 | parsed = urllib.parse.urlparse(_GCS_URL) 17 | bucket_name = parsed.netloc 18 | prefix = parsed.path 19 | bucket = client.get_bucket(bucket_name) 20 | blobs = bucket.list_blobs(prefix=prefix) 21 | for blob in blobs: 22 | blob.delete() 23 | 24 | 25 | def write_read(key, content, write_mode, read_mode, **kwargs): 26 | with smart_open.open(key, write_mode, **kwargs) as fout: 27 | fout.write(content) 28 | with smart_open.open(key, read_mode, **kwargs) as fin: 29 | return fin.read() 30 | 31 | 32 | def read_length_prefixed_messages(key, read_mode, **kwargs): 33 | result = io.BytesIO() 34 | 35 | with smart_open.open(key, read_mode, **kwargs) as fin: 36 | length_byte = fin.read(1) 37 | while len(length_byte): 38 | result.write(length_byte) 39 | msg = fin.read(ord(length_byte)) 40 | result.write(msg) 41 | length_byte = fin.read(1) 42 | return result.getvalue() 43 | 44 | 45 | def test_gcs_readwrite_text(benchmark): 46 | initialize_bucket() 47 | 48 | key = _GCS_URL + '/sanity.txt' 49 | text = 'с гранатою в кармане, с чекою в руке' 50 | actual = benchmark(write_read, key, text, 'w', 'r', encoding='utf-8') 51 | assert actual == text 52 | 53 | 54 | def test_gcs_readwrite_text_gzip(benchmark): 55 | initialize_bucket() 56 | 57 | key = _GCS_URL + '/sanity.txt.gz' 58 | text = 'не чайки здесь запели на знакомом языке' 59 | actual = benchmark(write_read, key, text, 'w', 'r', encoding='utf-8') 60 | assert actual == text 61 | 62 | 63 | def test_gcs_readwrite_binary(benchmark): 64 | initialize_bucket() 65 | 66 | key = _GCS_URL + '/sanity.txt' 67 | binary = b'this is a test' 68 | actual = benchmark(write_read, key, binary, 'wb', 'rb') 69 | assert actual == binary 70 | 71 | 72 | def test_gcs_readwrite_binary_gzip(benchmark): 73 | initialize_bucket() 74 | 75 | key = _GCS_URL + '/sanity.txt.gz' 76 | binary = b'this is a test' 77 | actual = benchmark(write_read, key, binary, 'wb', 'rb') 78 | assert actual == binary 79 | 80 | 81 | def test_gcs_performance(benchmark): 82 | initialize_bucket() 83 | 84 | one_megabyte = io.BytesIO() 85 | for _ in range(1024*128): 86 | one_megabyte.write(b'01234567') 87 | one_megabyte = one_megabyte.getvalue() 88 | 89 | key = _GCS_URL + '/performance.txt' 90 | actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb') 91 | assert actual == one_megabyte 92 | 93 | 94 | def test_gcs_performance_gz(benchmark): 95 | initialize_bucket() 96 | 97 | one_megabyte = io.BytesIO() 98 | for _ in range(1024*128): 99 | one_megabyte.write(b'01234567') 100 | one_megabyte = one_megabyte.getvalue() 101 | 102 | key = _GCS_URL + '/performance.txt.gz' 103 | actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb') 104 | assert actual == one_megabyte 105 | 106 | 107 | def test_gcs_performance_small_reads(benchmark): 108 | initialize_bucket() 109 | 110 | ONE_MIB = 1024**2 111 | one_megabyte_of_msgs = io.BytesIO() 112 | msg = b'\x0f' + b'0123456789abcde' # a length-prefixed "message" 113 | for _ in range(0, ONE_MIB, len(msg)): 114 | one_megabyte_of_msgs.write(msg) 115 | one_megabyte_of_msgs = one_megabyte_of_msgs.getvalue() 116 | 117 | key = _GCS_URL + '/many_reads_performance.bin' 118 | 119 | with smart_open.open(key, 'wb') as fout: 120 | fout.write(one_megabyte_of_msgs) 121 | 122 | actual = benchmark(read_length_prefixed_messages, key, 'rb', buffering=ONE_MIB) 123 | assert actual == one_megabyte_of_msgs 124 | -------------------------------------------------------------------------------- /integration-tests/test_hdfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """ 9 | Sample code for HDFS integration tests. 10 | Requires hadoop to be running on localhost, at the moment. 11 | """ 12 | import smart_open 13 | 14 | with smart_open.smart_open("hdfs://user/root/input/core-site.xml") as fin: 15 | print(fin.read()) 16 | 17 | with smart_open.smart_open("hdfs://user/root/input/test.txt") as fin: 18 | print(fin.read()) 19 | 20 | with smart_open.smart_open("hdfs://user/root/input/test.txt?user.name=root", 'wb') as fout: 21 | fout.write(b'hello world') 22 | -------------------------------------------------------------------------------- /integration-tests/test_http.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | from __future__ import unicode_literals 9 | 10 | import logging 11 | import unittest 12 | 13 | import smart_open 14 | 15 | GZIP_MAGIC = b'\x1f\x8b' 16 | BASE_URL = ('https://raw.githubusercontent.com/RaRe-Technologies/smart_open/' 17 | 'master/smart_open/tests/test_data/') 18 | 19 | 20 | class ReadTest(unittest.TestCase): 21 | def test_read_text(self): 22 | url = BASE_URL + 'crime-and-punishment.txt' 23 | with smart_open.smart_open(url, encoding='utf-8') as fin: 24 | text = fin.read() 25 | self.assertTrue(text.startswith('В начале июля, в чрезвычайно жаркое время,')) 26 | self.assertTrue(text.endswith('улизнуть, чтобы никто не видал.\n')) 27 | 28 | def test_read_binary(self): 29 | url = BASE_URL + 'crime-and-punishment.txt' 30 | with smart_open.smart_open(url, 'rb') as fin: 31 | text = fin.read() 32 | self.assertTrue(text.startswith('В начале июля, в чрезвычайно'.encode('utf-8'))) 33 | self.assertTrue(text.endswith('улизнуть, чтобы никто не видал.\n'.encode('utf-8'))) 34 | 35 | def test_read_gzip_text(self): 36 | url = BASE_URL + 'crime-and-punishment.txt.gz' 37 | with smart_open.smart_open(url, encoding='utf-8') as fin: 38 | text = fin.read() 39 | self.assertTrue(text.startswith('В начале июля, в чрезвычайно жаркое время,')) 40 | self.assertTrue(text.endswith('улизнуть, чтобы никто не видал.\n')) 41 | 42 | def test_read_gzip_binary(self): 43 | url = BASE_URL + 'crime-and-punishment.txt.gz' 44 | with smart_open.smart_open(url, 'rb', ignore_extension=True) as fin: 45 | binary = fin.read() 46 | self.assertTrue(binary.startswith(GZIP_MAGIC)) 47 | 48 | 49 | if __name__ == '__main__': 50 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /integration-tests/test_minio.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import logging 9 | import boto3 10 | 11 | from smart_open import open 12 | 13 | # 14 | # These are publicly available via play.min.io 15 | # 16 | KEY_ID = 'Q3AM3UQ867SPQQA43P2F' 17 | SECRET_KEY = 'zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG' 18 | ENDPOINT_URL = 'https://play.min.io:9000' 19 | 20 | 21 | def read_boto3(): 22 | """Read directly using boto3.""" 23 | session = get_minio_session() 24 | s3 = session.resource('s3', endpoint_url=ENDPOINT_URL) 25 | 26 | obj = s3.Object('smart-open-test', 'README.rst') 27 | data = obj.get()['Body'].read() 28 | logging.info('read %d bytes via boto3', len(data)) 29 | return data 30 | 31 | 32 | def read_smart_open(): 33 | url = 's3://Q3AM3UQ867SPQQA43P2F:zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG@play.min.io:9000@smart-open-test/README.rst' # noqa 34 | 35 | # 36 | # If the default region is not us-east-1, we need to construct our own 37 | # session. This is because smart_open will create a session in the default 38 | # region, which _must_ be us-east-1 for minio to work. 39 | # 40 | tp = {} 41 | if get_default_region() != 'us-east-1': 42 | logging.info('injecting custom session') 43 | tp['session'] = get_minio_session() 44 | with open(url, transport_params=tp) as fin: 45 | text = fin.read() 46 | logging.info('read %d characters via smart_open', len(text)) 47 | return text 48 | 49 | 50 | def get_minio_session(): 51 | return boto3.Session( 52 | region_name='us-east-1', 53 | aws_access_key_id=KEY_ID, 54 | aws_secret_access_key=SECRET_KEY, 55 | ) 56 | 57 | 58 | def get_default_region(): 59 | return boto3.Session().region_name 60 | 61 | 62 | def main(): 63 | logging.basicConfig(level=logging.INFO) 64 | from_boto3 = read_boto3() 65 | from_smart_open = read_smart_open() 66 | assert from_boto3.decode('utf-8') == from_smart_open 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /integration-tests/test_s3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | from __future__ import unicode_literals 10 | import contextlib 11 | import io 12 | import os 13 | import random 14 | import string 15 | 16 | import boto3 17 | import smart_open 18 | 19 | _BUCKET = os.environ.get('SO_BUCKET') 20 | assert _BUCKET is not None, 'please set the SO_BUCKET environment variable' 21 | 22 | _KEY = os.environ.get('SO_KEY') 23 | assert _KEY is not None, 'please set the SO_KEY environment variable' 24 | 25 | 26 | # 27 | # https://stackoverflow.com/questions/13484726/safe-enough-8-character-short-unique-random-string 28 | # 29 | def _random_string(length=8): 30 | alphabet = string.ascii_lowercase + string.digits 31 | return ''.join(random.choices(alphabet, k=length)) 32 | 33 | 34 | @contextlib.contextmanager 35 | def temporary(): 36 | """Yields a URL than can be used for temporary writing. 37 | 38 | Removes all content under the URL when exiting. 39 | """ 40 | key = '%s/%s' % (_KEY, _random_string()) 41 | yield 's3://%s/%s' % (_BUCKET, key) 42 | boto3.resource('s3').Bucket(_BUCKET).objects.filter(Prefix=key).delete() 43 | 44 | 45 | def _test_case(function): 46 | def inner(benchmark): 47 | with temporary() as uri: 48 | return function(benchmark, uri) 49 | return inner 50 | 51 | 52 | def write_read(uri, content, write_mode, read_mode, encoding=None, s3_upload=None, **kwargs): 53 | write_params = dict(kwargs) 54 | write_params.update(s3_upload=s3_upload) 55 | with smart_open.open(uri, write_mode, encoding=encoding, transport_params=write_params) as fout: 56 | fout.write(content) 57 | with smart_open.open(uri, read_mode, encoding=encoding, transport_params=kwargs) as fin: 58 | actual = fin.read() 59 | return actual 60 | 61 | 62 | def read_length_prefixed_messages(uri, read_mode, encoding=None, **kwargs): 63 | with smart_open.open(uri, read_mode, encoding=encoding, transport_params=kwargs) as fin: 64 | actual = b'' 65 | length_byte = fin.read(1) 66 | while len(length_byte): 67 | actual += length_byte 68 | msg = fin.read(ord(length_byte)) 69 | actual += msg 70 | length_byte = fin.read(1) 71 | return actual 72 | 73 | 74 | @_test_case 75 | def test_s3_readwrite_text(benchmark, uri): 76 | text = 'с гранатою в кармане, с чекою в руке' 77 | actual = benchmark(write_read, uri, text, 'w', 'r', 'utf-8') 78 | assert actual == text 79 | 80 | 81 | @_test_case 82 | def test_s3_readwrite_text_gzip(benchmark, uri): 83 | text = 'не чайки здесь запели на знакомом языке' 84 | actual = benchmark(write_read, uri, text, 'w', 'r', 'utf-8') 85 | assert actual == text 86 | 87 | 88 | @_test_case 89 | def test_s3_readwrite_binary(benchmark, uri): 90 | binary = b'this is a test' 91 | actual = benchmark(write_read, uri, binary, 'wb', 'rb') 92 | assert actual == binary 93 | 94 | 95 | @_test_case 96 | def test_s3_readwrite_binary_gzip(benchmark, uri): 97 | binary = b'this is a test' 98 | actual = benchmark(write_read, uri, binary, 'wb', 'rb') 99 | assert actual == binary 100 | 101 | 102 | @_test_case 103 | def test_s3_performance(benchmark, uri): 104 | one_megabyte = io.BytesIO() 105 | for _ in range(1024*128): 106 | one_megabyte.write(b'01234567') 107 | one_megabyte = one_megabyte.getvalue() 108 | 109 | actual = benchmark(write_read, uri, one_megabyte, 'wb', 'rb') 110 | assert actual == one_megabyte 111 | 112 | 113 | @_test_case 114 | def test_s3_performance_gz(benchmark, uri): 115 | one_megabyte = io.BytesIO() 116 | for _ in range(1024*128): 117 | one_megabyte.write(b'01234567') 118 | one_megabyte = one_megabyte.getvalue() 119 | 120 | actual = benchmark(write_read, uri, one_megabyte, 'wb', 'rb') 121 | assert actual == one_megabyte 122 | 123 | 124 | @_test_case 125 | def test_s3_performance_small_reads(benchmark, uri): 126 | one_mib = 1024**2 127 | one_megabyte_of_msgs = io.BytesIO() 128 | msg = b'\x0f' + b'0123456789abcde' # a length-prefixed "message" 129 | for _ in range(0, one_mib, len(msg)): 130 | one_megabyte_of_msgs.write(msg) 131 | one_megabyte_of_msgs = one_megabyte_of_msgs.getvalue() 132 | 133 | with smart_open.open(uri, 'wb') as fout: 134 | fout.write(one_megabyte_of_msgs) 135 | 136 | actual = benchmark(read_length_prefixed_messages, uri, 'rb', buffer_size=one_mib) 137 | assert actual == one_megabyte_of_msgs 138 | 139 | 140 | @_test_case 141 | def test_s3_encrypted_file(benchmark, uri): 142 | text = 'с гранатою в кармане, с чекою в руке' 143 | s3_upload = {'ServerSideEncryption': 'AES256'} 144 | actual = benchmark(write_read, uri, text, 'w', 'r', 'utf-8', s3_upload=s3_upload) 145 | assert actual == text 146 | -------------------------------------------------------------------------------- /integration-tests/test_s3_buffering.py: -------------------------------------------------------------------------------- 1 | from smart_open import open 2 | 3 | 4 | def read_bytes(url, limit): 5 | bytes_ = [] 6 | with open(url, 'rb') as fin: 7 | for i in range(limit): 8 | bytes_.append(fin.read(1)) 9 | 10 | return bytes_ 11 | 12 | 13 | def test(benchmark): 14 | # 15 | # This file is around 850MB. 16 | # 17 | url = ( 18 | 's3://commoncrawl/crawl-data/CC-MAIN-2019-51/segments/1575541319511.97' 19 | '/warc/CC-MAIN-20191216093448-20191216121448-00559.warc.gz' 20 | ) 21 | limit = 1000000 22 | bytes_ = benchmark(read_bytes, url, limit) 23 | assert len(bytes_) == limit 24 | -------------------------------------------------------------------------------- /integration-tests/test_s3_readline.py: -------------------------------------------------------------------------------- 1 | from smart_open import open 2 | 3 | 4 | def read_lines(url, limit): 5 | lines = [] 6 | with open(url, 'r', errors='ignore') as fin: 7 | for i, l in enumerate(fin): 8 | if i == limit: 9 | break 10 | lines.append(l) 11 | 12 | return lines 13 | 14 | 15 | def test(benchmark): 16 | # 17 | # This file is around 850MB. 18 | # 19 | url = ( 20 | 's3://commoncrawl/crawl-data/CC-MAIN-2019-51/segments/1575541319511.97' 21 | '/warc/CC-MAIN-20191216093448-20191216121448-00559.warc.gz' 22 | ) 23 | limit = 1000000 24 | lines = benchmark(read_lines, url, limit) 25 | assert len(lines) == limit 26 | -------------------------------------------------------------------------------- /integration-tests/test_ssh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2022 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | import os 10 | import tempfile 11 | import pytest 12 | 13 | import smart_open 14 | import smart_open.ssh 15 | 16 | 17 | def explode(*args, **kwargs): 18 | raise RuntimeError("this function should never have been called") 19 | 20 | 21 | @pytest.mark.skipif("SMART_OPEN_SSH" not in os.environ, reason="this test only works on the dev machine") 22 | def test(): 23 | with smart_open.open("ssh://misha@localhost/Users/misha/git/smart_open/README.rst") as fin: 24 | readme = fin.read() 25 | 26 | assert 'smart_open — utils for streaming large files in Python' in readme 27 | 28 | # 29 | # Ensure the cache is being used 30 | # 31 | assert ('localhost', 'misha') in smart_open.ssh._SSH 32 | 33 | try: 34 | connect_ssh = smart_open.ssh._connect_ssh 35 | smart_open.ssh._connect_ssh = explode 36 | 37 | with smart_open.open("ssh://misha@localhost/Users/misha/git/smart_open/howto.md") as fin: 38 | howto = fin.read() 39 | 40 | assert 'How-to Guides' in howto 41 | finally: 42 | smart_open.ssh._connect_ssh = connect_ssh 43 | -------------------------------------------------------------------------------- /integration-tests/test_version_id.py: -------------------------------------------------------------------------------- 1 | """Tests the version_id transport parameter for S3 against real S3.""" 2 | 3 | import boto3 4 | from smart_open import open 5 | 6 | BUCKET, KEY = 'smart-open-versioned', 'demo.txt' 7 | """Our have a public-readable bucket with a versioned object.""" 8 | 9 | URL = 's3://%s/%s' % (BUCKET, KEY) 10 | 11 | 12 | def assert_equal(a, b): 13 | assert a == b, '%r != %r' % (a, b) 14 | 15 | 16 | def main(): 17 | versions = [ 18 | v.id for v in boto3.resource('s3').Bucket(BUCKET).object_versions.filter(Prefix=KEY) 19 | ] 20 | expected_versions = [ 21 | 'KiQpZPsKI5Dm2oJZy_RzskTOtl2snjBg', 22 | 'N0GJcE3TQCKtkaS.gF.MUBZS85Gs3hzn', 23 | ] 24 | assert_equal(versions, expected_versions) 25 | 26 | contents = [ 27 | open(URL, transport_params={'version_id': v}).read() 28 | for v in versions 29 | ] 30 | expected_contents = ['second version\n', 'first version\n'] 31 | assert_equal(contents, expected_contents) 32 | 33 | with open(URL) as fin: 34 | most_recent_contents = fin.read() 35 | assert_equal(most_recent_contents, expected_contents[0]) 36 | 37 | print('OK') 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /integration-tests/test_webhdfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """ 9 | Sample code for WebHDFS integration tests. 10 | To run it working WebHDFS in your network is needed - simply 11 | set _SO_WEBHDFS_BASE_URL env variable to webhdfs url you have 12 | write access to. 13 | 14 | For example on Amazon EMR WebHDFS is accessible on driver port 14000, so 15 | it may look like: 16 | 17 | $ export SO_WEBHDFS_BASE_URL=webhdfs://hadoop@your-emr-driver:14000/tmp/ 18 | $ py.test integration-tests/test_webhdfs.py 19 | """ 20 | import json 21 | import os 22 | import smart_open 23 | from smart_open.webhdfs import WebHdfsException 24 | import pytest 25 | 26 | _SO_WEBHDFS_BASE_URL = os.environ.get("SO_WEBHDFS_BASE_URL") 27 | assert ( 28 | _SO_WEBHDFS_BASE_URL is not None 29 | ), "please set the SO_WEBHDFS_BASE_URL environment variable" 30 | 31 | 32 | def make_url(path): 33 | return "{base_url}/{path}".format( 34 | base_url=_SO_WEBHDFS_BASE_URL.rstrip("/"), path=path.lstrip("/") 35 | ) 36 | 37 | 38 | def test_write_and_read(): 39 | with smart_open.open(make_url("test2.txt"), "w") as f: 40 | f.write("write_test\n") 41 | with smart_open.open(make_url("test2.txt"), "r") as f: 42 | assert f.read() == "write_test\n" 43 | 44 | 45 | def test_binary_write_and_read(): 46 | with smart_open.open(make_url("test3.txt"), "wb") as f: 47 | f.write(b"binary_write_test\n") 48 | with smart_open.open(make_url("test3.txt"), "rb") as f: 49 | assert f.read() == b"binary_write_test\n" 50 | 51 | 52 | def test_not_found(): 53 | with pytest.raises(WebHdfsException) as exc_info: 54 | with smart_open.open(make_url("not_existing"), "r") as f: 55 | assert f.read() 56 | assert exc_info.value.status_code == 404 57 | 58 | 59 | def test_quoted_path(): 60 | with smart_open.open(make_url("test_%40_4.txt"), "w") as f: 61 | f.write("write_test\n") 62 | 63 | with smart_open.open(make_url("?op=LISTSTATUS"), "r") as f: 64 | data = json.load(f) 65 | filenames = [ 66 | entry["pathSuffix"] for entry in data["FileStatuses"]["FileStatus"] 67 | ] 68 | assert "test_@_4.txt" in filenames 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | testpaths = ["smart_open"] 3 | -------------------------------------------------------------------------------- /release/README.md: -------------------------------------------------------------------------------- 1 | # Release Scripts 2 | 3 | This subdirectory contains various scripts for making a smart_open release. 4 | 5 | ## Prerequisites 6 | 7 | You need a GNU-like environment to run these scripts. I perform the releases 8 | using Ubuntu 18.04, but other O/S like MacOS should also work. The 9 | prerequisites are minimal: 10 | 11 | - bash 12 | - git with authentication set up (e.g. via ssh-agent) 13 | - virtualenv 14 | - pip 15 | 16 | All of the above are generally freely available, e.g. installable via apt in Ubuntu. 17 | 18 | ## Release Procedure 19 | 20 | First, check that the [latest commit](https://github.com/RaRe-Technologies/smart_open/commits/master) passed all CI. 21 | 22 | For the subsequent steps to work, you will need to be in the top-level subdirectory for the repo (e.g. /home/misha/git/smart_open). 23 | 24 | Prepare the release, replacing 2.3.4 with the actual version of the new release: 25 | 26 | bash release/prepare.sh 2.3.4 27 | 28 | This will create a local release branch. 29 | Look around the branch and make sure everything is in order. 30 | Checklist: 31 | 32 | - [ ] Does smart_open/version.py contain the correct version number for the release? 33 | - [ ] Does the CHANGELOG.md contain a section detailing the new release? 34 | - [ ] Are there any PRs that should be in CHANGELOG.md, but currently aren't? 35 | 36 | If anything is out of order, make the appropriate changes and commit them to the release branch before proceeding. 37 | 38 | **This is the point of no return**. 39 | **Once you're happy with the release branch**, run: 40 | 41 | bash release/merge.sh 42 | 43 | Congratulations, at this stage, you are done! 44 | 45 | ## Troubleshooting 46 | 47 | Ideally, our CI should save you from major boo-boos along the way. 48 | If the build is broken, fix it before even thinking about doing a release. 49 | 50 | If anything is wrong with the local release branch (before you call merge.sh), for example: 51 | 52 | - Typo in CHANGELOG.md 53 | - Missing entries in CHANGELOG.md 54 | - Wrong version.py number 55 | 56 | then just fix it in the release branch before moving on. 57 | 58 | Otherwise, it's too late to fix anything for the current release. 59 | Make a bugfix release to fix the problem. 60 | -------------------------------------------------------------------------------- /release/annotate_pr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Helper script for including change log entries in an open PR. 3 | 4 | Automatically constructs the change log entry from the PR title. 5 | Copies the entry to the window manager clipboard. 6 | Opens the change log belonging to the specific PR in a browser window. 7 | All you have to do is paste and click "commit changes". 8 | """ 9 | import json 10 | import sys 11 | import webbrowser 12 | 13 | import smart_open 14 | 15 | 16 | def copy_to_clipboard(text): 17 | try: 18 | import pyperclip 19 | except ImportError: 20 | print('pyperclip is missing.', file=sys.stderr) 21 | print('copy-paste the following text manually:', file=sys.stderr) 22 | print('\t', text, file=sys.stderr) 23 | else: 24 | pyperclip.copy(text) 25 | 26 | 27 | prid = int(sys.argv[1]) 28 | url = "https://api.github.com/repos/RaRe-Technologies/smart_open/pulls/%d" % prid 29 | with smart_open.open(url) as fin: 30 | prinfo = json.load(fin) 31 | 32 | prinfo['user_login'] = prinfo['user']['login'] 33 | prinfo['user_html_url'] = prinfo['user']['html_url'] 34 | text = '- %(title)s (PR [#%(number)s](%(html_url)s), [@%(user_login)s](%(user_html_url)s))' % prinfo 35 | copy_to_clipboard(text) 36 | 37 | prinfo['head_repo_html_url'] = prinfo['head']['repo']['html_url'] 38 | prinfo['head_ref'] = prinfo['head']['ref'] 39 | edit_url = '%(head_repo_html_url)s/edit/%(head_ref)s/CHANGELOG.md' % prinfo 40 | webbrowser.open(edit_url) 41 | -------------------------------------------------------------------------------- /release/check_preamble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Checks preambles of Python script files. 10 | 11 | We want to ensure they all contain the appropriate license and copyright. 12 | 13 | For the purposes of this script, the *preamble* is defined as the first 14 | lines of the file starting with a hash (#). Any line that does not start 15 | with a hash ends the preamble. 16 | 17 | Usage:: 18 | 19 | python check_preamble.py --replace /path/to/template.py script.py 20 | 21 | The above command reads the preamble from ``template.py``, and then copies 22 | that preamble into ``script.py``. If ``script.py`` already contains a 23 | preamble, then the existing preamble will be replaced **entirely**. 24 | 25 | Processing entire subdirectories with one command:: 26 | 27 | find subdir1 subdir2 -iname "*.py" | xargs -n 1 python check_preamble.py --replace template.py 28 | 29 | """ 30 | import argparse 31 | import logging 32 | import os 33 | import sys 34 | 35 | 36 | def extract_preamble(fin): 37 | end_preamble = False 38 | preamble, body = [], [] 39 | 40 | for line in fin: 41 | if end_preamble: 42 | body.append(line) 43 | elif line.startswith('#'): 44 | preamble.append(line) 45 | else: 46 | end_preamble = True 47 | body.append(line) 48 | 49 | return preamble, body 50 | 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('path', help='the path of the file to check') 55 | parser.add_argument('--replace', help='replace the preamble with the one from this file') 56 | parser.add_argument('--loglevel', default=logging.INFO) 57 | args = parser.parse_args() 58 | 59 | logging.basicConfig(level=args.loglevel) 60 | 61 | with open(args.path) as fin: 62 | preamble, body = extract_preamble(fin) 63 | 64 | for line in preamble: 65 | logging.info('%s: %s', args.path, line.rstrip()) 66 | 67 | if not args.replace: 68 | sys.exit(0) 69 | 70 | with open(args.replace) as fin: 71 | preamble, _ = extract_preamble(fin) 72 | 73 | if os.access(args.path, os.X_OK): 74 | preamble.insert(0, '#!/usr/bin/env python\n') 75 | 76 | with open(args.path, 'w') as fout: 77 | for line in preamble + body: 78 | fout.write(line) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /release/doctest.sh: -------------------------------------------------------------------------------- 1 | script_dir="$(dirname "${BASH_SOURCE[0]}")" 2 | 3 | export AWS_ACCESS_KEY_ID=$(aws --profile smart_open configure get aws_access_key_id) 4 | export AWS_SECRET_ACCESS_KEY=$(aws --profile smart_open configure get aws_secret_access_key) 5 | 6 | # 7 | # Using the current environment, which has smart_open installed. 8 | # 9 | cd "$script_dir/.." 10 | python -m doctest README.rst 11 | -------------------------------------------------------------------------------- /release/hijack_pr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Hijack a PR to add commits as a maintainer. 3 | 4 | This is a two-step process: 5 | 6 | 1. Add a git remote that points to the contributor's repo 7 | 2. Check out the actual contribution by reference 8 | 9 | As a maintainer, you can add changes by making new commits and pushing them 10 | back to the remote. 11 | """ 12 | import json 13 | import subprocess 14 | import sys 15 | 16 | import smart_open 17 | 18 | prid = int(sys.argv[1]) 19 | url = f"https://api.github.com/repos/RaRe-Technologies/smart_open/pulls/{prid}" 20 | with smart_open.open(url) as fin: 21 | prinfo = json.load(fin) 22 | 23 | user = prinfo['head']['user']['login'] 24 | ssh_url = prinfo['head']['repo']['ssh_url'] 25 | 26 | remotes = subprocess.check_output(['git', 'remote']).strip().decode('utf-8').split('\n') 27 | if user not in remotes: 28 | subprocess.check_call(['git', 'remote', 'add', user, ssh_url]) 29 | 30 | subprocess.check_call(['git', 'fetch', user]) 31 | 32 | ref = prinfo['head']['ref'] 33 | subprocess.check_call(['git', 'checkout', f'{user}/{ref}']) 34 | subprocess.check_call(['git', 'switch', '-c', f'{ref}']) 35 | -------------------------------------------------------------------------------- /release/merge.sh: -------------------------------------------------------------------------------- 1 | # 2 | # This script performs the following tasks: 3 | # 4 | # - Merges the current release branch into master 5 | # - Applies a tag to master 6 | # - Merges 7 | # - Pushes the updated master branch and its tag to upstream 8 | # 9 | # - develop: Our development branch. We merge all PRs into this branch. 10 | # - release-$version: A local branch containing commits specific to this release. 11 | # This is a local-only branch, we never push this anywhere. 12 | # - master: Our "clean" release branch. Contains tags. 13 | # 14 | # The relationships between the three branches are illustrated below: 15 | # 16 | # github.com PRs 17 | # \ 18 | # develop --+--+----------------------------------+--- 19 | # \ / 20 | # (new branch) \ commits (CHANGELOG.md, etc) / 21 | # \ v / 22 | # release ---*-----X (delete branch) / (merge 2) 23 | # \ / 24 | # (merge 1) \ TAG / 25 | # \ v / 26 | # master -------------------+------*-----+----------- 27 | # 28 | # Use it like this: 29 | # 30 | # bash release/merge.sh 31 | # 32 | # Expects smart_open/version.py to be correctly incremented for the new release. 33 | # 34 | set -euo pipefail 35 | 36 | cd "$(dirname "${BASH_SOURCE[0]}")/.." 37 | 38 | version="$(python smart_open/version.py)" 39 | 40 | read -p "Push version $version to github.com and PyPI? yes or no: " reply 41 | if [ "$reply" != "yes" ] 42 | then 43 | echo "aborted by user" 44 | exit 1 45 | fi 46 | 47 | # 48 | # Delete the local develop branch in case one is left lying around. 49 | # 50 | set +e 51 | git branch -D develop 52 | git branch -D master 53 | set -e 54 | 55 | git checkout upstream/master -b master 56 | git merge --no-ff release-${version} 57 | git tag -a "v${version}" -m "v${version}" 58 | 59 | git checkout upstream/develop -b develop 60 | git merge --no-ff master 61 | 62 | # 63 | # N.B. these push steps are non-reversible. 64 | # 65 | git checkout master 66 | git push --tags upstream master 67 | 68 | git checkout develop 69 | dev_version="$version.dev0" 70 | sed --in-place="" -e s/$(python smart_open/version.py)/$dev_version/ smart_open/version.py 71 | git commit smart_open/version.py -m "bump version to $dev_version" 72 | git push upstream develop 73 | 74 | python release/update_release_notes.py "$version" 75 | -------------------------------------------------------------------------------- /release/prepare.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Prepare a new release of smart_open. Use it like this: 3 | # 4 | # bash release/prepare.sh 1.2.3 5 | # 6 | # where 1.2.3 is the new version to release. 7 | # 8 | # Does the following: 9 | # 10 | # - Creates a clean virtual environment 11 | # - Creates a local release git branch 12 | # - Bumps VERSION accordingly 13 | # - Opens CHANGELOG.md for editing, commits updates 14 | # 15 | # Once you're happy, run merge.sh to continue with the release. 16 | # 17 | set -euxo pipefail 18 | 19 | version="$1" 20 | echo "version: $version" 21 | 22 | script_dir="$(dirname "${BASH_SOURCE[0]}")" 23 | cd "$script_dir/.." 24 | 25 | git fetch upstream 26 | 27 | # 28 | # Delete the release branch in case one is left lying around. 29 | # 30 | git checkout upstream/develop 31 | set +e 32 | git branch -D release-"$version" 33 | set -e 34 | 35 | git checkout upstream/develop -b release-"$version" 36 | sed --in-place="" -e "s/$(python smart_open/version.py)/$version/" smart_open/version.py 37 | git commit smart_open/version.py -m "bump version to $version" 38 | 39 | echo "Next, update CHANGELOG.md." 40 | echo "Consider running summarize_pr.sh for each PR merged since the last release." 41 | read -p "Press Enter to continue..." 42 | 43 | ${EDITOR:-vim} CHANGELOG.md 44 | set +e 45 | git commit CHANGELOG.md -m "updated CHANGELOG.md for version $version" 46 | set -e 47 | 48 | echo "Have a look at the current branch, and if all looks good, run merge.sh" 49 | -------------------------------------------------------------------------------- /release/update_help_txt.sh: -------------------------------------------------------------------------------- 1 | script_dir="$(dirname "${BASH_SOURCE[0]}")" 2 | 3 | # 4 | # Using the current environment, which has smart_open installed. 5 | # 6 | cd "$script_dir/.." 7 | python -c 'help("smart_open")' > help.txt 8 | git commit help.txt -m "updated help.txt" 9 | -------------------------------------------------------------------------------- /release/update_release_notes.py: -------------------------------------------------------------------------------- 1 | """Helper script for updating the release notes. 2 | 3 | Copies the change log to the window manager clipboard. 4 | Opens the release notes using the browser. 5 | All you have to do is paste and click "commit changes". 6 | """ 7 | import os 8 | import sys 9 | import webbrowser 10 | 11 | version = sys.argv[1] 12 | curr_dir = os.path.dirname(__file__) 13 | 14 | 15 | def copy_to_clipboard(text): 16 | try: 17 | import pyperclip 18 | except ImportError: 19 | print('pyperclip is missing.', file=sys.stderr) 20 | print('copy-paste the contents of CHANGELOG.md manually', file=sys.stderr) 21 | else: 22 | pyperclip.copy(text) 23 | 24 | 25 | with open(os.path.join(curr_dir, '../CHANGELOG.md')) as fin: 26 | copy_to_clipboard(fin.read()) 27 | 28 | 29 | url = "https://github.com/RaRe-Technologies/smart_open/releases/tag/v%s" % version 30 | webbrowser.open(url) 31 | -------------------------------------------------------------------------------- /sampledata/hello.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/sampledata/hello.zip -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2015 Radim Rehurek 5 | # 6 | # This code is distributed under the terms and conditions 7 | # from the MIT License (MIT). 8 | 9 | 10 | import io 11 | import os 12 | 13 | from setuptools import setup, find_packages 14 | 15 | 16 | def _get_version(): 17 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 18 | with open(os.path.join(curr_dir, 'smart_open', 'version.py')) as fin: 19 | line = fin.readline().strip() 20 | parts = line.split(' ') 21 | assert len(parts) == 3 22 | assert parts[0] == '__version__' 23 | assert parts[1] == '=' 24 | return parts[2].strip('\'"') 25 | 26 | 27 | # 28 | # We cannot do "from smart_open.version import __version__" because that will 29 | # require the dependencies for smart_open to already be in place, and that is 30 | # not necessarily the case when running setup.py for the first time. 31 | # 32 | __version__ = _get_version() 33 | 34 | 35 | def read(fname): 36 | return io.open(os.path.join(os.path.dirname(__file__), fname), encoding='utf-8').read() 37 | 38 | base_deps = ['wrapt'] 39 | aws_deps = ['boto3'] 40 | gcs_deps = ['google-cloud-storage>=2.6.0'] 41 | azure_deps = ['azure-storage-blob', 'azure-common', 'azure-core'] 42 | http_deps = ['requests'] 43 | ssh_deps = ['paramiko'] 44 | zst_deps = ['zstandard'] 45 | 46 | all_deps = aws_deps + gcs_deps + azure_deps + http_deps + ssh_deps + zst_deps 47 | tests_require = all_deps + [ 48 | 'moto[server]', 49 | 'responses', 50 | 'pytest', 51 | 'pytest-rerunfailures', 52 | 'pytest_benchmark', 53 | 'awscli', 54 | 'pyopenssl', 55 | 'numpy', 56 | ] 57 | 58 | setup( 59 | name='smart_open', 60 | version=__version__, 61 | description='Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)', 62 | long_description=read('README.rst'), 63 | packages=find_packages(exclude=["smart_open.tests*"]), 64 | author='Radim Rehurek', 65 | author_email='me@radimrehurek.com', 66 | maintainer='Radim Rehurek', 67 | maintainer_email='me@radimrehurek.com', 68 | 69 | url='https://github.com/piskvorky/smart_open', 70 | download_url='http://pypi.python.org/pypi/smart_open', 71 | 72 | keywords='file streaming, s3, hdfs, gcs, azure blob storage', 73 | 74 | license='MIT', 75 | platforms='any', 76 | 77 | install_requires=base_deps, 78 | tests_require=tests_require, 79 | extras_require={ 80 | 'test': tests_require, 81 | 's3': aws_deps, 82 | 'gcs': gcs_deps, 83 | 'azure': azure_deps, 84 | 'all': all_deps, 85 | 'http': http_deps, 86 | 'webhdfs': http_deps, 87 | 'ssh': ssh_deps, 88 | 'zst': zst_deps, 89 | }, 90 | python_requires=">=3.7,<4.0", 91 | 92 | test_suite="smart_open.tests", 93 | 94 | classifiers=[ 95 | 'Development Status :: 5 - Production/Stable', 96 | 'Environment :: Console', 97 | 'Intended Audience :: Developers', 98 | 'License :: OSI Approved :: MIT License', 99 | 'Operating System :: OS Independent', 100 | 'Programming Language :: Python :: 3.7', 101 | 'Programming Language :: Python :: 3.8', 102 | 'Programming Language :: Python :: 3.9', 103 | 'Programming Language :: Python :: 3.10', 104 | 'Programming Language :: Python :: 3.11', 105 | 'Programming Language :: Python :: 3.12', 106 | 'Programming Language :: Python :: 3.13', 107 | 'Topic :: System :: Distributed Computing', 108 | 'Topic :: Database :: Front-Ends', 109 | ], 110 | ) 111 | -------------------------------------------------------------------------------- /smart_open/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """ 10 | Utilities for streaming to/from several file-like data storages: S3 / HDFS / local 11 | filesystem / compressed files, and many more, using a simple, Pythonic API. 12 | 13 | The streaming makes heavy use of generators and pipes, to avoid loading 14 | full file contents into memory, allowing work with arbitrarily large files. 15 | 16 | The main functions are: 17 | 18 | * `open()`, which opens the given file for reading/writing 19 | * `parse_uri()` 20 | * `s3_iter_bucket()`, which goes over all keys in an S3 bucket in parallel 21 | * `register_compressor()`, which registers callbacks for transparent compressor handling 22 | 23 | """ 24 | 25 | import logging 26 | 27 | # 28 | # Prevent regression of #474 and #475 29 | # 30 | logger = logging.getLogger(__name__) 31 | logger.addHandler(logging.NullHandler()) 32 | 33 | from smart_open import version # noqa: E402 34 | from .smart_open_lib import open, parse_uri, smart_open, register_compressor # noqa: E402 35 | 36 | _WARNING = """smart_open.s3_iter_bucket is deprecated and will stop functioning 37 | in a future version. Please import iter_bucket from the smart_open.s3 module instead: 38 | 39 | from smart_open.s3 import iter_bucket as s3_iter_bucket 40 | 41 | """ 42 | _WARNED = False 43 | 44 | 45 | def s3_iter_bucket( 46 | bucket_name, 47 | prefix='', 48 | accept_key=None, 49 | key_limit=None, 50 | workers=16, 51 | retries=3, 52 | **session_kwargs 53 | ): 54 | """Deprecated. Use smart_open.s3.iter_bucket instead.""" 55 | global _WARNED 56 | from .s3 import iter_bucket 57 | if not _WARNED: 58 | logger.warning(_WARNING) 59 | _WARNED = True 60 | return iter_bucket( 61 | bucket_name=bucket_name, 62 | prefix=prefix, 63 | accept_key=accept_key, 64 | key_limit=key_limit, 65 | workers=workers, 66 | retries=retries, 67 | session_kwargs=session_kwargs 68 | ) 69 | 70 | 71 | __all__ = [ 72 | 'open', 73 | 'parse_uri', 74 | 'register_compressor', 75 | 's3_iter_bucket', 76 | 'smart_open', 77 | ] 78 | 79 | __version__ = version.__version__ 80 | -------------------------------------------------------------------------------- /smart_open/bytebuffer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """Implements ByteBuffer class for amortizing network transfer overhead.""" 9 | 10 | import io 11 | 12 | 13 | class ByteBuffer(object): 14 | """Implements a byte buffer that allows callers to read data with minimal 15 | copying, and has a fast __len__ method. The buffer is parametrized by its 16 | chunk_size, which is the number of bytes that it will read in from the 17 | supplied reader or iterable when the buffer is being filled. As primary use 18 | case for this buffer is to amortize the overhead costs of transferring data 19 | over the network (rather than capping memory consumption), it leads to more 20 | predictable performance to always read the same amount of bytes each time 21 | the buffer is filled, hence the chunk_size parameter instead of some fixed 22 | capacity. 23 | 24 | The bytes are stored in a bytestring, and previously-read bytes are freed 25 | when the buffer is next filled (by slicing the bytestring into a smaller 26 | copy). 27 | 28 | Example 29 | ------- 30 | 31 | Note that while this example works in both Python 2 and 3, the doctest only 32 | passes in Python 3 due to the bytestring literals in the expected values. 33 | 34 | >>> buf = ByteBuffer(chunk_size = 8) 35 | >>> message_bytes = iter([b'Hello, W', b'orld!']) 36 | >>> buf.fill(message_bytes) 37 | 8 38 | >>> len(buf) # only chunk_size bytes are filled 39 | 8 40 | >>> buf.peek() 41 | b'Hello, W' 42 | >>> len(buf) # peek() does not change read position 43 | 8 44 | >>> buf.read(6) 45 | b'Hello,' 46 | >>> len(buf) # read() does change read position 47 | 2 48 | >>> buf.fill(message_bytes) 49 | 5 50 | >>> buf.read() 51 | b' World!' 52 | >>> len(buf) 53 | 0 54 | """ 55 | 56 | def __init__(self, chunk_size=io.DEFAULT_BUFFER_SIZE): 57 | """Create a ByteBuffer instance that reads chunk_size bytes when filled. 58 | Note that the buffer has no maximum size. 59 | 60 | Parameters 61 | ----------- 62 | chunk_size: int, optional 63 | The the number of bytes that will be read from the supplied reader 64 | or iterable when filling the buffer. 65 | """ 66 | self._chunk_size = chunk_size 67 | self.empty() 68 | 69 | def __len__(self): 70 | """Return the number of unread bytes in the buffer as an int""" 71 | return len(self._bytes) - self._pos 72 | 73 | def read(self, size=-1): 74 | """Read bytes from the buffer and advance the read position. Returns 75 | the bytes in a bytestring. 76 | 77 | Parameters 78 | ---------- 79 | size: int, optional 80 | Maximum number of bytes to read. If negative or not supplied, read 81 | all unread bytes in the buffer. 82 | 83 | Returns 84 | ------- 85 | bytes 86 | """ 87 | part = self.peek(size) 88 | self._pos += len(part) 89 | return part 90 | 91 | def peek(self, size=-1): 92 | """Get bytes from the buffer without advancing the read position. 93 | Returns the bytes in a bytestring. 94 | 95 | Parameters 96 | ---------- 97 | size: int, optional 98 | Maximum number of bytes to return. If negative or not supplied, 99 | return all unread bytes in the buffer. 100 | 101 | Returns 102 | ------- 103 | bytes 104 | """ 105 | if size < 0 or size > len(self): 106 | size = len(self) 107 | 108 | part = bytes(self._bytes[self._pos:self._pos+size]) 109 | return part 110 | 111 | def empty(self): 112 | """Remove all bytes from the buffer""" 113 | self._bytes = bytearray() 114 | self._pos = 0 115 | 116 | def fill(self, source, size=-1): 117 | """Fill the buffer with bytes from source until one of these 118 | conditions is met: 119 | * size bytes have been read from source (if size >= 0); 120 | * chunk_size bytes have been read from source; 121 | * no more bytes can be read from source; 122 | Returns the number of new bytes added to the buffer. 123 | Note: all previously-read bytes in the buffer are removed. 124 | 125 | Parameters 126 | ---------- 127 | source: a file-like object, or iterable/list that contains bytes 128 | The source of bytes to fill the buffer with. If this argument has 129 | the `read` attribute, it's assumed to be a file-like object and 130 | `read` is called to get the bytes; otherwise it's assumed to be an 131 | iterable or list that contains bytes, and a for loop is used to get 132 | the bytes. 133 | size: int, optional 134 | The number of bytes to try to read from source. If not supplied, 135 | negative, or larger than the buffer's chunk_size, then chunk_size 136 | bytes are read. Note that if source is an iterable or list, then 137 | it's possible that more than size bytes will be read if iterating 138 | over source produces more than one byte at a time. 139 | 140 | Returns 141 | ------- 142 | int, the number of new bytes added to the buffer. 143 | """ 144 | size = size if size >= 0 else self._chunk_size 145 | size = min(size, self._chunk_size) 146 | 147 | if self._pos != 0: 148 | self._bytes = self._bytes[self._pos:] 149 | self._pos = 0 150 | 151 | if hasattr(source, 'read'): 152 | new_bytes = source.read(size) 153 | else: 154 | new_bytes = bytearray() 155 | for more_bytes in source: 156 | new_bytes += more_bytes 157 | if len(new_bytes) >= size: 158 | break 159 | 160 | self._bytes += new_bytes 161 | return len(new_bytes) 162 | 163 | def readline(self, terminator): 164 | """Read a line from this buffer efficiently. 165 | 166 | A line is a contiguous sequence of bytes that ends with either: 167 | 168 | 1. The ``terminator`` character 169 | 2. The end of the buffer itself 170 | 171 | :param byte terminator: The line terminator character. 172 | :rtype: bytes 173 | 174 | """ 175 | index = self._bytes.find(terminator, self._pos) 176 | if index == -1: 177 | size = len(self) 178 | else: 179 | size = index - self._pos + 1 180 | return self.read(size) 181 | -------------------------------------------------------------------------------- /smart_open/compression.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """Implements the compression layer of the ``smart_open`` library.""" 9 | import io 10 | import logging 11 | import os.path 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | _COMPRESSOR_REGISTRY = {} 16 | 17 | NO_COMPRESSION = 'disable' 18 | """Use no compression. Read/write the data as-is.""" 19 | INFER_FROM_EXTENSION = 'infer_from_extension' 20 | """Determine the compression to use from the file extension. 21 | 22 | See get_supported_extensions(). 23 | """ 24 | 25 | 26 | def get_supported_compression_types(): 27 | """Return the list of supported compression types available to open. 28 | 29 | See compression paratemeter to smart_open.open(). 30 | """ 31 | return [NO_COMPRESSION, INFER_FROM_EXTENSION] + get_supported_extensions() 32 | 33 | 34 | def get_supported_extensions(): 35 | """Return the list of file extensions for which we have registered compressors.""" 36 | return sorted(_COMPRESSOR_REGISTRY.keys()) 37 | 38 | 39 | def register_compressor(ext, callback): 40 | """Register a callback for transparently decompressing files with a specific extension. 41 | 42 | Parameters 43 | ---------- 44 | ext: str 45 | The extension. Must include the leading period, e.g. ``.gz``. 46 | callback: callable 47 | The callback. It must accept two position arguments, file_obj and mode. 48 | This function will be called when ``smart_open`` is opening a file with 49 | the specified extension. 50 | 51 | Examples 52 | -------- 53 | 54 | Instruct smart_open to use the `lzma` module whenever opening a file 55 | with a .xz extension (see README.rst for the complete example showing I/O): 56 | 57 | >>> def _handle_xz(file_obj, mode): 58 | ... import lzma 59 | ... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ) 60 | >>> 61 | >>> register_compressor('.xz', _handle_xz) 62 | 63 | """ 64 | if not (ext and ext[0] == '.'): 65 | raise ValueError('ext must be a string starting with ., not %r' % ext) 66 | ext = ext.lower() 67 | if ext in _COMPRESSOR_REGISTRY: 68 | logger.warning('overriding existing compression handler for %r', ext) 69 | _COMPRESSOR_REGISTRY[ext] = callback 70 | 71 | 72 | def tweak_close(outer, inner): 73 | """Ensure that closing the `outer` stream closes the `inner` stream as well. 74 | 75 | Deprecated: smart_open.open().__exit__ now always calls __exit__ on the 76 | underlying filestream. 77 | 78 | Use this when your compression library's `close` method does not 79 | automatically close the underlying filestream. See 80 | https://github.com/piskvorky/smart_open/issues/630 for an 81 | explanation why that is a problem for smart_open. 82 | """ 83 | outer_close = outer.close 84 | 85 | def close_both(*args): 86 | nonlocal inner 87 | try: 88 | outer_close() 89 | finally: 90 | if inner: 91 | inner, fp = None, inner 92 | fp.close() 93 | 94 | outer.close = close_both 95 | 96 | 97 | def _handle_bz2(file_obj, mode): 98 | from bz2 import BZ2File 99 | result = BZ2File(file_obj, mode) 100 | return result 101 | 102 | 103 | def _handle_gzip(file_obj, mode): 104 | import gzip 105 | result = gzip.GzipFile(fileobj=file_obj, mode=mode) 106 | return result 107 | 108 | 109 | def _handle_zstd(file_obj, mode): 110 | import zstandard # type: ignore 111 | result = zstandard.open(filename=file_obj, mode=mode) 112 | # zstandard.open returns an io.TextIOWrapper in text mode, but otherwise 113 | # returns a raw stream reader/writer, and we need the `io` wrapper 114 | # to make FileLikeProxy work correctly. 115 | # 116 | # See: 117 | # 118 | # https://github.com/indygreg/python-zstandard/blob/d7d81e79dbe74feb22fb73405ebfb3e20f4c4653/zstandard/__init__.py#L169-L174 119 | if "b" in mode and "w" in mode: 120 | result = io.BufferedWriter(result) 121 | elif "b" in mode and "r" in mode: 122 | result = io.BufferedReader(result) 123 | return result 124 | 125 | 126 | def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None): 127 | """ 128 | Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension. 129 | 130 | If the filename extension isn't recognized, simply return the original `file_obj` unchanged. 131 | 132 | `file_obj` must either be a filehandle object, or a class which behaves like one. 133 | 134 | If `filename` is specified, it will be used to extract the extension. 135 | If not, the `file_obj.name` attribute is used as the filename. 136 | 137 | """ 138 | if compression == NO_COMPRESSION: 139 | return file_obj 140 | elif compression == INFER_FROM_EXTENSION: 141 | try: 142 | filename = (filename or file_obj.name).lower() 143 | except (AttributeError, TypeError): 144 | logger.warning( 145 | 'unable to transparently decompress %r because it ' 146 | 'seems to lack a string-like .name', file_obj 147 | ) 148 | return file_obj 149 | _, compression = os.path.splitext(filename) 150 | 151 | if compression in _COMPRESSOR_REGISTRY and mode.endswith('+'): 152 | raise ValueError('transparent (de)compression unsupported for mode %r' % mode) 153 | 154 | try: 155 | callback = _COMPRESSOR_REGISTRY[compression] 156 | except KeyError: 157 | return file_obj 158 | else: 159 | return callback(file_obj, mode) 160 | 161 | 162 | # 163 | # NB. avoid using lambda here to make stack traces more readable. 164 | # 165 | register_compressor('.bz2', _handle_bz2) 166 | register_compressor('.gz', _handle_gzip) 167 | register_compressor('.zst', _handle_zstd) 168 | -------------------------------------------------------------------------------- /smart_open/concurrency.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Common functionality for concurrent processing. The main entry point is :func:`create_pool`.""" 10 | 11 | import contextlib 12 | import logging 13 | import warnings 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # AWS Lambda environments do not support multiprocessing.Queue or multiprocessing.Pool. 18 | # However they do support Threads and therefore concurrent.futures's ThreadPoolExecutor. 19 | # We use this flag to allow python 2 backward compatibility, where concurrent.futures doesn't exist. 20 | _CONCURRENT_FUTURES = False 21 | try: 22 | import concurrent.futures 23 | _CONCURRENT_FUTURES = True 24 | except ImportError: 25 | warnings.warn("concurrent.futures could not be imported and won't be used") 26 | 27 | # Multiprocessing is unavailable in App Engine (and possibly other sandboxes). 28 | # The only method currently relying on it is iter_bucket, which is instructed 29 | # whether to use it by the MULTIPROCESSING flag. 30 | _MULTIPROCESSING = False 31 | try: 32 | import multiprocessing.pool 33 | _MULTIPROCESSING = True 34 | except ImportError: 35 | warnings.warn("multiprocessing could not be imported and won't be used") 36 | 37 | 38 | class DummyPool(object): 39 | """A class that mimics multiprocessing.pool.Pool for our purposes.""" 40 | def imap_unordered(self, function, items): 41 | return map(function, items) 42 | 43 | def terminate(self): 44 | pass 45 | 46 | 47 | class ConcurrentFuturesPool(object): 48 | """A class that mimics multiprocessing.pool.Pool but uses concurrent futures instead of processes.""" 49 | def __init__(self, max_workers): 50 | self.executor = concurrent.futures.ThreadPoolExecutor(max_workers) 51 | 52 | def imap_unordered(self, function, items): 53 | futures = [self.executor.submit(function, item) for item in items] 54 | for future in concurrent.futures.as_completed(futures): 55 | yield future.result() 56 | 57 | def terminate(self): 58 | self.executor.shutdown(wait=True) 59 | 60 | 61 | @contextlib.contextmanager 62 | def create_pool(processes=1): 63 | if _MULTIPROCESSING and processes: 64 | logger.info("creating multiprocessing pool with %i workers", processes) 65 | pool = multiprocessing.pool.Pool(processes=processes) 66 | elif _CONCURRENT_FUTURES and processes: 67 | logger.info("creating concurrent futures pool with %i workers", processes) 68 | pool = ConcurrentFuturesPool(max_workers=processes) 69 | else: 70 | logger.info("creating dummy pool") 71 | pool = DummyPool() 72 | yield pool 73 | pool.terminate() 74 | -------------------------------------------------------------------------------- /smart_open/constants.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Some universal constants that are common to I/O operations.""" 10 | 11 | 12 | READ_BINARY = 'rb' 13 | 14 | WRITE_BINARY = 'wb' 15 | 16 | BINARY_MODES = (READ_BINARY, WRITE_BINARY) 17 | 18 | BINARY_NEWLINE = b'\n' 19 | 20 | WHENCE_START = 0 21 | 22 | WHENCE_CURRENT = 1 23 | 24 | WHENCE_END = 2 25 | 26 | WHENCE_CHOICES = (WHENCE_START, WHENCE_CURRENT, WHENCE_END) 27 | -------------------------------------------------------------------------------- /smart_open/doctools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Common functions for working with docstrings. 10 | 11 | For internal use only. 12 | """ 13 | 14 | import contextlib 15 | import inspect 16 | import io 17 | import os.path 18 | import re 19 | 20 | from . import compression 21 | from . import transport 22 | 23 | PLACEHOLDER = ' smart_open/doctools.py magic goes here' 24 | 25 | 26 | def extract_kwargs(docstring): 27 | """Extract keyword argument documentation from a function's docstring. 28 | 29 | Parameters 30 | ---------- 31 | docstring: str 32 | The docstring to extract keyword arguments from. 33 | 34 | Returns 35 | ------- 36 | list of (str, str, list str) 37 | 38 | str 39 | The name of the keyword argument. 40 | str 41 | Its type. 42 | str 43 | Its documentation as a list of lines. 44 | 45 | Notes 46 | ----- 47 | The implementation is rather fragile. It expects the following: 48 | 49 | 1. The parameters are under an underlined Parameters section 50 | 2. Keyword parameters have the literal ", optional" after the type 51 | 3. Names and types are not indented 52 | 4. Descriptions are indented with 4 spaces 53 | 5. The Parameters section ends with an empty line. 54 | 55 | Examples 56 | -------- 57 | 58 | >>> docstring = '''The foo function. 59 | ... Parameters 60 | ... ---------- 61 | ... bar: str, optional 62 | ... This parameter is the bar. 63 | ... baz: int, optional 64 | ... This parameter is the baz. 65 | ... 66 | ... ''' 67 | >>> kwargs = extract_kwargs(docstring) 68 | >>> kwargs[0] 69 | ('bar', 'str, optional', ['This parameter is the bar.']) 70 | 71 | """ 72 | if not docstring: 73 | return [] 74 | 75 | lines = inspect.cleandoc(docstring).split('\n') 76 | kwargs = [] 77 | 78 | # 79 | # 1. Find the underlined 'Parameters' section 80 | # 2. Once there, continue parsing parameters until we hit an empty line 81 | # 82 | while lines and lines[0] != 'Parameters': 83 | lines.pop(0) 84 | 85 | if not lines: 86 | return [] 87 | 88 | lines.pop(0) 89 | lines.pop(0) 90 | 91 | for line in lines: 92 | if not line.strip(): # stop at the first empty line encountered 93 | break 94 | is_arg_line = not line.startswith(' ') 95 | if is_arg_line: 96 | name, type_ = line.split(':', 1) 97 | name, type_, description = name.strip(), type_.strip(), [] 98 | kwargs.append([name, type_, description]) 99 | continue 100 | is_description_line = line.startswith(' ') 101 | if is_description_line: 102 | kwargs[-1][-1].append(line.strip()) 103 | 104 | return kwargs 105 | 106 | 107 | def to_docstring(kwargs, lpad=''): 108 | """Reconstruct a docstring from keyword argument info. 109 | 110 | Basically reverses :func:`extract_kwargs`. 111 | 112 | Parameters 113 | ---------- 114 | kwargs: list 115 | Output from the extract_kwargs function 116 | lpad: str, optional 117 | Padding string (from the left). 118 | 119 | Returns 120 | ------- 121 | str 122 | The docstring snippet documenting the keyword arguments. 123 | 124 | Examples 125 | -------- 126 | 127 | >>> kwargs = [ 128 | ... ('bar', 'str, optional', ['This parameter is the bar.']), 129 | ... ('baz', 'int, optional', ['This parameter is the baz.']), 130 | ... ] 131 | >>> print(to_docstring(kwargs), end='') 132 | bar: str, optional 133 | This parameter is the bar. 134 | baz: int, optional 135 | This parameter is the baz. 136 | 137 | """ 138 | buf = io.StringIO() 139 | for name, type_, description in kwargs: 140 | buf.write('%s%s: %s\n' % (lpad, name, type_)) 141 | for line in description: 142 | buf.write('%s %s\n' % (lpad, line)) 143 | return buf.getvalue() 144 | 145 | 146 | def extract_examples_from_readme_rst(indent=' '): 147 | """Extract examples from this project's README.rst file. 148 | 149 | Parameters 150 | ---------- 151 | indent: str 152 | Prepend each line with this string. Should contain some number of spaces. 153 | 154 | Returns 155 | ------- 156 | str 157 | The examples. 158 | 159 | Notes 160 | ----- 161 | Quite fragile, depends on named labels inside the README.rst file. 162 | """ 163 | curr_dir = os.path.dirname(os.path.abspath(__file__)) 164 | readme_path = os.path.join(curr_dir, '..', 'README.rst') 165 | try: 166 | with open(readme_path) as fin: 167 | lines = list(fin) 168 | start = lines.index('.. _doctools_before_examples:\n') 169 | end = lines.index(".. _doctools_after_examples:\n") 170 | lines = lines[start+4:end-2] 171 | return ''.join([indent + re.sub('^ ', '', line) for line in lines]) 172 | except Exception: 173 | return indent + 'See README.rst' 174 | 175 | 176 | def tweak_open_docstring(f): 177 | buf = io.StringIO() 178 | seen = set() 179 | 180 | root_path = os.path.dirname(os.path.dirname(__file__)) 181 | 182 | with contextlib.redirect_stdout(buf): 183 | print(' smart_open supports the following transport mechanisms:') 184 | print() 185 | for scheme, submodule in sorted(transport._REGISTRY.items()): 186 | if scheme == transport.NO_SCHEME or submodule in seen: 187 | continue 188 | seen.add(submodule) 189 | 190 | try: 191 | schemes = submodule.SCHEMES 192 | except AttributeError: 193 | schemes = [scheme] 194 | 195 | relpath = os.path.relpath(submodule.__file__, start=root_path) 196 | heading = '%s (%s)' % ("/".join(schemes), relpath) 197 | print(' %s' % heading) 198 | print(' %s' % ('~' * len(heading))) 199 | print(' %s' % submodule.__doc__.split('\n')[0]) 200 | print() 201 | 202 | kwargs = extract_kwargs(submodule.open.__doc__) 203 | if kwargs: 204 | print(to_docstring(kwargs, lpad=u' ')) 205 | 206 | print(' Examples') 207 | print(' --------') 208 | print() 209 | print(extract_examples_from_readme_rst()) 210 | 211 | print(' This function also supports transparent compression and decompression ') 212 | print(' using the following codecs:') 213 | print() 214 | for extension in compression.get_supported_extensions(): 215 | print(' * %s' % extension) 216 | print() 217 | print(' The function depends on the file extension to determine the appropriate codec.') 218 | 219 | # 220 | # The docstring can be None if -OO was passed to the interpreter. 221 | # 222 | if f.__doc__: 223 | f.__doc__ = f.__doc__.replace(PLACEHOLDER, buf.getvalue()) 224 | 225 | 226 | def tweak_parse_uri_docstring(f): 227 | buf = io.StringIO() 228 | seen = set() 229 | schemes = [] 230 | examples = [] 231 | 232 | for scheme, submodule in sorted(transport._REGISTRY.items()): 233 | if scheme == transport.NO_SCHEME or submodule in seen: 234 | continue 235 | 236 | seen.add(submodule) 237 | 238 | try: 239 | examples.extend(submodule.URI_EXAMPLES) 240 | except AttributeError: 241 | pass 242 | 243 | try: 244 | schemes.extend(submodule.SCHEMES) 245 | except AttributeError: 246 | schemes.append(scheme) 247 | 248 | with contextlib.redirect_stdout(buf): 249 | print(' Supported URI schemes are:') 250 | print() 251 | for scheme in schemes: 252 | print(' * %s' % scheme) 253 | print() 254 | print(' Valid URI examples::') 255 | print() 256 | for example in examples: 257 | print(' * %s' % example) 258 | 259 | if f.__doc__: 260 | f.__doc__ = f.__doc__.replace(PLACEHOLDER, buf.getvalue()) 261 | -------------------------------------------------------------------------------- /smart_open/ftp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Implements I/O streams over FTP.""" 10 | 11 | import logging 12 | import ssl 13 | import urllib.parse 14 | import smart_open.utils 15 | from ftplib import FTP, FTP_TLS, error_reply 16 | import types 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | SCHEMES = ("ftp", "ftps") 21 | 22 | """Supported URL schemes.""" 23 | 24 | DEFAULT_PORT = 21 25 | 26 | URI_EXAMPLES = ( 27 | "ftp://username@host/path/file", 28 | "ftp://username:password@host/path/file", 29 | "ftp://username:password@host:port/path/file", 30 | "ftps://username@host/path/file", 31 | "ftps://username:password@host/path/file", 32 | "ftps://username:password@host:port/path/file", 33 | ) 34 | 35 | 36 | def _unquote(text): 37 | return text and urllib.parse.unquote(text) 38 | 39 | 40 | def parse_uri(uri_as_string): 41 | split_uri = urllib.parse.urlsplit(uri_as_string) 42 | assert split_uri.scheme in SCHEMES 43 | return dict( 44 | scheme=split_uri.scheme, 45 | uri_path=_unquote(split_uri.path), 46 | user=_unquote(split_uri.username), 47 | host=split_uri.hostname, 48 | port=int(split_uri.port or DEFAULT_PORT), 49 | password=_unquote(split_uri.password), 50 | ) 51 | 52 | 53 | def open_uri(uri, mode, transport_params): 54 | smart_open.utils.check_kwargs(open, transport_params) 55 | parsed_uri = parse_uri(uri) 56 | uri_path = parsed_uri.pop("uri_path") 57 | scheme = parsed_uri.pop("scheme") 58 | secure_conn = True if scheme == "ftps" else False 59 | return open( 60 | uri_path, 61 | mode, 62 | secure_connection=secure_conn, 63 | transport_params=transport_params, 64 | **parsed_uri, 65 | ) 66 | 67 | 68 | def convert_transport_params_to_args(transport_params): 69 | supported_keywords = [ 70 | "timeout", 71 | "source_address", 72 | "encoding", 73 | ] 74 | unsupported_keywords = [k for k in transport_params if k not in supported_keywords] 75 | kwargs = {k: v for (k, v) in transport_params.items() if k in supported_keywords} 76 | 77 | if unsupported_keywords: 78 | logger.warning( 79 | "ignoring unsupported ftp keyword arguments: %r", unsupported_keywords 80 | ) 81 | 82 | return kwargs 83 | 84 | 85 | def _connect(hostname, username, port, password, secure_connection, transport_params): 86 | kwargs = convert_transport_params_to_args(transport_params) 87 | if secure_connection: 88 | ssl_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) 89 | ftp = FTP_TLS(context=ssl_context, **kwargs) 90 | else: 91 | ftp = FTP(**kwargs) 92 | try: 93 | ftp.connect(hostname, port) 94 | except Exception as e: 95 | logger.error("Unable to connect to FTP server: try checking the host and port!") 96 | raise e 97 | try: 98 | ftp.login(username, password) 99 | except error_reply as e: 100 | logger.error( 101 | "Unable to login to FTP server: try checking the username and password!" 102 | ) 103 | raise e 104 | if secure_connection: 105 | ftp.prot_p() 106 | return ftp 107 | 108 | 109 | def open( 110 | path, 111 | mode="rb", 112 | host=None, 113 | user=None, 114 | password=None, 115 | port=DEFAULT_PORT, 116 | secure_connection=False, 117 | transport_params=None, 118 | ): 119 | """Open a file for reading or writing via FTP/FTPS. 120 | 121 | Parameters 122 | ---------- 123 | path: str 124 | The path on the remote server 125 | mode: str 126 | Must be "rb" or "wb" 127 | host: str 128 | The host to connect to 129 | user: str 130 | The username to use for the connection 131 | password: str 132 | The password for the specified username 133 | port: int 134 | The port to connect to 135 | secure_connection: bool 136 | True for FTPS, False for FTP 137 | transport_params: dict 138 | Additional parameters for the FTP connection. 139 | Currently supported parameters: timeout, source_address, encoding. 140 | """ 141 | if not host: 142 | raise ValueError("you must specify the host to connect to") 143 | if not user: 144 | raise ValueError("you must specify the user") 145 | if not transport_params: 146 | transport_params = {} 147 | conn = _connect(host, user, port, password, secure_connection, transport_params) 148 | mode_to_ftp_cmds = { 149 | "rb": ("RETR", "rb"), 150 | "wb": ("STOR", "wb"), 151 | "ab": ("APPE", "wb"), 152 | } 153 | try: 154 | ftp_mode, file_obj_mode = mode_to_ftp_cmds[mode] 155 | except KeyError: 156 | raise ValueError(f"unsupported mode: {mode!r}") 157 | ftp_mode, file_obj_mode = mode_to_ftp_cmds[mode] 158 | conn.voidcmd("TYPE I") 159 | socket = conn.transfercmd(f"{ftp_mode} {path}") 160 | fobj = socket.makefile(file_obj_mode) 161 | 162 | def full_close(self): 163 | self.orig_close() 164 | self.socket.close() 165 | self.conn.close() 166 | 167 | fobj.orig_close = fobj.close 168 | fobj.socket = socket 169 | fobj.conn = conn 170 | fobj.close = types.MethodType(full_close, fobj) 171 | return fobj 172 | -------------------------------------------------------------------------------- /smart_open/gcs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """Implements file-like objects for reading and writing to/from GCS.""" 9 | 10 | import logging 11 | import warnings 12 | 13 | try: 14 | import google.cloud.exceptions 15 | import google.cloud.storage 16 | import google.auth.transport.requests 17 | except ImportError: 18 | MISSING_DEPS = True 19 | 20 | import smart_open.bytebuffer 21 | import smart_open.utils 22 | 23 | from smart_open import constants 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | SCHEME = "gs" 28 | """Supported scheme for GCS""" 29 | 30 | _DEFAULT_MIN_PART_SIZE = 50 * 1024**2 31 | """Default minimum part size for GCS multipart uploads""" 32 | 33 | _DEFAULT_WRITE_OPEN_KWARGS = {'ignore_flush': True} 34 | 35 | 36 | def parse_uri(uri_as_string): 37 | sr = smart_open.utils.safe_urlsplit(uri_as_string) 38 | assert sr.scheme == SCHEME 39 | bucket_id = sr.netloc 40 | blob_id = sr.path.lstrip('/') 41 | return dict(scheme=SCHEME, bucket_id=bucket_id, blob_id=blob_id) 42 | 43 | 44 | def open_uri(uri, mode, transport_params): 45 | parsed_uri = parse_uri(uri) 46 | kwargs = smart_open.utils.check_kwargs(open, transport_params) 47 | return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs) 48 | 49 | 50 | def warn_deprecated(parameter_name): 51 | message = f"Parameter {parameter_name} is deprecated, this parameter no-longer has any effect" 52 | warnings.warn(message, UserWarning) 53 | 54 | 55 | def open( 56 | bucket_id, 57 | blob_id, 58 | mode, 59 | buffer_size=None, 60 | min_part_size=_DEFAULT_MIN_PART_SIZE, 61 | client=None, # type: google.cloud.storage.Client 62 | get_blob_kwargs=None, 63 | blob_properties=None, 64 | blob_open_kwargs=None, 65 | ): 66 | """Open an GCS blob for reading or writing. 67 | 68 | Parameters 69 | ---------- 70 | bucket_id: str 71 | The name of the bucket this object resides in. 72 | blob_id: str 73 | The name of the blob within the bucket. 74 | mode: str 75 | The mode for opening the object. Must be either "rb" or "wb". 76 | buffer_size: 77 | deprecated 78 | min_part_size: int, optional 79 | The minimum part size for multipart uploads. For writing only. 80 | client: google.cloud.storage.Client, optional 81 | The GCS client to use when working with google-cloud-storage. 82 | get_blob_kwargs: dict, optional 83 | Additional keyword arguments to propagate to the bucket.get_blob 84 | method of the google-cloud-storage library. For reading only. 85 | blob_properties: dict, optional 86 | Set properties on blob before writing. For writing only. 87 | blob_open_kwargs: dict, optional 88 | Additional keyword arguments to propagate to the blob.open method 89 | of the google-cloud-storage library. 90 | 91 | """ 92 | if blob_open_kwargs is None: 93 | blob_open_kwargs = {} 94 | 95 | if buffer_size is not None: 96 | warn_deprecated('buffer_size') 97 | 98 | if mode in (constants.READ_BINARY, 'r', 'rt'): 99 | _blob = Reader(bucket=bucket_id, 100 | key=blob_id, 101 | client=client, 102 | get_blob_kwargs=get_blob_kwargs, 103 | blob_open_kwargs=blob_open_kwargs) 104 | 105 | elif mode in (constants.WRITE_BINARY, 'w', 'wt'): 106 | _blob = Writer(bucket=bucket_id, 107 | blob=blob_id, 108 | min_part_size=min_part_size, 109 | client=client, 110 | blob_properties=blob_properties, 111 | blob_open_kwargs=blob_open_kwargs) 112 | 113 | else: 114 | raise NotImplementedError(f'GCS support for mode {mode} not implemented') 115 | 116 | return _blob 117 | 118 | 119 | def Reader(bucket, 120 | key, 121 | buffer_size=None, 122 | line_terminator=None, 123 | client=None, 124 | get_blob_kwargs=None, 125 | blob_open_kwargs=None): 126 | 127 | if get_blob_kwargs is None: 128 | get_blob_kwargs = {} 129 | if blob_open_kwargs is None: 130 | blob_open_kwargs = {} 131 | if client is None: 132 | client = google.cloud.storage.Client() 133 | if buffer_size is not None: 134 | warn_deprecated('buffer_size') 135 | if line_terminator is not None: 136 | warn_deprecated('line_terminator') 137 | 138 | bkt = client.bucket(bucket) 139 | blob = bkt.get_blob(key, **get_blob_kwargs) 140 | 141 | if blob is None: 142 | raise google.cloud.exceptions.NotFound(f'blob {key} not found in {bucket}') 143 | 144 | return blob.open('rb', **blob_open_kwargs) 145 | 146 | 147 | def Writer(bucket, 148 | blob, 149 | min_part_size=None, 150 | client=None, 151 | blob_properties=None, 152 | blob_open_kwargs=None): 153 | 154 | if blob_open_kwargs is None: 155 | blob_open_kwargs = {} 156 | if blob_properties is None: 157 | blob_properties = {} 158 | if client is None: 159 | client = google.cloud.storage.Client() 160 | 161 | blob_open_kwargs = {**_DEFAULT_WRITE_OPEN_KWARGS, **blob_open_kwargs} 162 | 163 | g_blob = client.bucket(bucket).blob( 164 | blob, 165 | chunk_size=min_part_size, 166 | ) 167 | 168 | for k, v in blob_properties.items(): 169 | setattr(g_blob, k, v) 170 | 171 | _blob = g_blob.open('wb', **blob_open_kwargs) 172 | 173 | # backwards-compatiblity, was deprecated upstream https://cloud.google.com/storage/docs/resumable-uploads 174 | _blob.terminate = lambda: None 175 | 176 | return _blob 177 | -------------------------------------------------------------------------------- /smart_open/hdfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Implements reading and writing to/from HDFS.""" 10 | 11 | import io 12 | import logging 13 | import subprocess 14 | import urllib.parse 15 | 16 | from smart_open import utils 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | SCHEMES = ('hdfs', 'viewfs') 21 | 22 | URI_EXAMPLES = ( 23 | 'hdfs:///path/file', 24 | 'hdfs://path/file', 25 | 'viewfs:///path/file', 26 | 'viewfs://path/file', 27 | ) 28 | 29 | 30 | def parse_uri(uri_as_string): 31 | split_uri = urllib.parse.urlsplit(uri_as_string) 32 | assert split_uri.scheme in SCHEMES 33 | 34 | uri_path = split_uri.netloc + split_uri.path 35 | uri_path = "/" + uri_path.lstrip("/") 36 | if not uri_path: 37 | raise RuntimeError("invalid HDFS URI: %r" % uri_as_string) 38 | 39 | return dict(scheme=split_uri.scheme, uri_path=uri_path) 40 | 41 | 42 | def open_uri(uri, mode, transport_params): 43 | utils.check_kwargs(open, transport_params) 44 | 45 | parsed_uri = parse_uri(uri) 46 | fobj = open(parsed_uri['uri_path'], mode) 47 | fobj.name = parsed_uri['uri_path'].split('/')[-1] 48 | return fobj 49 | 50 | 51 | def open(uri, mode): 52 | if mode == 'rb': 53 | return CliRawInputBase(uri) 54 | elif mode == 'wb': 55 | return CliRawOutputBase(uri) 56 | else: 57 | raise NotImplementedError('hdfs support for mode %r not implemented' % mode) 58 | 59 | 60 | class CliRawInputBase(io.RawIOBase): 61 | """Reads bytes from HDFS via the "hdfs dfs" command-line interface. 62 | 63 | Implements the io.RawIOBase interface of the standard library. 64 | """ 65 | _sub = None # so `closed` property works in case __init__ fails and __del__ is called 66 | 67 | def __init__(self, uri): 68 | self._uri = uri 69 | self._sub = subprocess.Popen(["hdfs", "dfs", '-cat', self._uri], stdout=subprocess.PIPE) 70 | 71 | # 72 | # This member is part of the io.BufferedIOBase interface. 73 | # 74 | self.raw = None 75 | 76 | # 77 | # Override some methods from io.IOBase. 78 | # 79 | def close(self): 80 | """Flush and close this stream.""" 81 | logger.debug("close: called") 82 | if not self.closed: 83 | self._sub.terminate() 84 | self._sub = None 85 | 86 | @property 87 | def closed(self): 88 | return self._sub is None 89 | 90 | def readable(self): 91 | """Return True if the stream can be read from.""" 92 | return self._sub is not None 93 | 94 | def seekable(self): 95 | """If False, seek(), tell() and truncate() will raise IOError.""" 96 | return False 97 | 98 | # 99 | # io.RawIOBase methods. 100 | # 101 | def detach(self): 102 | """Unsupported.""" 103 | raise io.UnsupportedOperation 104 | 105 | def read(self, size=-1): 106 | """Read up to size bytes from the object and return them.""" 107 | return self._sub.stdout.read(size) 108 | 109 | def read1(self, size=-1): 110 | """This is the same as read().""" 111 | return self.read(size=size) 112 | 113 | def readinto(self, b): 114 | """Read up to len(b) bytes into b, and return the number of bytes 115 | read.""" 116 | data = self.read(len(b)) 117 | if not data: 118 | return 0 119 | b[:len(data)] = data 120 | return len(data) 121 | 122 | 123 | class CliRawOutputBase(io.RawIOBase): 124 | """Writes bytes to HDFS via the "hdfs dfs" command-line interface. 125 | 126 | Implements the io.RawIOBase interface of the standard library. 127 | """ 128 | _sub = None # so `closed` property works in case __init__ fails and __del__ is called 129 | 130 | def __init__(self, uri): 131 | self._uri = uri 132 | self._sub = subprocess.Popen(["hdfs", "dfs", '-put', '-f', '-', self._uri], 133 | stdin=subprocess.PIPE) 134 | 135 | # 136 | # This member is part of the io.RawIOBase interface. 137 | # 138 | self.raw = None 139 | 140 | def close(self): 141 | logger.debug("close: called") 142 | if not self.closed: 143 | self.flush() 144 | self._sub.stdin.close() 145 | self._sub.wait() 146 | self._sub = None 147 | 148 | @property 149 | def closed(self): 150 | return self._sub is None 151 | 152 | def flush(self): 153 | self._sub.stdin.flush() 154 | 155 | def writeable(self): 156 | """Return True if this object is writeable.""" 157 | return self._sub is not None 158 | 159 | def seekable(self): 160 | """If False, seek(), tell() and truncate() will raise IOError.""" 161 | return False 162 | 163 | def write(self, b): 164 | self._sub.stdin.write(b) 165 | 166 | # 167 | # io.IOBase methods. 168 | # 169 | def detach(self): 170 | raise io.UnsupportedOperation("detach() not supported") 171 | -------------------------------------------------------------------------------- /smart_open/local_file.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """Implements the transport for the file:// schema.""" 9 | import io 10 | import os.path 11 | 12 | SCHEME = 'file' 13 | 14 | URI_EXAMPLES = ( 15 | './local/path/file', 16 | '~/local/path/file', 17 | 'local/path/file', 18 | './local/path/file.gz', 19 | 'file:///home/user/file', 20 | 'file:///home/user/file.bz2', 21 | ) 22 | 23 | 24 | open = io.open 25 | 26 | 27 | def parse_uri(uri_as_string): 28 | local_path = extract_local_path(uri_as_string) 29 | return dict(scheme=SCHEME, uri_path=local_path) 30 | 31 | 32 | def open_uri(uri_as_string, mode, transport_params): 33 | parsed_uri = parse_uri(uri_as_string) 34 | fobj = io.open(parsed_uri['uri_path'], mode) 35 | return fobj 36 | 37 | 38 | def extract_local_path(uri_as_string): 39 | if uri_as_string.startswith('file://'): 40 | local_path = uri_as_string.replace('file://', '', 1) 41 | else: 42 | local_path = uri_as_string 43 | return os.path.expanduser(local_path) 44 | -------------------------------------------------------------------------------- /smart_open/ssh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Implements I/O streams over SSH. 10 | 11 | Examples 12 | -------- 13 | 14 | >>> with open('/proc/version_signature', host='1.2.3.4') as conn: 15 | ... print(conn.read()) 16 | b'Ubuntu 4.4.0-1061.70-aws 4.4.131' 17 | 18 | Similarly, from a command line:: 19 | 20 | $ python -c "from smart_open import ssh;print(ssh.open('/proc/version_signature', host='1.2.3.4').read())" 21 | b'Ubuntu 4.4.0-1061.70-aws 4.4.131' 22 | 23 | """ 24 | 25 | import getpass 26 | import os 27 | import logging 28 | import urllib.parse 29 | 30 | from typing import ( 31 | Dict, 32 | Callable, 33 | Tuple, 34 | ) 35 | 36 | try: 37 | import paramiko 38 | except ImportError: 39 | MISSING_DEPS = True 40 | 41 | import smart_open.utils 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | # 46 | # Global storage for SSH connections. 47 | # 48 | _SSH = {} 49 | 50 | SCHEMES = ("ssh", "scp", "sftp") 51 | """Supported URL schemes.""" 52 | 53 | DEFAULT_PORT = 22 54 | 55 | URI_EXAMPLES = ( 56 | 'ssh://username@host/path/file', 57 | 'ssh://username@host//path/file', 58 | 'scp://username@host/path/file', 59 | 'sftp://username@host/path/file', 60 | ) 61 | 62 | # 63 | # Global storage for SSH config files. 64 | # 65 | _SSH_CONFIG_FILES = [os.path.expanduser("~/.ssh/config")] 66 | 67 | 68 | def _unquote(text): 69 | return text and urllib.parse.unquote(text) 70 | 71 | 72 | def _str2bool(string): 73 | if string == "no": 74 | return False 75 | if string == "yes": 76 | return True 77 | raise ValueError(f"Expected 'yes' / 'no', got {string}.") 78 | 79 | 80 | # 81 | # The parameter names used by Paramiko (and smart_open) slightly differ to 82 | # those used in ~/.ssh/config, so we use a mapping to bridge the gap. 83 | # 84 | # The keys are option names as they appear in Paramiko (and smart_open) 85 | # The values are a tuples containing: 86 | # 87 | # 1. their corresponding names in the ~/.ssh/config file 88 | # 2. a callable to convert the parameter value from a string to the appropriate type 89 | # 90 | _PARAMIKO_CONFIG_MAP: Dict[str, Tuple[str, Callable]] = { 91 | "timeout": ("connecttimeout", float), 92 | "compress": ("compression", _str2bool), 93 | "gss_auth": ("gssapiauthentication", _str2bool), 94 | "gss_kex": ("gssapikeyexchange", _str2bool), 95 | "gss_deleg_creds": ("gssapidelegatecredentials", _str2bool), 96 | "gss_trust_dns": ("gssapitrustdns", _str2bool), 97 | } 98 | 99 | 100 | def parse_uri(uri_as_string): 101 | split_uri = urllib.parse.urlsplit(uri_as_string) 102 | assert split_uri.scheme in SCHEMES 103 | return dict( 104 | scheme=split_uri.scheme, 105 | uri_path=_unquote(split_uri.path), 106 | user=_unquote(split_uri.username), 107 | host=split_uri.hostname, 108 | port=int(split_uri.port) if split_uri.port else None, 109 | password=_unquote(split_uri.password), 110 | ) 111 | 112 | 113 | def open_uri(uri, mode, transport_params): 114 | kwargs = smart_open.utils.check_kwargs(open, transport_params) 115 | parsed_uri = parse_uri(uri) 116 | uri_path = parsed_uri.pop('uri_path') 117 | parsed_uri.pop('scheme') 118 | return open(uri_path, mode, **parsed_uri, **kwargs) 119 | 120 | 121 | def _connect_ssh(hostname, username, port, password, connect_kwargs): 122 | ssh = paramiko.SSHClient() 123 | ssh.load_system_host_keys() 124 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 125 | kwargs = (connect_kwargs or {}).copy() 126 | if 'key_filename' not in kwargs: 127 | kwargs.setdefault('password', password) 128 | kwargs.setdefault('username', username) 129 | ssh.connect(hostname, port, **kwargs) 130 | return ssh 131 | 132 | 133 | def _maybe_fetch_config(host, username=None, password=None, port=None, connect_kwargs=None): 134 | # If all fields are set, return as-is. 135 | if not any(arg is None for arg in (host, username, password, port, connect_kwargs)): 136 | return host, username, password, port, connect_kwargs 137 | 138 | if not host: 139 | raise ValueError('you must specify the host to connect to') 140 | 141 | # Attempt to load an OpenSSH config. 142 | # 143 | # Connections configured in this way are not guaranteed to perform exactly 144 | # as they do in typical usage due to mismatches between the set of OpenSSH 145 | # configuration options and those that Paramiko supports. We provide a best 146 | # attempt, and support: 147 | # 148 | # - hostname -> address resolution 149 | # - username inference 150 | # - port inference 151 | # - identityfile inference 152 | # - connection timeout inference 153 | # - compression selection 154 | # - GSS configuration 155 | # 156 | connect_params = (connect_kwargs or {}).copy() 157 | config_files = [f for f in _SSH_CONFIG_FILES if os.path.exists(f)] 158 | # 159 | # This is the actual name of the host. The input host may actually be an 160 | # alias. 161 | # 162 | actual_hostname = "" 163 | 164 | for config_filename in config_files: 165 | try: 166 | cfg = paramiko.SSHConfig.from_path(config_filename) 167 | except PermissionError: 168 | continue 169 | 170 | if host not in cfg.get_hostnames(): 171 | continue 172 | 173 | cfg = cfg.lookup(host) 174 | if username is None: 175 | username = cfg.get("user", None) 176 | 177 | if not actual_hostname: 178 | actual_hostname = cfg["hostname"] 179 | 180 | if port is None: 181 | try: 182 | port = int(cfg["port"]) 183 | except (IndexError, ValueError): 184 | # 185 | # Nb. ignore missing/invalid port numbers 186 | # 187 | pass 188 | 189 | # 190 | # Special case, as we can have multiple identity files, so we check 191 | # that the identityfile list has len > 0. This should be redundant, but 192 | # keeping it for safety. 193 | # 194 | if connect_params.get("key_filename") is None: 195 | identityfile = cfg.get("identityfile", []) 196 | if len(identityfile): 197 | connect_params["key_filename"] = identityfile 198 | 199 | for param_name, (sshcfg_name, from_str) in _PARAMIKO_CONFIG_MAP.items(): 200 | if connect_params.get(param_name) is None and sshcfg_name in cfg: 201 | connect_params[param_name] = from_str(cfg[sshcfg_name]) 202 | 203 | # 204 | # Continue working through other config files, if there are any, 205 | # as they may contain more options for our host 206 | # 207 | 208 | if port is None: 209 | port = DEFAULT_PORT 210 | 211 | if not username: 212 | username = getpass.getuser() 213 | 214 | if actual_hostname: 215 | host = actual_hostname 216 | 217 | return host, username, password, port, connect_params 218 | 219 | 220 | def open( 221 | path, 222 | mode="r", 223 | host=None, 224 | user=None, 225 | password=None, 226 | port=None, 227 | connect_kwargs=None, 228 | prefetch_kwargs=None, 229 | buffer_size=-1, 230 | ): 231 | """Open a file on a remote machine over SSH. 232 | 233 | Expects authentication to be already set up via existing keys on the local machine. 234 | 235 | Parameters 236 | ---------- 237 | path: str 238 | The path to the file to open on the remote machine. 239 | mode: str, optional 240 | The mode to use for opening the file. 241 | host: str, optional 242 | The hostname of the remote machine. May not be None. 243 | user: str, optional 244 | The username to use to login to the remote machine. 245 | If None, defaults to the name of the current user. 246 | password: str, optional 247 | The password to use to login to the remote machine. 248 | port: int, optional 249 | The port to connect to. 250 | connect_kwargs: dict, optional 251 | Any additional settings to be passed to paramiko.SSHClient.connect. 252 | prefetch_kwargs: dict, optional 253 | Any additional settings to be passed to paramiko.SFTPFile.prefetch. 254 | The presence of this dict (even if empty) triggers prefetching. 255 | buffer_size: int, optional 256 | Passed to the bufsize argument of paramiko.SFTPClient.open. 257 | 258 | Returns 259 | ------- 260 | A file-like object. 261 | 262 | Important 263 | --------- 264 | If you specify a previously unseen host, then its host key will be added to 265 | the local ~/.ssh/known_hosts *automatically*. 266 | 267 | If ``username`` or ``password`` are specified in *both* the uri and 268 | ``transport_params``, ``transport_params`` will take precedence 269 | """ 270 | host, user, password, port, connect_kwargs = _maybe_fetch_config( 271 | host, user, password, port, connect_kwargs 272 | ) 273 | 274 | key = (host, user) 275 | 276 | attempts = 2 277 | for attempt in range(attempts): 278 | try: 279 | ssh = _SSH[key] 280 | # Validate that the cached connection is still an active connection 281 | # and if not, refresh the connection 282 | if not ssh.get_transport().active: 283 | ssh.close() 284 | ssh = _SSH[key] = _connect_ssh(host, user, port, password, connect_kwargs) 285 | except KeyError: 286 | ssh = _SSH[key] = _connect_ssh(host, user, port, password, connect_kwargs) 287 | 288 | try: 289 | transport = ssh.get_transport() 290 | sftp_client = transport.open_sftp_client() 291 | break 292 | except paramiko.SSHException as ex: 293 | connection_timed_out = ex.args and ex.args[0] == 'SSH session not active' 294 | if attempt == attempts - 1 or not connection_timed_out: 295 | raise 296 | 297 | # 298 | # Try again. Delete the connection from the cache to force a 299 | # reconnect in the next attempt. 300 | # 301 | del _SSH[key] 302 | 303 | fobj = sftp_client.open(path, mode=mode, bufsize=buffer_size) 304 | fobj.name = path 305 | if prefetch_kwargs is not None: 306 | fobj.prefetch(**prefetch_kwargs) 307 | return fobj 308 | -------------------------------------------------------------------------------- /smart_open/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | -------------------------------------------------------------------------------- /smart_open/tests/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/fixtures/__init__.py -------------------------------------------------------------------------------- /smart_open/tests/fixtures/good_transport.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A no-op transport that registers scheme 'foo'""" 3 | import io 4 | 5 | SCHEME = "foo" 6 | open = io.open 7 | 8 | 9 | def parse_uri(uri_as_string): # pragma: no cover 10 | ... 11 | 12 | 13 | def open_uri(uri_as_string, mode, transport_params): # pragma: no cover 14 | ... 15 | -------------------------------------------------------------------------------- /smart_open/tests/fixtures/missing_deps_transport.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Transport that has missing deps""" 3 | import io 4 | 5 | 6 | try: 7 | import this_module_does_not_exist_but_we_need_it # noqa 8 | except ImportError: 9 | MISSING_DEPS = True 10 | 11 | SCHEME = "missing" 12 | open = io.open 13 | 14 | 15 | def parse_uri(uri_as_string): # pragma: no cover 16 | ... 17 | 18 | 19 | def open_uri(uri_as_string, mode, transport_params): # pragma: no cover 20 | ... 21 | -------------------------------------------------------------------------------- /smart_open/tests/fixtures/no_schemes_transport.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A transport that is missing the required SCHEME/SCHEMAS attributes""" 3 | import io 4 | 5 | open = io.open 6 | 7 | 8 | def parse_uri(uri_as_string): # pragma: no cover 9 | ... 10 | 11 | 12 | def open_uri(uri_as_string, mode, transport_params): # pragma: no cover 13 | ... 14 | -------------------------------------------------------------------------------- /smart_open/tests/test_bytebuffer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import io 9 | import random 10 | import unittest 11 | 12 | import smart_open.bytebuffer 13 | 14 | 15 | CHUNK_SIZE = 1024 16 | 17 | 18 | def int2byte(i): 19 | return bytes((i, )) 20 | 21 | 22 | def random_byte_string(length=CHUNK_SIZE): 23 | rand_bytes = [int2byte(random.randint(0, 255)) for _ in range(length)] 24 | return b''.join(rand_bytes) 25 | 26 | 27 | def bytebuffer_and_random_contents(): 28 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 29 | contents = random_byte_string(CHUNK_SIZE) 30 | content_reader = io.BytesIO(contents) 31 | buf.fill(content_reader) 32 | 33 | return [buf, contents] 34 | 35 | 36 | class ByteBufferTest(unittest.TestCase): 37 | def test_len(self): 38 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 39 | self.assertEqual(len(buf), 0) 40 | 41 | contents = b'foo bar baz' 42 | buf._bytes = contents 43 | self.assertEqual(len(buf), len(contents)) 44 | 45 | pos = 4 46 | buf._pos = pos 47 | self.assertEqual(len(buf), len(contents) - pos) 48 | 49 | def test_fill_from_reader(self): 50 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 51 | contents = random_byte_string(CHUNK_SIZE) 52 | content_reader = io.BytesIO(contents) 53 | 54 | bytes_filled = buf.fill(content_reader) 55 | self.assertEqual(bytes_filled, CHUNK_SIZE) 56 | self.assertEqual(len(buf), CHUNK_SIZE) 57 | self.assertEqual(buf._bytes, contents) 58 | 59 | def test_fill_from_iterable(self): 60 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 61 | contents = random_byte_string(CHUNK_SIZE) 62 | contents_iter = (contents[i:i+8] for i in range(0, CHUNK_SIZE, 8)) 63 | 64 | bytes_filled = buf.fill(contents_iter) 65 | self.assertEqual(bytes_filled, CHUNK_SIZE) 66 | self.assertEqual(len(buf), CHUNK_SIZE) 67 | self.assertEqual(buf._bytes, contents) 68 | 69 | def test_fill_from_list(self): 70 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 71 | contents = random_byte_string(CHUNK_SIZE) 72 | contents_list = [contents[i:i+7] for i in range(0, CHUNK_SIZE, 7)] 73 | 74 | bytes_filled = buf.fill(contents_list) 75 | self.assertEqual(bytes_filled, CHUNK_SIZE) 76 | self.assertEqual(len(buf), CHUNK_SIZE) 77 | self.assertEqual(buf._bytes, contents) 78 | 79 | def test_fill_multiple(self): 80 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 81 | long_contents = random_byte_string(CHUNK_SIZE * 4) 82 | long_content_reader = io.BytesIO(long_contents) 83 | 84 | first_bytes_filled = buf.fill(long_content_reader) 85 | self.assertEqual(first_bytes_filled, CHUNK_SIZE) 86 | 87 | second_bytes_filled = buf.fill(long_content_reader) 88 | self.assertEqual(second_bytes_filled, CHUNK_SIZE) 89 | self.assertEqual(len(buf), 2 * CHUNK_SIZE) 90 | 91 | def test_fill_size(self): 92 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 93 | contents = random_byte_string(CHUNK_SIZE * 2) 94 | content_reader = io.BytesIO(contents) 95 | fill_size = int(CHUNK_SIZE / 2) 96 | 97 | bytes_filled = buf.fill(content_reader, size=fill_size) 98 | 99 | self.assertEqual(bytes_filled, fill_size) 100 | self.assertEqual(len(buf), fill_size) 101 | 102 | second_bytes_filled = buf.fill(content_reader, size=CHUNK_SIZE+1) 103 | self.assertEqual(second_bytes_filled, CHUNK_SIZE) 104 | self.assertEqual(len(buf), fill_size + CHUNK_SIZE) 105 | 106 | def test_fill_reader_exhaustion(self): 107 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 108 | short_content_size = int(CHUNK_SIZE / 4) 109 | short_contents = random_byte_string(short_content_size) 110 | short_content_reader = io.BytesIO(short_contents) 111 | 112 | bytes_filled = buf.fill(short_content_reader) 113 | self.assertEqual(bytes_filled, short_content_size) 114 | self.assertEqual(len(buf), short_content_size) 115 | 116 | def test_fill_iterable_exhaustion(self): 117 | buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) 118 | short_content_size = int(CHUNK_SIZE / 4) 119 | short_contents = random_byte_string(short_content_size) 120 | short_contents_iter = (short_contents[i:i+8] 121 | for i in range(0, short_content_size, 8)) 122 | 123 | bytes_filled = buf.fill(short_contents_iter) 124 | self.assertEqual(bytes_filled, short_content_size) 125 | self.assertEqual(len(buf), short_content_size) 126 | 127 | def test_empty(self): 128 | buf, _ = bytebuffer_and_random_contents() 129 | 130 | self.assertEqual(len(buf), CHUNK_SIZE) 131 | buf.empty() 132 | self.assertEqual(len(buf), 0) 133 | 134 | def test_peek(self): 135 | buf, contents = bytebuffer_and_random_contents() 136 | 137 | self.assertEqual(buf.peek(), contents) 138 | self.assertEqual(len(buf), CHUNK_SIZE) 139 | self.assertEqual(buf.peek(64), contents[0:64]) 140 | self.assertEqual(buf.peek(CHUNK_SIZE * 10), contents) 141 | 142 | def test_read(self): 143 | buf, contents = bytebuffer_and_random_contents() 144 | 145 | self.assertEqual(buf.read(), contents) 146 | self.assertEqual(len(buf), 0) 147 | self.assertEqual(buf.read(), b'') 148 | 149 | def test_read_size(self): 150 | buf, contents = bytebuffer_and_random_contents() 151 | read_size = 128 152 | 153 | self.assertEqual(buf.read(read_size), contents[:read_size]) 154 | self.assertEqual(len(buf), CHUNK_SIZE - read_size) 155 | 156 | self.assertEqual(buf.read(CHUNK_SIZE*2), contents[read_size:]) 157 | self.assertEqual(len(buf), 0) 158 | 159 | def test_readline(self): 160 | """Does the readline function work as expected in the simple case?""" 161 | expected = (b'this is the very first line\n', b'and this the second') 162 | buf = smart_open.bytebuffer.ByteBuffer() 163 | buf.fill(io.BytesIO(b''.join(expected))) 164 | 165 | first_line = buf.readline(b'\n') 166 | self.assertEqual(expected[0], first_line) 167 | 168 | second_line = buf.readline(b'\n') 169 | self.assertEqual(expected[1], second_line) 170 | 171 | def test_readline_middle(self): 172 | """Does the readline function work when we're in the middle of the buffer?""" 173 | expected = (b'this is the very first line\n', b'and this the second') 174 | buf = smart_open.bytebuffer.ByteBuffer() 175 | buf.fill(io.BytesIO(b''.join(expected))) 176 | 177 | buf.read(5) 178 | first_line = buf.readline(b'\n') 179 | self.assertEqual(expected[0][5:], first_line) 180 | 181 | buf.read(5) 182 | second_line = buf.readline(b'\n') 183 | self.assertEqual(expected[1][5:], second_line) 184 | 185 | def test_readline_terminator(self): 186 | """Does the readline function respect the terminator parameter?""" 187 | buf = smart_open.bytebuffer.ByteBuffer() 188 | buf.fill(io.BytesIO(b'one!two.three,')) 189 | expected = [b'one!', b'two.', b'three,'] 190 | actual = [buf.readline(b'!'), buf.readline(b'.'), buf.readline(b',')] 191 | self.assertEqual(expected, actual) 192 | -------------------------------------------------------------------------------- /smart_open/tests/test_compression.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import gzip 9 | import io 10 | 11 | import pytest 12 | import zstandard as zstd 13 | 14 | import smart_open.compression 15 | 16 | plain = 'доброе утро планета!'.encode() 17 | 18 | 19 | def label(thing, name): 20 | setattr(thing, 'name', name) 21 | return thing 22 | 23 | 24 | @pytest.mark.parametrize( 25 | 'fileobj,compression,filename', 26 | [ 27 | (io.BytesIO(plain), 'disable', None), 28 | (io.BytesIO(plain), 'disable', ''), 29 | (io.BytesIO(plain), 'infer_from_extension', 'file.txt'), 30 | (io.BytesIO(plain), 'infer_from_extension', 'file.TXT'), 31 | (io.BytesIO(plain), '.unknown', ''), 32 | (io.BytesIO(gzip.compress(plain)), 'infer_from_extension', 'file.gz'), 33 | (io.BytesIO(gzip.compress(plain)), 'infer_from_extension', 'file.GZ'), 34 | (label(io.BytesIO(gzip.compress(plain)), 'file.gz'), 'infer_from_extension', ''), 35 | (io.BytesIO(gzip.compress(plain)), '.gz', 'file.gz'), 36 | (io.BytesIO(zstd.ZstdCompressor().compress(plain)), 'infer_from_extension', 'file.zst'), 37 | (io.BytesIO(zstd.ZstdCompressor().compress(plain)), 'infer_from_extension', 'file.ZST'), 38 | (label(io.BytesIO(zstd.ZstdCompressor().compress(plain)), 'file.zst'), 'infer_from_extension', ''), 39 | (io.BytesIO(zstd.ZstdCompressor().compress(plain)), '.zst', 'file.zst'), 40 | ] 41 | ) 42 | def test_compression_wrapper_read(fileobj, compression, filename): 43 | wrapped = smart_open.compression.compression_wrapper(fileobj, 'rb', compression, filename) 44 | assert wrapped.read() == plain 45 | -------------------------------------------------------------------------------- /smart_open/tests/test_data/1984.txt.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/1984.txt.bz2 -------------------------------------------------------------------------------- /smart_open/tests/test_data/1984.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/1984.txt.gz -------------------------------------------------------------------------------- /smart_open/tests/test_data/1984.txt.gzip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/1984.txt.gzip -------------------------------------------------------------------------------- /smart_open/tests/test_data/1984.txt.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/1984.txt.xz -------------------------------------------------------------------------------- /smart_open/tests/test_data/cp852.tsv.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/cp852.tsv.txt -------------------------------------------------------------------------------- /smart_open/tests/test_data/crime-and-punishment.txt: -------------------------------------------------------------------------------- 1 | В начале июля, в чрезвычайно жаркое время, под вечер, один молодой человек вышел из своей каморки, которую нанимал от жильцов в С -- м переулке, на улицу и медленно, как бы в нерешимости, отправился к К -- ну мосту. 2 | Он благополучно избегнул встречи с своею хозяйкой на лестнице. Каморка его приходилась под самою кровлей высокого пятиэтажного дома и походила более на шкаф, чем на квартиру. Квартирная же хозяйка его, у которой он нанимал эту каморку с обедом и прислугой, помещалась одною лестницей ниже, в отдельной квартире, и каждый раз, при выходе на улицу, ему непременно надо было проходить мимо хозяйкиной кухни, почти всегда настежь отворенной на лестницу. И каждый раз молодой человек, проходя мимо, чувствовал какое-то болезненное и трусливое ощущение, которого стыдился и от которого морщился. Он был должен кругом хозяйке и боялся с нею встретиться. 3 | Не то чтоб он был так труслив и забит, совсем даже напротив; но с некоторого времени он был в раздражительном и напряженном состоянии, похожем на ипохондрию. Он до того углубился в себя и уединился от всех, что боялся даже всякой встречи, не только встречи с хозяйкой. Он был задавлен бедностью; но даже стесненное положение перестало в последнее время тяготить его. Насущными делами своими он совсем перестал и не хотел заниматься. Никакой хозяйки, в сущности, он не боялся, что бы та ни замышляла против него. Но останавливаться на лестнице, слушать всякий вздор про всю эту обыденную дребедень, до которой ему нет никакого дела, все эти приставания о платеже, угрозы, жалобы, и при этом самому изворачиваться, извиняться, лгать, -- нет уж, лучше проскользнуть как-нибудь кошкой по лестнице и улизнуть, чтобы никто не видал. 4 | -------------------------------------------------------------------------------- /smart_open/tests/test_data/crime-and-punishment.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/crime-and-punishment.txt.gz -------------------------------------------------------------------------------- /smart_open/tests/test_data/crime-and-punishment.txt.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/crime-and-punishment.txt.xz -------------------------------------------------------------------------------- /smart_open/tests/test_data/crlf_at_1k_boundary.warc.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piskvorky/smart_open/4bed3466b9db62fefa32249cfce40a479b586866/smart_open/tests/test_data/crlf_at_1k_boundary.warc.gz -------------------------------------------------------------------------------- /smart_open/tests/test_data/ssh.cfg: -------------------------------------------------------------------------------- 1 | Host another-host 2 | HostName another-host-domain.com 3 | User another-user 4 | Port 2345 5 | IdentityFile /path/to/key/file 6 | ConnectTimeout 20 7 | Compression yes 8 | GSSAPIAuthentication no 9 | GSSAPIKeyExchange no 10 | GSSAPIDelegateCredentials no 11 | GSSAPITrustDns no 12 | -------------------------------------------------------------------------------- /smart_open/tests/test_hdfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import gzip 9 | import os 10 | import os.path as P 11 | import subprocess 12 | from unittest import mock 13 | import sys 14 | 15 | import pytest 16 | 17 | import smart_open.hdfs 18 | 19 | CURR_DIR = P.dirname(P.abspath(__file__)) 20 | 21 | if sys.platform.startswith("win"): 22 | pytest.skip("these tests don't work under Windows", allow_module_level=True) 23 | 24 | 25 | # 26 | # We want our mocks to emulate the real implementation as close as possible, 27 | # so we use a Popen call during each test. If we mocked using io.BytesIO, then 28 | # it is possible the mocks would behave differently to what we expect in real 29 | # use. 30 | # 31 | # Since these tests use cat, they will not work in an environment without cat, 32 | # such as Windows. The main line of this test submodule contains a simple 33 | # cat implementation. We need this because Windows' analog, type, does 34 | # weird stuff with line endings (inserts CRLF). Also, I don't know of a way 35 | # to get type to echo standard input. 36 | # 37 | def cat(path=None): 38 | command = [sys.executable, P.abspath(__file__)] 39 | if path: 40 | command.append(path) 41 | return subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) 42 | 43 | 44 | CAP_PATH = P.join(CURR_DIR, 'test_data', 'crime-and-punishment.txt') 45 | with open(CAP_PATH, encoding='utf-8') as fin: 46 | CRIME_AND_PUNISHMENT = fin.read() 47 | 48 | 49 | def test_sanity_read_bytes(): 50 | with open(CAP_PATH, 'rb') as fin: 51 | lines = [line for line in fin] 52 | assert len(lines) == 3 53 | 54 | 55 | def test_sanity_read_text(): 56 | with open(CAP_PATH, 'r', encoding='utf-8') as fin: 57 | text = fin.read() 58 | 59 | expected = 'В начале июля, в чрезвычайно жаркое время' 60 | assert text[:len(expected)] == expected 61 | 62 | 63 | @pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )]) 64 | def test_read(schema): 65 | with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH)): 66 | reader = smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url') 67 | as_bytes = reader.read() 68 | 69 | # 70 | # Not 100% sure why this is necessary on Windows platforms, but the 71 | # tests fail without it. It may be a bug, but I don't have time to 72 | # investigate right now. 73 | # 74 | as_text = as_bytes.decode('utf-8').replace(os.linesep, '\n') 75 | assert as_text == CRIME_AND_PUNISHMENT 76 | 77 | 78 | @pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )]) 79 | def test_read_75(schema): 80 | with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH)): 81 | reader = smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url') 82 | as_bytes = reader.read(75) 83 | 84 | as_text = as_bytes.decode('utf-8').replace(os.linesep, '\n') 85 | assert as_text == CRIME_AND_PUNISHMENT[:len(as_text)] 86 | 87 | 88 | @pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )]) 89 | def test_unzip(schema): 90 | with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH + '.gz')): 91 | with gzip.GzipFile(fileobj=smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url')) as fin: 92 | as_bytes = fin.read() 93 | 94 | as_text = as_bytes.decode('utf-8') 95 | assert as_text == CRIME_AND_PUNISHMENT 96 | 97 | 98 | @pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )]) 99 | def test_context_manager(schema): 100 | with mock.patch('subprocess.Popen', return_value=cat(CAP_PATH)): 101 | with smart_open.hdfs.CliRawInputBase(f'{schema}://dummy/url') as fin: 102 | as_bytes = fin.read() 103 | 104 | as_text = as_bytes.decode('utf-8').replace('\r\n', '\n') 105 | assert as_text == CRIME_AND_PUNISHMENT 106 | 107 | 108 | @pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )]) 109 | def test_write(schema): 110 | expected = 'мы в ответе за тех, кого приручили' 111 | mocked_cat = cat() 112 | 113 | with mock.patch('subprocess.Popen', return_value=mocked_cat): 114 | with smart_open.hdfs.CliRawOutputBase(f'{schema}://dummy/url') as fout: 115 | fout.write(expected.encode('utf-8')) 116 | 117 | actual = mocked_cat.stdout.read().decode('utf-8') 118 | assert actual == expected 119 | 120 | 121 | @pytest.mark.parametrize('schema', [('hdfs', ), ('viewfs', )]) 122 | def test_write_zip(schema): 123 | expected = 'мы в ответе за тех, кого приручили' 124 | mocked_cat = cat() 125 | 126 | with mock.patch('subprocess.Popen', return_value=mocked_cat): 127 | with smart_open.hdfs.CliRawOutputBase(f'{schema}://dummy/url') as fout: 128 | with gzip.GzipFile(fileobj=fout, mode='wb') as gz_fout: 129 | gz_fout.write(expected.encode('utf-8')) 130 | 131 | with gzip.GzipFile(fileobj=mocked_cat.stdout) as fin: 132 | actual = fin.read().decode('utf-8') 133 | 134 | assert actual == expected 135 | 136 | 137 | def main(): 138 | try: 139 | path = sys.argv[1] 140 | except IndexError: 141 | bytez = sys.stdin.buffer.read() 142 | else: 143 | with open(path, 'rb') as fin: 144 | bytez = fin.read() 145 | 146 | sys.stdout.buffer.write(bytez) 147 | sys.stdout.flush() 148 | 149 | 150 | if __name__ == '__main__': 151 | main() 152 | -------------------------------------------------------------------------------- /smart_open/tests/test_http.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import functools 9 | import os 10 | import unittest 11 | 12 | import pytest 13 | import responses 14 | 15 | import smart_open.http 16 | import smart_open.s3 17 | import smart_open.constants 18 | import requests 19 | 20 | BYTES = b'i tried so hard and got so far but in the end it doesn\'t even matter' 21 | URL = 'http://localhost' 22 | HTTPS_URL = 'https://localhost' 23 | HEADERS = { 24 | 'Accept-Ranges': 'bytes', 25 | } 26 | 27 | 28 | def request_callback(request, headers=HEADERS, data=BYTES): 29 | headers = headers.copy() 30 | range_string = request.headers.get('range', 'bytes=0-') 31 | 32 | start, end = range_string.replace('bytes=', '', 1).split('-', 1) 33 | start = int(start) 34 | end = int(end) if end else len(data) 35 | 36 | data = data[start:end] 37 | headers['Content-Length'] = str(len(data)) 38 | 39 | return (200, headers, data) 40 | 41 | 42 | @unittest.skipIf(os.environ.get('TRAVIS'), 'This test does not work on TravisCI for some reason') 43 | class HttpTest(unittest.TestCase): 44 | 45 | @responses.activate 46 | def test_read_all(self): 47 | responses.add(responses.GET, URL, body=BYTES) 48 | reader = smart_open.http.SeekableBufferedInputBase(URL) 49 | read_bytes = reader.read() 50 | self.assertEqual(BYTES, read_bytes) 51 | 52 | @responses.activate 53 | def test_seek_from_start(self): 54 | responses.add_callback(responses.GET, URL, callback=request_callback) 55 | reader = smart_open.http.SeekableBufferedInputBase(URL) 56 | 57 | reader.seek(10) 58 | self.assertEqual(reader.tell(), 10) 59 | read_bytes = reader.read(size=10) 60 | self.assertEqual(reader.tell(), 20) 61 | self.assertEqual(BYTES[10:20], read_bytes) 62 | 63 | reader.seek(20) 64 | read_bytes = reader.read(size=10) 65 | self.assertEqual(BYTES[20:30], read_bytes) 66 | 67 | reader.seek(0) 68 | read_bytes = reader.read(size=10) 69 | self.assertEqual(BYTES[:10], read_bytes) 70 | 71 | @responses.activate 72 | def test_seek_from_current(self): 73 | responses.add_callback(responses.GET, URL, callback=request_callback) 74 | reader = smart_open.http.SeekableBufferedInputBase(URL) 75 | 76 | reader.seek(10) 77 | read_bytes = reader.read(size=10) 78 | self.assertEqual(BYTES[10:20], read_bytes) 79 | 80 | self.assertEqual(reader.tell(), 20) 81 | reader.seek(10, whence=smart_open.constants.WHENCE_CURRENT) 82 | self.assertEqual(reader.tell(), 30) 83 | read_bytes = reader.read(size=10) 84 | self.assertEqual(reader.tell(), 40) 85 | self.assertEqual(BYTES[30:40], read_bytes) 86 | 87 | @responses.activate 88 | def test_seek_from_end(self): 89 | responses.add_callback(responses.GET, URL, callback=request_callback) 90 | reader = smart_open.http.SeekableBufferedInputBase(URL) 91 | 92 | reader.seek(-10, whence=smart_open.constants.WHENCE_END) 93 | self.assertEqual(reader.tell(), len(BYTES) - 10) 94 | read_bytes = reader.read(size=10) 95 | self.assertEqual(reader.tell(), len(BYTES)) 96 | self.assertEqual(BYTES[-10:], read_bytes) 97 | 98 | @responses.activate 99 | def test_headers_are_as_assigned(self): 100 | responses.add_callback(responses.GET, URL, callback=request_callback) 101 | 102 | # use default _HEADERS 103 | x = smart_open.http.BufferedInputBase(URL) 104 | # set different ones 105 | x.headers['Accept-Encoding'] = 'compress, gzip' 106 | x.headers['Other-Header'] = 'value' 107 | 108 | # use default again, global shoudn't overwritten from x 109 | y = smart_open.http.BufferedInputBase(URL) 110 | # should be default headers 111 | self.assertEqual(y.headers, {'Accept-Encoding': 'identity'}) 112 | # should be assigned headers 113 | self.assertEqual(x.headers, {'Accept-Encoding': 'compress, gzip', 'Other-Header': 'value'}) 114 | 115 | @responses.activate 116 | def test_headers(self): 117 | """Does the top-level http.open function handle headers correctly?""" 118 | responses.add_callback(responses.GET, URL, callback=request_callback) 119 | reader = smart_open.http.open(URL, 'rb', headers={'Foo': 'bar'}) 120 | self.assertEqual(reader.headers['Foo'], 'bar') 121 | 122 | @responses.activate 123 | def test_https_seek_start(self): 124 | """Did the seek start over HTTPS work?""" 125 | responses.add_callback(responses.GET, HTTPS_URL, callback=request_callback) 126 | 127 | with smart_open.open(HTTPS_URL, "rb") as fin: 128 | read_bytes_1 = fin.read(size=10) 129 | fin.seek(0) 130 | read_bytes_2 = fin.read(size=10) 131 | self.assertEqual(read_bytes_1, read_bytes_2) 132 | 133 | @responses.activate 134 | def test_https_seek_forward(self): 135 | """Did the seek forward over HTTPS work?""" 136 | responses.add_callback(responses.GET, HTTPS_URL, callback=request_callback) 137 | 138 | with smart_open.open(HTTPS_URL, "rb") as fin: 139 | fin.seek(10) 140 | read_bytes = fin.read(size=10) 141 | self.assertEqual(BYTES[10:20], read_bytes) 142 | 143 | @responses.activate 144 | def test_https_seek_reverse(self): 145 | """Did the seek in reverse over HTTPS work?""" 146 | responses.add_callback(responses.GET, HTTPS_URL, callback=request_callback) 147 | 148 | with smart_open.open(HTTPS_URL, "rb") as fin: 149 | read_bytes_1 = fin.read(size=10) 150 | fin.seek(-10, whence=smart_open.constants.WHENCE_CURRENT) 151 | read_bytes_2 = fin.read(size=10) 152 | self.assertEqual(read_bytes_1, read_bytes_2) 153 | 154 | @responses.activate 155 | def test_timeout_attribute(self): 156 | timeout = 1 157 | responses.add_callback(responses.GET, URL, callback=request_callback) 158 | reader = smart_open.open(URL, "rb", transport_params={'timeout': timeout}) 159 | assert hasattr(reader, 'timeout') 160 | assert reader.timeout == timeout 161 | 162 | @responses.activate 163 | def test_session_attribute(self): 164 | session = requests.Session() 165 | responses.add_callback(responses.GET, URL, callback=request_callback) 166 | reader = smart_open.open(URL, "rb", transport_params={'session': session}) 167 | assert hasattr(reader, 'session') 168 | assert reader.session == session 169 | assert reader.read() == BYTES 170 | 171 | 172 | @responses.activate 173 | def test_seek_implicitly_enabled(numbytes=10): 174 | """Can we seek even if the server hasn't explicitly allowed it?""" 175 | callback = functools.partial(request_callback, headers={}) 176 | responses.add_callback(responses.GET, HTTPS_URL, callback=callback) 177 | with smart_open.open(HTTPS_URL, 'rb') as fin: 178 | assert fin.seekable() 179 | first = fin.read(size=numbytes) 180 | fin.seek(-numbytes, whence=smart_open.constants.WHENCE_CURRENT) 181 | second = fin.read(size=numbytes) 182 | assert first == second 183 | 184 | 185 | @responses.activate 186 | def test_seek_implicitly_disabled(): 187 | """Does seeking fail when the server has explicitly disabled it?""" 188 | callback = functools.partial(request_callback, headers={'Accept-Ranges': 'none'}) 189 | responses.add_callback(responses.GET, HTTPS_URL, callback=callback) 190 | with smart_open.open(HTTPS_URL, 'rb') as fin: 191 | assert not fin.seekable() 192 | fin.read() 193 | with pytest.raises(OSError): 194 | fin.seek(0) 195 | -------------------------------------------------------------------------------- /smart_open/tests/test_package.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import unittest 4 | import pytest 5 | 6 | from smart_open import open 7 | 8 | skip_tests = "SMART_OPEN_TEST_MISSING_DEPS" not in os.environ 9 | 10 | 11 | class PackageTests(unittest.TestCase): 12 | 13 | @pytest.mark.skipif(skip_tests, reason="requires missing dependencies") 14 | def test_azure_raises_helpful_error_with_missing_deps(self): 15 | with pytest.raises(ImportError, match=r"pip install smart_open\[azure\]"): 16 | open("azure://foo/bar") 17 | 18 | @pytest.mark.skipif(skip_tests, reason="requires missing dependencies") 19 | def test_aws_raises_helpful_error_with_missing_deps(self): 20 | match = r"pip install smart_open\[s3\]" 21 | with pytest.raises(ImportError, match=match): 22 | open("s3://foo/bar") 23 | 24 | @pytest.mark.skipif(skip_tests, reason="requires missing dependencies") 25 | def test_gcs_raises_helpful_error_with_missing_deps(self): 26 | with pytest.raises(ImportError, match=r"pip install smart_open\[gcs\]"): 27 | open("gs://foo/bar") 28 | -------------------------------------------------------------------------------- /smart_open/tests/test_s3_version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import functools 3 | import logging 4 | import unittest 5 | import uuid 6 | import time 7 | 8 | import boto3 9 | 10 | # See https://github.com/piskvorky/smart_open/issues/800 11 | # This supports moto 4 & 5 until v4 is no longer used by distros. 12 | try: 13 | from moto import mock_s3 14 | except ImportError: 15 | from moto import mock_aws as mock_s3 16 | 17 | from smart_open import open 18 | 19 | BUCKET_NAME = 'test-smartopen' 20 | KEY_NAME = 'test-key' 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | _resource = functools.partial(boto3.resource, region_name='us-east-1') 27 | 28 | 29 | def get_versions(bucket, key): 30 | """Return object versions in chronological order.""" 31 | return [ 32 | v.id 33 | for v in sorted( 34 | _resource('s3').Bucket(bucket).object_versions.filter(Prefix=key), 35 | key=lambda version: version.last_modified, 36 | ) 37 | ] 38 | 39 | 40 | @mock_s3 41 | class TestVersionId(unittest.TestCase): 42 | def setUp(self): 43 | # 44 | # Each run of this test reuses the BUCKET_NAME, but works with a 45 | # different key for isolation. 46 | # 47 | resource = _resource('s3') 48 | resource.create_bucket(Bucket=BUCKET_NAME).wait_until_exists() 49 | resource.BucketVersioning(BUCKET_NAME).enable() 50 | 51 | self.key = 'test-write-key-{}'.format(uuid.uuid4().hex) 52 | self.url = "s3://%s/%s" % (BUCKET_NAME, self.key) 53 | self.test_ver1 = u"String version 1.0".encode('utf8') 54 | self.test_ver2 = u"String version 2.0".encode('utf8') 55 | 56 | bucket = resource.Bucket(BUCKET_NAME) 57 | bucket.put_object(Key=self.key, Body=self.test_ver1) 58 | logging.critical('versions after first write: %r', get_versions(BUCKET_NAME, self.key)) 59 | 60 | time.sleep(3) 61 | 62 | bucket.put_object(Key=self.key, Body=self.test_ver2) 63 | self.versions = get_versions(BUCKET_NAME, self.key) 64 | logging.critical('versions after second write: %r', get_versions(BUCKET_NAME, self.key)) 65 | 66 | assert len(self.versions) == 2 67 | 68 | def test_good_id(self): 69 | """Does passing the version_id parameter into the s3 submodule work correctly when reading?""" 70 | params = {'version_id': self.versions[0]} 71 | with open(self.url, mode='rb', transport_params=params) as fin: 72 | actual = fin.read() 73 | self.assertEqual(actual, self.test_ver1) 74 | 75 | def test_bad_id(self): 76 | """Does passing an invalid version_id exception into the s3 submodule get handled correctly?""" 77 | params = {'version_id': 'bad-version-does-not-exist'} 78 | with self.assertRaises(IOError): 79 | open(self.url, 'rb', transport_params=params) 80 | 81 | def test_bad_mode(self): 82 | """Do we correctly handle non-None version when writing?""" 83 | params = {'version_id': self.versions[0]} 84 | with self.assertRaises(ValueError): 85 | open(self.url, 'wb', transport_params=params) 86 | 87 | def test_no_version(self): 88 | """Passing in no version at all gives the newest version of the file?""" 89 | with open(self.url, 'rb') as fin: 90 | actual = fin.read() 91 | self.assertEqual(actual, self.test_ver2) 92 | 93 | def test_newest_version(self): 94 | """Passing in the newest version explicitly gives the most recent content?""" 95 | params = {'version_id': self.versions[1]} 96 | with open(self.url, mode='rb', transport_params=params) as fin: 97 | actual = fin.read() 98 | self.assertEqual(actual, self.test_ver2) 99 | 100 | def test_oldest_version(self): 101 | """Passing in the oldest version gives the oldest content?""" 102 | params = {'version_id': self.versions[0]} 103 | with open(self.url, mode='rb', transport_params=params) as fin: 104 | actual = fin.read() 105 | self.assertEqual(actual, self.test_ver1) 106 | 107 | def test_version_to_boto3(self): 108 | """Passing in the oldest version gives the oldest content?""" 109 | self.versions = get_versions(BUCKET_NAME, self.key) 110 | params = {'version_id': self.versions[0]} 111 | with open(self.url, mode='rb', transport_params=params) as fin: 112 | returned_obj = fin.to_boto3(_resource('s3')) 113 | 114 | boto3_body = boto3_body = returned_obj.get()['Body'].read() 115 | self.assertEqual(boto3_body, self.test_ver1) 116 | 117 | 118 | if __name__ == '__main__': 119 | unittest.main() 120 | -------------------------------------------------------------------------------- /smart_open/tests/test_ssh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import logging 4 | import os 5 | import unittest 6 | from unittest import mock 7 | 8 | from paramiko import SSHException 9 | 10 | import smart_open.ssh 11 | 12 | _TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "test_data") 13 | _CONFIG_PATH = os.path.join(_TEST_DATA_PATH, "ssh.cfg") 14 | 15 | 16 | def mock_ssh(func): 17 | def wrapper(*args, **kwargs): 18 | smart_open.ssh._SSH.clear() 19 | return func(*args, **kwargs) 20 | 21 | return mock.patch("paramiko.SSHClient.get_transport")( 22 | mock.patch("paramiko.SSHClient.connect")(wrapper) 23 | ) 24 | 25 | 26 | class SSHOpen(unittest.TestCase): 27 | def setUp(self): 28 | self._cfg_files = smart_open.ssh._SSH_CONFIG_FILES 29 | smart_open.ssh._SSH_CONFIG_FILES = [_CONFIG_PATH] 30 | 31 | def tearDown(self): 32 | smart_open.ssh._SSH_CONFIG_FILES = self._cfg_files 33 | 34 | @mock_ssh 35 | def test_open(self, mock_connect, get_transp_mock): 36 | smart_open.open("ssh://user:pass@some-host/") 37 | mock_connect.assert_called_with("some-host", 22, username="user", password="pass") 38 | 39 | @mock_ssh 40 | def test_percent_encoding(self, mock_connect, get_transp_mock): 41 | smart_open.open("ssh://user%3a:pass%40@some-host/") 42 | mock_connect.assert_called_with("some-host", 22, username="user:", password="pass@") 43 | 44 | @mock_ssh 45 | def test_open_without_password(self, mock_connect, get_transp_mock): 46 | smart_open.open("ssh://user@some-host/") 47 | mock_connect.assert_called_with("some-host", 22, username="user", password=None) 48 | 49 | @mock_ssh 50 | def test_open_with_transport_params(self, mock_connect, get_transp_mock): 51 | smart_open.open( 52 | "ssh://user:pass@some-host/", 53 | transport_params={"connect_kwargs": {"username": "ubuntu", "password": "pwd"}}, 54 | ) 55 | mock_connect.assert_called_with("some-host", 22, username="ubuntu", password="pwd") 56 | 57 | @mock_ssh 58 | def test_open_with_key_filename(self, mock_connect, get_transp_mock): 59 | smart_open.open( 60 | "ssh://user@some-host/", 61 | transport_params={"connect_kwargs": {"key_filename": "key"}}, 62 | ) 63 | mock_connect.assert_called_with("some-host", 22, username="user", key_filename="key") 64 | 65 | @mock_ssh 66 | def test_reconnect_after_session_timeout(self, mock_connect, get_transp_mock): 67 | mock_sftp = get_transp_mock().open_sftp_client() 68 | get_transp_mock().open_sftp_client.reset_mock() 69 | 70 | def mocked_open_sftp(): 71 | if len(mock_connect.call_args_list) < 2: # simulate timeout until second connect() 72 | yield SSHException('SSH session not active') 73 | while True: 74 | yield mock_sftp 75 | 76 | get_transp_mock().open_sftp_client.side_effect = mocked_open_sftp() 77 | 78 | smart_open.open("ssh://user:pass@some-host/") 79 | mock_connect.assert_called_with("some-host", 22, username="user", password="pass") 80 | mock_sftp.open.assert_called_once() 81 | 82 | @mock_ssh 83 | def test_open_with_openssh_config(self, mock_connect, get_transp_mock): 84 | smart_open.open("ssh://another-host/") 85 | mock_connect.assert_called_with( 86 | "another-host-domain.com", 87 | 2345, 88 | username="another-user", 89 | key_filename=["/path/to/key/file"], 90 | timeout=20., 91 | compress=True, 92 | gss_auth=False, 93 | gss_kex=False, 94 | gss_deleg_creds=False, 95 | gss_trust_dns=False, 96 | ) 97 | 98 | @mock_ssh 99 | def test_open_with_openssh_config_override_port(self, mock_connect, get_transp_mock): 100 | smart_open.open("ssh://another-host:22/") 101 | mock_connect.assert_called_with( 102 | "another-host-domain.com", 103 | 22, 104 | username="another-user", 105 | key_filename=["/path/to/key/file"], 106 | timeout=20., 107 | compress=True, 108 | gss_auth=False, 109 | gss_kex=False, 110 | gss_deleg_creds=False, 111 | gss_trust_dns=False, 112 | ) 113 | 114 | @mock_ssh 115 | def test_open_with_openssh_config_override_user(self, mock_connect, get_transp_mock): 116 | smart_open.open("ssh://new-user@another-host/") 117 | mock_connect.assert_called_with( 118 | "another-host-domain.com", 119 | 2345, 120 | username="new-user", 121 | key_filename=["/path/to/key/file"], 122 | timeout=20., 123 | compress=True, 124 | gss_auth=False, 125 | gss_kex=False, 126 | gss_deleg_creds=False, 127 | gss_trust_dns=False, 128 | ) 129 | 130 | @mock_ssh 131 | def test_open_with_prefetch(self, mock_connect, get_transp_mock): 132 | smart_open.open( 133 | "ssh://user:pass@some-host/", 134 | transport_params={"prefetch_kwargs": {"max_concurrent_requests": 3}}, 135 | ) 136 | mock_sftp = get_transp_mock().open_sftp_client() 137 | mock_fobj = mock_sftp.open() 138 | mock_fobj.prefetch.assert_called_with(max_concurrent_requests=3) 139 | 140 | @mock_ssh 141 | def test_open_without_prefetch(self, mock_connect, get_transp_mock): 142 | smart_open.open("ssh://user:pass@some-host/") 143 | mock_sftp = get_transp_mock().open_sftp_client() 144 | mock_fobj = mock_sftp.open() 145 | mock_fobj.prefetch.assert_not_called() 146 | 147 | 148 | if __name__ == "__main__": 149 | logging.basicConfig(format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG) 150 | unittest.main() 151 | -------------------------------------------------------------------------------- /smart_open/tests/test_transport.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pytest 3 | import unittest 4 | 5 | from smart_open.transport import register_transport, get_transport 6 | 7 | 8 | class TransportTest(unittest.TestCase): 9 | 10 | def test_registry_requires_declared_schemes(self): 11 | with pytest.raises(ValueError): 12 | register_transport('smart_open.tests.fixtures.no_schemes_transport') 13 | 14 | def test_registry_errors_on_double_register_scheme(self): 15 | register_transport('smart_open.tests.fixtures.good_transport') 16 | with pytest.raises(AssertionError): 17 | register_transport('smart_open.tests.fixtures.good_transport') 18 | 19 | def test_registry_errors_get_transport_for_module_with_missing_deps(self): 20 | register_transport('smart_open.tests.fixtures.missing_deps_transport') 21 | with pytest.raises(ImportError): 22 | get_transport("missing") 23 | -------------------------------------------------------------------------------- /smart_open/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | import urllib.parse 9 | 10 | import pytest 11 | 12 | import smart_open.utils 13 | 14 | 15 | @pytest.mark.parametrize( 16 | 'value,minval,maxval,expected', 17 | [ 18 | (5, 0, 10, 5), 19 | (11, 0, 10, 10), 20 | (-1, 0, 10, 0), 21 | (10, 0, None, 10), 22 | (-10, 0, None, 0), 23 | ] 24 | ) 25 | def test_clamp(value, minval, maxval, expected): 26 | assert smart_open.utils.clamp(value, minval=minval, maxval=maxval) == expected 27 | 28 | 29 | @pytest.mark.parametrize( 30 | 'value,params,expected', 31 | [ 32 | (10, {}, 10), 33 | (-10, {}, 0), 34 | (-10, {'minval': -5}, -5), 35 | (10, {'maxval': 5}, 5), 36 | ] 37 | ) 38 | def test_clamp_defaults(value, params, expected): 39 | assert smart_open.utils.clamp(value, **params) == expected 40 | 41 | 42 | def test_check_kwargs(): 43 | import smart_open.s3 44 | kallable = smart_open.s3.open 45 | kwargs = {'client': 'foo', 'unsupported': 'bar', 'client_kwargs': 'boaz'} 46 | supported = smart_open.utils.check_kwargs(kallable, kwargs) 47 | assert supported == {'client': 'foo', 'client_kwargs': 'boaz'} 48 | 49 | 50 | @pytest.mark.parametrize( 51 | 'url,expected', 52 | [ 53 | ('s3://bucket/key', ('s3', 'bucket', '/key', '', '')), 54 | ('s3://bucket/key?', ('s3', 'bucket', '/key?', '', '')), 55 | ('s3://bucket/???', ('s3', 'bucket', '/???', '', '')), 56 | ('https://host/path?foo=bar', ('https', 'host', '/path', 'foo=bar', '')), 57 | ] 58 | ) 59 | def test_safe_urlsplit(url, expected): 60 | actual = smart_open.utils.safe_urlsplit(url) 61 | assert actual == urllib.parse.SplitResult(*expected) 62 | -------------------------------------------------------------------------------- /smart_open/transport.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | """Maintains a registry of transport mechanisms. 9 | 10 | The main entrypoint is :func:`get_transport`. See also :file:`extending.md`. 11 | 12 | """ 13 | import importlib 14 | import logging 15 | 16 | import smart_open.local_file 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | NO_SCHEME = '' 21 | 22 | _REGISTRY = {NO_SCHEME: smart_open.local_file} 23 | _ERRORS = {} 24 | _MISSING_DEPS_ERROR = """You are trying to use the %(module)s functionality of smart_open 25 | but you do not have the correct %(module)s dependencies installed. Try: 26 | 27 | pip install smart_open[%(module)s] 28 | 29 | """ 30 | 31 | 32 | def register_transport(submodule): 33 | """Register a submodule as a transport mechanism for ``smart_open``. 34 | 35 | This module **must** have: 36 | 37 | - `SCHEME` attribute (or `SCHEMES`, if the submodule supports multiple schemes) 38 | - `open` function 39 | - `open_uri` function 40 | - `parse_uri' function 41 | 42 | Once registered, you can get the submodule by calling :func:`get_transport`. 43 | 44 | """ 45 | module_name = submodule 46 | if isinstance(submodule, str): 47 | try: 48 | submodule = importlib.import_module(submodule) 49 | except ImportError: 50 | return 51 | else: 52 | module_name = submodule.__name__ 53 | # Save only the last module name piece 54 | module_name = module_name.rsplit(".")[-1] 55 | 56 | if hasattr(submodule, "SCHEME"): 57 | schemes = [submodule.SCHEME] 58 | elif hasattr(submodule, "SCHEMES"): 59 | schemes = submodule.SCHEMES 60 | else: 61 | raise ValueError("%r does not have a .SCHEME or .SCHEMES attribute" % submodule) 62 | 63 | for f in ("open", "open_uri", "parse_uri"): 64 | assert hasattr(submodule, f), "%r is missing %r" % (submodule, f) 65 | 66 | for scheme in schemes: 67 | assert scheme not in _REGISTRY 68 | if getattr(submodule, "MISSING_DEPS", False): 69 | _ERRORS[scheme] = module_name 70 | else: 71 | _REGISTRY[scheme] = submodule 72 | 73 | 74 | def get_transport(scheme): 75 | """Get the submodule that handles transport for the specified scheme. 76 | 77 | This submodule must have been previously registered via :func:`register_transport`. 78 | 79 | """ 80 | expected = SUPPORTED_SCHEMES 81 | readme_url = ( 82 | "https://github.com/piskvorky/smart_open/blob/master/README.rst" 83 | ) 84 | message = ( 85 | "Unable to handle scheme %(scheme)r, expected one of %(expected)r. " 86 | "Extra dependencies required by %(scheme)r may be missing. " 87 | "See <%(readme_url)s> for details." % locals() 88 | ) 89 | if scheme in _ERRORS: 90 | raise ImportError(_MISSING_DEPS_ERROR % dict(module=_ERRORS[scheme])) 91 | if scheme in _REGISTRY: 92 | return _REGISTRY[scheme] 93 | raise NotImplementedError(message) 94 | 95 | 96 | register_transport(smart_open.local_file) 97 | register_transport("smart_open.azure") 98 | register_transport("smart_open.ftp") 99 | register_transport("smart_open.gcs") 100 | register_transport("smart_open.hdfs") 101 | register_transport("smart_open.http") 102 | register_transport("smart_open.s3") 103 | register_transport("smart_open.ssh") 104 | register_transport("smart_open.webhdfs") 105 | 106 | SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys())) 107 | """The transport schemes that the local installation of ``smart_open`` supports.""" 108 | -------------------------------------------------------------------------------- /smart_open/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2020 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Helper functions for documentation, etc.""" 10 | 11 | import inspect 12 | import io 13 | import logging 14 | import urllib.parse 15 | 16 | import wrapt 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | WORKAROUND_SCHEMES = ['s3', 's3n', 's3u', 's3a', 'gs'] 21 | QUESTION_MARK_PLACEHOLDER = '///smart_open.utils.QUESTION_MARK_PLACEHOLDER///' 22 | 23 | 24 | def inspect_kwargs(kallable): 25 | # 26 | # inspect.getargspec got deprecated in Py3.4, and calling it spews 27 | # deprecation warnings that we'd prefer to avoid. Unfortunately, older 28 | # versions of Python (<3.3) did not have inspect.signature, so we need to 29 | # handle them the old-fashioned getargspec way. 30 | # 31 | try: 32 | signature = inspect.signature(kallable) 33 | except AttributeError: 34 | try: 35 | args, varargs, keywords, defaults = inspect.getargspec(kallable) 36 | except TypeError: 37 | # 38 | # Happens under Py2.7 with mocking. 39 | # 40 | return {} 41 | 42 | if not defaults: 43 | return {} 44 | supported_keywords = args[-len(defaults):] 45 | return dict(zip(supported_keywords, defaults)) 46 | else: 47 | return { 48 | name: param.default 49 | for name, param in signature.parameters.items() 50 | if param.default != inspect.Parameter.empty 51 | } 52 | 53 | 54 | def check_kwargs(kallable, kwargs): 55 | """Check which keyword arguments the callable supports. 56 | 57 | Parameters 58 | ---------- 59 | kallable: callable 60 | A function or method to test 61 | kwargs: dict 62 | The keyword arguments to check. If the callable doesn't support any 63 | of these, a warning message will get printed. 64 | 65 | Returns 66 | ------- 67 | dict 68 | A dictionary of argument names and values supported by the callable. 69 | """ 70 | supported_keywords = sorted(inspect_kwargs(kallable)) 71 | unsupported_keywords = [k for k in sorted(kwargs) if k not in supported_keywords] 72 | supported_kwargs = {k: v for (k, v) in kwargs.items() if k in supported_keywords} 73 | 74 | if unsupported_keywords: 75 | logger.warning('ignoring unsupported keyword arguments: %r', unsupported_keywords) 76 | 77 | return supported_kwargs 78 | 79 | 80 | def clamp(value, minval=0, maxval=None): 81 | """Clamp a numeric value to a specific range. 82 | 83 | Parameters 84 | ---------- 85 | value: numeric 86 | The value to clamp. 87 | 88 | minval: numeric 89 | The lower bound. 90 | 91 | maxval: numeric 92 | The upper bound. 93 | 94 | Returns 95 | ------- 96 | numeric 97 | The clamped value. It will be in the range ``[minval, maxval]``. 98 | 99 | """ 100 | if maxval is not None: 101 | value = min(value, maxval) 102 | value = max(value, minval) 103 | return value 104 | 105 | 106 | def make_range_string(start=None, stop=None): 107 | """Create a byte range specifier in accordance with RFC-2616. 108 | 109 | Parameters 110 | ---------- 111 | start: int, optional 112 | The start of the byte range. If unspecified, stop indicated offset from EOF. 113 | 114 | stop: int, optional 115 | The end of the byte range. If unspecified, indicates EOF. 116 | 117 | Returns 118 | ------- 119 | str 120 | A byte range specifier. 121 | 122 | """ 123 | # 124 | # https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 125 | # 126 | if start is None and stop is None: 127 | raise ValueError("make_range_string requires either a stop or start value") 128 | start_str = '' if start is None else str(start) 129 | stop_str = '' if stop is None else str(stop) 130 | return 'bytes=%s-%s' % (start_str, stop_str) 131 | 132 | 133 | def parse_content_range(content_range): 134 | """Extract units, start, stop, and length from a content range header like "bytes 0-846981/846982". 135 | 136 | Assumes a properly formatted content-range header from S3. 137 | See werkzeug.http.parse_content_range_header for a more robust version. 138 | 139 | Parameters 140 | ---------- 141 | content_range: str 142 | The content-range header to parse. 143 | 144 | Returns 145 | ------- 146 | tuple (units: str, start: int, stop: int, length: int) 147 | The units and three integers from the content-range header. 148 | 149 | """ 150 | units, numbers = content_range.split(' ', 1) 151 | range, length = numbers.split('/', 1) 152 | start, stop = range.split('-', 1) 153 | return units, int(start), int(stop), int(length) 154 | 155 | 156 | def safe_urlsplit(url): 157 | """This is a hack to prevent the regular urlsplit from splitting around question marks. 158 | 159 | A question mark (?) in a URL typically indicates the start of a 160 | querystring, and the standard library's urlparse function handles the 161 | querystring separately. Unfortunately, question marks can also appear 162 | _inside_ the actual URL for some schemas like S3, GS. 163 | 164 | Replaces question marks with a special placeholder substring prior to 165 | splitting. This work-around behavior is disabled in the unlikely event the 166 | placeholder is already part of the URL. If this affects you, consider 167 | changing the value of QUESTION_MARK_PLACEHOLDER to something more suitable. 168 | 169 | See Also 170 | -------- 171 | https://bugs.python.org/issue43882 172 | https://github.com/python/cpython/blob/3.7/Lib/urllib/parse.py 173 | https://github.com/piskvorky/smart_open/issues/285 174 | https://github.com/piskvorky/smart_open/issues/458 175 | smart_open/utils.py:QUESTION_MARK_PLACEHOLDER 176 | """ 177 | sr = urllib.parse.urlsplit(url, allow_fragments=False) 178 | 179 | placeholder = None 180 | if sr.scheme in WORKAROUND_SCHEMES and '?' in url and QUESTION_MARK_PLACEHOLDER not in url: 181 | # 182 | # This is safe because people will _almost never_ use the below 183 | # substring in a URL. If they do, then they're asking for trouble, 184 | # and this special handling will simply not happen for them. 185 | # 186 | placeholder = QUESTION_MARK_PLACEHOLDER 187 | url = url.replace('?', placeholder) 188 | sr = urllib.parse.urlsplit(url, allow_fragments=False) 189 | 190 | if placeholder is None: 191 | return sr 192 | 193 | path = sr.path.replace(placeholder, '?') 194 | return urllib.parse.SplitResult(sr.scheme, sr.netloc, path, '', '') 195 | 196 | 197 | class TextIOWrapper(io.TextIOWrapper): 198 | def __exit__(self, exc_type, exc_val, exc_tb): 199 | """Call close on underlying buffer only when there was no exception. 200 | 201 | Without this patch, TextIOWrapper would call self.buffer.close() during 202 | exception handling, which is unwanted for e.g. s3 and azure. They only call 203 | self.close() when there was no exception (self.terminate() otherwise) to avoid 204 | committing unfinished/failed uploads. 205 | """ 206 | if exc_type is None: 207 | self.close() 208 | 209 | 210 | class FileLikeProxy(wrapt.ObjectProxy): 211 | __inner = ... # initialized before wrapt disallows __setattr__ on certain objects 212 | 213 | def __init__(self, outer, inner): 214 | super().__init__(outer) 215 | self.__inner = inner 216 | 217 | def __exit__(self, *args, **kwargs): 218 | """Exit inner after exiting outer.""" 219 | try: 220 | return super().__exit__(*args, **kwargs) 221 | finally: 222 | self.__inner.__exit__(*args, **kwargs) 223 | 224 | def __next__(self): 225 | return self.__wrapped__.__next__() 226 | 227 | def close(self): 228 | try: 229 | return self.__wrapped__.close() 230 | finally: 231 | if self.__inner != self.__wrapped__: # Don't close again if inner and wrapped are the same 232 | self.__inner.close() 233 | -------------------------------------------------------------------------------- /smart_open/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '7.2.0.dev0' 2 | 3 | 4 | if __name__ == '__main__': 5 | print(__version__) 6 | -------------------------------------------------------------------------------- /smart_open/webhdfs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Radim Rehurek 4 | # 5 | # This code is distributed under the terms and conditions 6 | # from the MIT License (MIT). 7 | # 8 | 9 | """Implements reading and writing to/from WebHDFS. 10 | 11 | The main entry point is the :func:`~smart_open.webhdfs.open` function. 12 | 13 | """ 14 | 15 | import io 16 | import logging 17 | import urllib.parse 18 | 19 | try: 20 | import requests 21 | except ImportError: 22 | MISSING_DEPS = True 23 | 24 | from smart_open import utils, constants 25 | 26 | import http.client as httplib 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | SCHEME = 'webhdfs' 31 | 32 | URI_EXAMPLES = ( 33 | 'webhdfs://host:port/path/file', 34 | ) 35 | 36 | MIN_PART_SIZE = 50 * 1024**2 # minimum part size for HDFS multipart uploads 37 | 38 | 39 | def parse_uri(uri_as_str): 40 | return dict(scheme=SCHEME, uri=uri_as_str) 41 | 42 | 43 | def open_uri(uri, mode, transport_params): 44 | kwargs = utils.check_kwargs(open, transport_params) 45 | return open(uri, mode, **kwargs) 46 | 47 | 48 | def open(http_uri, mode, min_part_size=MIN_PART_SIZE): 49 | """ 50 | Parameters 51 | ---------- 52 | http_uri: str 53 | webhdfs url converted to http REST url 54 | min_part_size: int, optional 55 | For writing only. 56 | 57 | """ 58 | if http_uri.startswith(SCHEME): 59 | http_uri = _convert_to_http_uri(http_uri) 60 | 61 | if mode == constants.READ_BINARY: 62 | fobj = BufferedInputBase(http_uri) 63 | elif mode == constants.WRITE_BINARY: 64 | fobj = BufferedOutputBase(http_uri, min_part_size=min_part_size) 65 | else: 66 | raise NotImplementedError("webhdfs support for mode %r not implemented" % mode) 67 | 68 | fobj.name = http_uri.split('/')[-1] 69 | return fobj 70 | 71 | 72 | def _convert_to_http_uri(webhdfs_url): 73 | """ 74 | Convert webhdfs uri to http url and return it as text 75 | 76 | Parameters 77 | ---------- 78 | webhdfs_url: str 79 | A URL starting with webhdfs:// 80 | """ 81 | split_uri = urllib.parse.urlsplit(webhdfs_url) 82 | netloc = split_uri.hostname 83 | if split_uri.port: 84 | netloc += ":{}".format(split_uri.port) 85 | query = split_uri.query 86 | if split_uri.username: 87 | query += ( 88 | ("&" if query else "") + "user.name=" + urllib.parse.quote(split_uri.username) 89 | ) 90 | 91 | return urllib.parse.urlunsplit( 92 | ("http", netloc, "/webhdfs/v1" + split_uri.path, query, "") 93 | ) 94 | 95 | 96 | # 97 | # For old unit tests. 98 | # 99 | def convert_to_http_uri(parsed_uri): 100 | return _convert_to_http_uri(parsed_uri.uri) 101 | 102 | 103 | class BufferedInputBase(io.BufferedIOBase): 104 | _buf = None # so `closed` property works in case __init__ fails and __del__ is called 105 | 106 | def __init__(self, uri): 107 | self._uri = uri 108 | 109 | payload = {"op": "OPEN", "offset": 0} 110 | self._response = requests.get(self._uri, params=payload, stream=True) 111 | if self._response.status_code != httplib.OK: 112 | raise WebHdfsException.from_response(self._response) 113 | self._buf = b'' 114 | 115 | # 116 | # Override some methods from io.IOBase. 117 | # 118 | def close(self): 119 | """Flush and close this stream.""" 120 | logger.debug("close: called") 121 | if not self.closed: 122 | self._buf = None 123 | 124 | @property 125 | def closed(self): 126 | return self._buf is None 127 | 128 | def readable(self): 129 | """Return True if the stream can be read from.""" 130 | return True 131 | 132 | def seekable(self): 133 | """If False, seek(), tell() and truncate() will raise IOError. 134 | 135 | We offer only seek support, and no truncate support.""" 136 | return False 137 | 138 | # 139 | # io.BufferedIOBase methods. 140 | # 141 | def detach(self): 142 | """Unsupported.""" 143 | raise io.UnsupportedOperation 144 | 145 | def read(self, size=None): 146 | if size is None: 147 | self._buf, retval = b'', self._buf + self._response.raw.read() 148 | return retval 149 | elif size < len(self._buf): 150 | self._buf, retval = self._buf[size:], self._buf[:size] 151 | return retval 152 | 153 | try: 154 | buffers = [self._buf] 155 | total_read = 0 156 | while total_read < size: 157 | raw_data = self._response.raw.read(io.DEFAULT_BUFFER_SIZE) 158 | # some times read returns 0 length data without throwing a 159 | # StopIteration exception. We break here if this happens. 160 | if len(raw_data) == 0: 161 | break 162 | 163 | total_read += len(raw_data) 164 | buffers.append(raw_data) 165 | except StopIteration: 166 | pass 167 | 168 | self._buf = b"".join(buffers) 169 | self._buf, retval = self._buf[size:], self._buf[:size] 170 | return retval 171 | 172 | def read1(self, size=-1): 173 | """This is the same as read().""" 174 | return self.read(size=size) 175 | 176 | def readinto(self, b): 177 | """Read up to len(b) bytes into b, and return the number of bytes 178 | read.""" 179 | data = self.read(len(b)) 180 | if not data: 181 | return 0 182 | b[:len(data)] = data 183 | return len(data) 184 | 185 | def readline(self): 186 | self._buf, retval = b'', self._buf + self._response.raw.readline() 187 | return retval 188 | 189 | 190 | class BufferedOutputBase(io.BufferedIOBase): 191 | def __init__(self, uri, min_part_size=MIN_PART_SIZE): 192 | """ 193 | Parameters 194 | ---------- 195 | min_part_size: int, optional 196 | For writing only. 197 | 198 | """ 199 | self._uri = uri 200 | self._closed = False 201 | self.min_part_size = min_part_size 202 | # creating empty file first 203 | payload = {"op": "CREATE", "overwrite": True} 204 | init_response = requests.put(self._uri, params=payload, allow_redirects=False) 205 | if not init_response.status_code == httplib.TEMPORARY_REDIRECT: 206 | raise WebHdfsException.from_response(init_response) 207 | uri = init_response.headers['location'] 208 | response = requests.put(uri, data="", headers={'content-type': 'application/octet-stream'}) 209 | if not response.status_code == httplib.CREATED: 210 | raise WebHdfsException.from_response(response) 211 | self.lines = [] 212 | self.parts = 0 213 | self.chunk_bytes = 0 214 | self.total_size = 0 215 | 216 | # 217 | # This member is part of the io.BufferedIOBase interface. 218 | # 219 | self.raw = None 220 | 221 | # 222 | # Override some methods from io.IOBase. 223 | # 224 | def writable(self): 225 | """Return True if the stream supports writing.""" 226 | return True 227 | 228 | # 229 | # io.BufferedIOBase methods. 230 | # 231 | def detach(self): 232 | raise io.UnsupportedOperation("detach() not supported") 233 | 234 | def _upload(self, data): 235 | payload = {"op": "APPEND"} 236 | init_response = requests.post(self._uri, params=payload, allow_redirects=False) 237 | if not init_response.status_code == httplib.TEMPORARY_REDIRECT: 238 | raise WebHdfsException.from_response(init_response) 239 | uri = init_response.headers['location'] 240 | response = requests.post(uri, data=data, 241 | headers={'content-type': 'application/octet-stream'}) 242 | if not response.status_code == httplib.OK: 243 | raise WebHdfsException.from_response(response) 244 | 245 | def write(self, b): 246 | """ 247 | Write the given bytes (binary string) into the WebHDFS file from constructor. 248 | 249 | """ 250 | if self._closed: 251 | raise ValueError("I/O operation on closed file") 252 | 253 | if not isinstance(b, bytes): 254 | raise TypeError("input must be a binary string") 255 | 256 | self.lines.append(b) 257 | self.chunk_bytes += len(b) 258 | self.total_size += len(b) 259 | 260 | if self.chunk_bytes >= self.min_part_size: 261 | buff = b"".join(self.lines) 262 | logger.info( 263 | "uploading part #%i, %i bytes (total %.3fGB)", 264 | self.parts, len(buff), self.total_size / 1024.0 ** 3 265 | ) 266 | self._upload(buff) 267 | logger.debug("upload of part #%i finished", self.parts) 268 | self.parts += 1 269 | self.lines, self.chunk_bytes = [], 0 270 | 271 | def close(self): 272 | buff = b"".join(self.lines) 273 | if buff: 274 | logger.info( 275 | "uploading last part #%i, %i bytes (total %.3fGB)", 276 | self.parts, len(buff), self.total_size / 1024.0 ** 3 277 | ) 278 | self._upload(buff) 279 | logger.debug("upload of last part #%i finished", self.parts) 280 | self._closed = True 281 | 282 | @property 283 | def closed(self): 284 | return self._closed 285 | 286 | 287 | class WebHdfsException(Exception): 288 | def __init__(self, msg="", status_code=None): 289 | self.msg = msg 290 | self.status_code = status_code 291 | super(WebHdfsException, self).__init__(repr(self)) 292 | 293 | def __repr__(self): 294 | return "{}(status_code={}, msg={!r})".format( 295 | self.__class__.__name__, self.status_code, self.msg 296 | ) 297 | 298 | @classmethod 299 | def from_response(cls, response): 300 | return cls(msg=response.text, status_code=response.status_code) 301 | -------------------------------------------------------------------------------- /update_helptext.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Write out help.txt based on the current codebase.""" 3 | 4 | import subprocess 5 | from pathlib import Path 6 | 7 | # get the latest helptext 8 | helptext = subprocess.check_output( 9 | ["/usr/bin/env", "python3", "-c", 'help("smart_open")'], 10 | text=True, 11 | ).strip() 12 | 13 | # remove the user-specific FILE and VERSION section at the bottom to make this script reproducible 14 | lines = helptext.splitlines()[:-5] 15 | 16 | Path("help.txt").write_text("\n".join(line.rstrip() for line in lines)) 17 | --------------------------------------------------------------------------------