├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── requirements-dev.txt ├── setup.cfg ├── setup.py └── src └── splitcopy ├── __init__.py ├── ftp.py ├── get.py ├── paramikoshell.py ├── progress.py ├── put.py ├── shared.py ├── splitcopy.py └── tests ├── __init__.py ├── test_ftp.py ├── test_get.py ├── test_paramikoshell.py ├── test_progress.py ├── test_put.py ├── test_shared.py └── test_splitcopy.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.8", "3.9", "3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest pytest-cov coverage 31 | python -m pip install paramiko scp 32 | - name: Lint with flake8 33 | run: | 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | - name: Test with pytest, produce coverage report 39 | run: | 40 | pytest --cov src 41 | coverage html 42 | - name: Archive code coverage html report 43 | uses: actions/upload-artifact@v3 44 | with: 45 | name: code-coverage-report 46 | path: htmlcov/ 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm 107 | .idea 108 | 109 | # OSX stuff 110 | .DS_Store 111 | 112 | # VScode 113 | .vscode 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Juniper Networks, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Splitcopy 2 | 3 | Improves file transfer rates when copying files to/from JUNOS/EVO/\*nix hosts. 4 | 5 | At a minimum, sshd must be running on the remote host. 6 | On JUNOS/EVO this requires 'system services ssh' configuration. 7 | 8 | If using ftp to copy files then an ftp daemon must be running on the remote host. 9 | On JUNOS this requires 'system services ftp' configuration. 10 | FTP is the default transfer method due to its lower resource usage and its ability to restart transfers. 11 | 12 | Script overheads include authentication, sha hash generation/comparison, disk space check, file split and join. 13 | It can be slower than normal ftp/scp for small files as a result. 14 | 15 | Because it opens a number of simultaneous connections, 16 | if the JUNOS/EVO host has connection/rate limits configured like this: 17 | 18 | ``` 19 | system { 20 | services { 21 | ssh { # or ftp 22 | connection-limit 10; 23 | rate-limit 10; 24 | } 25 | } 26 | } 27 | ``` 28 | or system login retry-options: 29 | 30 | ``` 31 | system { 32 | login { 33 | retry-options { 34 | <..> 35 | } 36 | } 37 | } 38 | ``` 39 | 40 | The script will deactivate these limits so it can proceed, then rollback these changes upon completion. 41 | 42 | ## Arguments 43 | 44 | `source` Mandatory 45 | `target` Mandatory 46 | `--pwd` Optional, password 47 | `--scp` Optional, use scp instead of ftp to transfer files 48 | `--ssh_key` Optional, path to private ssh key (only required if not located in ~/.ssh/) 49 | `--log` Optional, enables additional logging, specify a logging level as argument 50 | `--noverify` Optional, skips sha1 hash comparison of src and dst file 51 | `--split_timeout` Optional, time to wait for remote file split operation to complete, default 120s 52 | `--ssh_port` Optional, ssh port number to connect to 53 | `--nocurses` Optional, disables the use of a curses window to show per-file progress statistics 54 | 55 | The format of source and target arguments match those of the 'scp' cmd. 56 | Both accept either a local path, or a remote path in the format - user@host:path or host@path 57 | 58 | ### To copy from local host to remote host: 59 | splitcopy @: 60 | ### To copy from remote host to local host: 61 | splitcopy @: 62 | 63 | Supports connecting through jumphosts via 'ProxyCommand' entries in ~/.ssh/config. Example: 64 | ``` 65 | Host myserver 66 | ProxyCommand ssh myjumphost.mydomain.com -W %h:%p 67 | ``` 68 | 69 | # INSTALLATION 70 | 71 | Installation requires Python >= 3.6 and associated `pip` tool 72 | 73 | python3 -m pip install splitcopy 74 | 75 | Installing from Git is also supported (OS must have git installed). 76 | 77 | To install the latest MASTER code 78 | python3 -m pip install git+https://github.com/Juniper/splitcopy.git 79 | -or- 80 | To install a specific version, branch, tag, etc. 81 | python3 -m pip install git+https://github.com/Juniper/splitcopy.git@ 82 | 83 | Upgrading has the same requirements as installation and has the same format with the addition of --upgrade 84 | 85 | python3 -m pip install splitcopy --upgrade 86 | 87 | 88 | # Usage Examples 89 | ## FTP transfer (default method) 90 | 91 | ``` 92 | $ splitcopy /var/tmp/jselective-update-ppc-J1.1-14.2R5-S3-J1.1.tgz lab@192.168.1.1:/var/tmp/ 93 | Password: 94 | checking remote port(s) are open... 95 | using FTP for file transfer 96 | checking remote storage... 97 | sha1 not found, generating sha1... 98 | splitting file... 99 | starting transfer... 100 | 100% done 101 | transfer complete 102 | joining files... 103 | deleting remote tmp directory... 104 | generating remote sha hash... 105 | local and remote sha hash match 106 | file has been successfully copied to 192.168.1.1:/var/tmp/jselective-update-ppc-J1.1-14.2R5-S3-J1.1.tgz 107 | data transfer = 0:00:16.831192 108 | total runtime = 0:00:31.520914 109 | ``` 110 | 111 | ## SCP transfer 112 | 113 | ``` 114 | $ splitcopy lab@192.168.1.1/var/log/messages /var/tmp/ --scp 115 | ssh auth succeeded 116 | checking remote storage... 117 | checking local storage... 118 | sha1 not found, generating sha1... 119 | splitting file... 120 | starting transfer... 121 | 100% done 122 | transfer complete 123 | joining files... 124 | deleting remote tmp directory... 125 | generating remote sha hash... 126 | local and remote sha hash match 127 | file has been successfully copied to /var/tmp/messages 128 | data transfer = 0:00:18.768987 129 | total runtime = 0:00:44.891370 130 | ``` 131 | 132 | ## Notes on using FTP 133 | 134 | FTP is the default transfer method. 135 | The processing of each file chunk is performed by a dedicated thread 136 | Each cpu core is allowed up to 5 threads, with a system max of 32 threads used 137 | 138 | Using FTP method will generate the following processes on the remote host: 139 | - for mgmt session: 1x sshd, 1x cli, 1x mgd, 1x csh 140 | - for transfers: up to 40x ftpd processes (depends on Python version and number of cpus as described above) 141 | 142 | In theory, this could result in the per-user maxproc limit of 64 being exceeded: 143 | ``` 144 | May 2 04:46:59 /kernel: maxproc limit exceeded by uid 2001, please see tuning(7) and login.conf(5). 145 | ``` 146 | The script modulates the number of chunks to match the number of threads available 147 | The maximum number of user owned processes that could be created is <= 44 148 | 149 | ## Notes on using SCP 150 | 151 | The processing of each file chunk is performed by a dedicated thread 152 | Each cpu core is allowed up to 5 threads, with a system max of 32 threads used 153 | 154 | Using SCP method will generate the following processes on the remote host: 155 | - for mgmt session: 1x sshd, 1x cli, 1x mgd, 1x csh 156 | - for transfers: depends on Python version, number of cpus (see above), OpenSSH and Junos FreeBSD version (see below) 157 | 158 | In FreeBSD 11 based Junos each scp transfer creates 2 user owned processes: 159 | ``` 160 | lab 30366 0.0 0.0 475056 7688 - Ss 10:39 0:00.03 cli -c scp -t /var/tmp/ 161 | lab 30367 0.0 0.0 61324 4860 - S 10:39 0:00.01 scp -t /var/tmp/ 162 | ``` 163 | In FreeBSD 10 based Junos each scp transfer creates 2 user owned processes 164 | ``` 165 | lab 28639 0.0 0.0 734108 4004 - Is 12:00PM 0:00.01 cli -c scp -t /var/tmp/splitcopy_jinstall-11.4R5.5-domestic-signed.tgz/ 166 | lab 28640 0.0 0.0 24768 3516 - S 12:00PM 0:00.01 scp -t /var/tmp/splitcopy_jinstall-11.4R5.5-domestic-signed.tgz/ 167 | ``` 168 | In FreeBSD 6 based Junos each scp transfer creates 3 user owned processes: 169 | ``` 170 | lab 78625 0.0 0.1 2984 2144 ?? Ss 5:29AM 0:00.01 cli -c scp -t /var/tmp/splitcopy_jinstall-11.4R5.5-domestic-signed.tgz/ 171 | lab 78626 0.0 0.0 2252 1556 ?? S 5:29AM 0:00.00 sh -c scp -t /var/tmp/splitcopy_jinstall-11.4R5.5-domestic-signed.tgz/ 172 | lab 78627 0.0 0.1 3500 1908 ?? S 5:29AM 0:00.01 scp -t /var/tmp/splitcopy_jinstall-11.4R5.5-domestic-signed.tgz/ 173 | ``` 174 | In addition, if OpenSSH version is >= 7.4, an additional user owned process is created: 175 | ``` 176 | lab 2287 2.4 0.1 11912 2348 ?? S 3:49AM 0:00.15 sshd: lab@notty (sshd) 177 | ``` 178 | This could result in the per-user maxproc limit of 64 being exceeded: 179 | ``` 180 | May 2 04:46:59 /kernel: maxproc limit exceeded by uid 2001, please see tuning(7) and login.conf(5). 181 | ``` 182 | To mitigate this, the script modulates the number of chunks to match the maximum number of simultaneous transfers possible (based on OpenSSH, Junos FreeBSD versions and the number of cpu's). 183 | The maximum number of user owned processes that could be created is <= 45 184 | 185 | 186 | 187 | ## LICENSE 188 | 189 | Apache 2.0 190 | 191 | ## CONTRIBUTORS 192 | 193 | Juniper Networks is actively contributing to and maintaining this repo. Please contact jnpr-community-netdev@juniper.net for any queries. 194 | 195 | *Contributors:* 196 | 197 | [Chris Jenn](https://github.com/ipmonk) 198 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-mock 3 | pytest-cov 4 | black 5 | codecov 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = splitcopy 3 | version = 1.7.1 4 | author = Chris Jenn 5 | author_email = jnpr-community-netdev@juniper.net 6 | license = Apache 2.0 7 | license_files = 8 | LICENSE 9 | description = Improves file transfer rates when copying files to/from JUNOS/EVO/*nix hosts 10 | long_description = file: README.md 11 | long_description_content_type = text/markdown 12 | keywords = 13 | ftp 14 | ssh 15 | scp 16 | transfer 17 | url = https://github.com/Juniper/splitcopy 18 | project_urls = 19 | Bug Tracker = https://github.com/Juniper/splitcopy/issues 20 | classifiers = 21 | Development Status :: 5 - Production/Stable 22 | License :: OSI Approved :: Apache Software License 23 | Environment :: Console 24 | Operating System :: OS Independent 25 | Programming Language :: Python :: 3 26 | Topic :: System :: Networking 27 | 28 | [options] 29 | package_dir= 30 | =src 31 | packages = find: 32 | python_requires = >=3.6 33 | install_requires = 34 | paramiko 35 | scp 36 | windows-curses; sys_platform == "win32" 37 | 38 | [options.packages.find] 39 | where = src 40 | 41 | [options.entry_points] 42 | console_scripts = 43 | splitcopy = splitcopy.splitcopy:main 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /src/splitcopy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Juniper/splitcopy/bb576cda0d4945809d6aafecd704658f3f045c0f/src/splitcopy/__init__.py -------------------------------------------------------------------------------- /src/splitcopy/ftp.py: -------------------------------------------------------------------------------- 1 | """ Copyright (c) 2018, Juniper Networks, Inc 2 | All rights reserved 3 | This SOFTWARE is licensed under the LICENSE provided in the 4 | ./LICENCE file. By downloading, installing, copying, or otherwise 5 | using the SOFTWARE, you agree to be bound by the terms of that 6 | LICENSE. 7 | """ 8 | 9 | # stdlib 10 | import ftplib 11 | import logging 12 | import os 13 | import sys 14 | 15 | 16 | class FTP(ftplib.FTP): 17 | """FTP utility used to transfer files to and from hosts 18 | mostly ripped from py-junos-eznc (with permission) 19 | """ 20 | 21 | def __init__(self, file_size=None, progress=None, **kwargs): 22 | """initialize the FTP class 23 | :param file_size: size of file to transfer 24 | :type int: 25 | :param progress: Progress (from progress.py) 26 | :type object: 27 | :param kwargs: named arguments 28 | :type dict: 29 | """ 30 | host = kwargs.get("host") 31 | user = kwargs.get("user") 32 | passwd = kwargs.get("passwd") 33 | timeout = kwargs.get("timeout") 34 | if not timeout: 35 | timeout = 30 36 | logger = logging.getLogger(__name__) 37 | if logger.getEffectiveLevel() == 10: 38 | self.set_debuglevel(level=1) 39 | ftplib.FTP.__init__(self, host=host, user=user, passwd=passwd, timeout=timeout) 40 | self.file_size = file_size 41 | self.progress = progress 42 | self.header_bytes = 33 43 | self.sent = 0 44 | 45 | def __enter__(self): 46 | return self 47 | 48 | def __exit__(self, exc_ty, exc_val, exc_tb): 49 | self.quit() 50 | 51 | def put(self, local_file, remote_file, restart_marker): 52 | """copies file from local host to remote host 53 | :param local_file: path to local file 54 | :type string: 55 | :param remote_file: full path on server 56 | :type string: 57 | :return None: 58 | """ 59 | with open(local_file, "rb") as open_local_file: 60 | if restart_marker is not None: 61 | self.sent = restart_marker 62 | open_local_file.seek(restart_marker, 0) 63 | 64 | def callback(data): 65 | size_data = sys.getsizeof(data) - self.header_bytes 66 | self.sent += size_data 67 | self.progress.report_progress( 68 | file_name=os.path.basename(local_file), 69 | file_size=self.file_size, 70 | sent=self.sent, 71 | ) 72 | 73 | self.storbinary( 74 | cmd="STOR " + remote_file, 75 | fp=open_local_file, 76 | callback=callback, 77 | rest=restart_marker, 78 | ) 79 | 80 | def get(self, remote_file, local_file, restart_marker=None): 81 | """copies file from remote host to local host 82 | :param remote_file: full path on server 83 | :type string: 84 | :param local_file: path to local file 85 | :type string: 86 | :return None: 87 | """ 88 | if restart_marker is not None: 89 | self.sent = restart_marker 90 | with open(local_file, "ab") as open_local_file: 91 | 92 | def callback(data): 93 | open_local_file.write(data) 94 | size_data = sys.getsizeof(data) - self.header_bytes 95 | self.sent += size_data 96 | self.progress.report_progress( 97 | file_name=os.path.basename(local_file), 98 | file_size=self.file_size, 99 | sent=self.sent, 100 | ) 101 | 102 | self.retrbinary("RETR " + remote_file, callback, rest=restart_marker) 103 | -------------------------------------------------------------------------------- /src/splitcopy/get.py: -------------------------------------------------------------------------------- 1 | """ Copyright (c) 2018, Juniper Networks, Inc 2 | All rights reserved 3 | This SOFTWARE is licensed under the LICENSE provided in the 4 | ./LICENCE file. By downloading, installing, copying, or otherwise 5 | using the SOFTWARE, you agree to be bound by the terms of that 6 | LICENSE. 7 | """ 8 | 9 | # stdlib 10 | import asyncio 11 | import datetime 12 | import fnmatch 13 | import functools 14 | import glob 15 | import hashlib 16 | import logging 17 | import os 18 | import re 19 | import signal 20 | import sys 21 | import tempfile 22 | import time 23 | import traceback 24 | import warnings 25 | from math import ceil 26 | 27 | # silence this warning 28 | from cryptography.utils import CryptographyDeprecationWarning 29 | 30 | warnings.simplefilter("ignore", CryptographyDeprecationWarning) 31 | 32 | # 3rd party 33 | from paramiko.ssh_exception import SSHException 34 | from scp import SCPClient 35 | 36 | # local modules 37 | from splitcopy.ftp import FTP 38 | from splitcopy.paramikoshell import SSHShell 39 | from splitcopy.progress import Progress 40 | from splitcopy.shared import SplitCopyShared 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | # use st_blksize 45 | _BUF_SIZE_READ = 1024 * 8 46 | _BUF_SIZE = 1024 47 | 48 | 49 | class SplitCopyGet: 50 | def __init__(self, **kwargs): 51 | """ 52 | Initialise the SplitCopyGet class 53 | """ 54 | self.user = kwargs.get("user") 55 | self.host = kwargs.get("host") 56 | self.passwd = kwargs.get("passwd") 57 | self.ssh_key = kwargs.get("ssh_key") 58 | self.ssh_port = kwargs.get("ssh_port") 59 | self.remote_path = kwargs.get("remote_path") 60 | self.copy_proto = kwargs.get("copy_proto") 61 | self.target = kwargs.get("target") 62 | self.noverify = kwargs.get("noverify") 63 | self.split_timeout = kwargs.get("split_timeout") 64 | self.use_curses = kwargs.get("use_curses") 65 | self.overwrite = kwargs.get("overwrite") 66 | self.sshshell = None 67 | self.scs = SplitCopyShared(**kwargs) 68 | self.mute = False 69 | self.progress = Progress() 70 | self.use_shell = False 71 | 72 | def handlesigint(self, sigint, stack): 73 | """function called when SigInt is received 74 | :param sigint: 75 | :type int: 76 | :param stack: 77 | :type frame: 78 | """ 79 | logger.debug(f"signal {sigint} received, stack:\n{stack}") 80 | self.mute = True 81 | self.progress.stop_progress() 82 | self.scs.close() 83 | 84 | def get(self): 85 | """copies file from remote host to local host 86 | performs file split/transfer/join/verify functions 87 | :returns loop_start: time when transfers started 88 | :type: datetime object 89 | :returns loop_end: time when transfers ended 90 | :type: datetime object 91 | """ 92 | ssh_kwargs = { 93 | "username": self.user, 94 | "hostname": self.host, 95 | "password": self.passwd, 96 | "key_filename": self.ssh_key, 97 | "ssh_port": self.ssh_port, 98 | } 99 | 100 | # handle sigint gracefully on *nix 101 | signal.signal(signal.SIGINT, self.handlesigint) 102 | 103 | # expand local dir path 104 | local_dir = self.expand_local_dir(self.target) 105 | 106 | # verify local dir is writeable 107 | self.verify_local_dir_perms(local_dir) 108 | 109 | # determine local filename 110 | local_file = self.determine_local_filename(self.target, self.remote_path) 111 | 112 | # define absolute local path 113 | local_path = f"{local_dir}{os.path.sep}{local_file}" 114 | 115 | # check if file already exists. delete it? 116 | self.delete_target_local(local_path) 117 | 118 | # connect to host 119 | self.sshshell, ssh_kwargs = self.scs.connect(SSHShell, **ssh_kwargs) 120 | 121 | # is this a juniper cli? 122 | if self.scs.juniper_cli_check(): 123 | self.use_shell = True 124 | self.scs.enter_shell() 125 | 126 | # ensure source path is valid 127 | ( 128 | remote_file, 129 | remote_dir, 130 | remote_path, 131 | filesize_path, 132 | ) = self.validate_remote_path_get(self.remote_path) 133 | 134 | # determine remote file size 135 | file_size = self.remote_filesize(filesize_path) 136 | 137 | # determine the OS 138 | junos, evo, bsd_version, sshd_version = self.scs.which_os() 139 | 140 | # verify which protocol to use 141 | self.copy_proto, self.passwd = self.scs.which_proto(self.copy_proto) 142 | 143 | # check required binaries exist on remote host 144 | self.scs.req_binaries(junos=junos, evo=evo) 145 | 146 | # cleanup previous remote tmp directory if found 147 | self.scs.remote_cleanup( 148 | remote_dir=remote_dir, remote_file=remote_file, silent=True 149 | ) 150 | 151 | # determine optimal size for chunks 152 | split_size, executor = self.scs.file_split_size( 153 | file_size, sshd_version, bsd_version, evo, self.copy_proto 154 | ) 155 | 156 | # confirm remote storage is sufficient 157 | self.scs.storage_check_remote(file_size, split_size, "/var/tmp") 158 | 159 | # confirm local storage is sufficient 160 | self.scs.storage_check_local(file_size) 161 | 162 | if not self.noverify: 163 | # get/create sha hash for remote file 164 | sha_hash = self.remote_sha_get(remote_path) 165 | 166 | # create tmp directory on remote host 167 | remote_tmpdir = self.scs.mkdir_remote("/var/tmp", remote_file) 168 | 169 | # split file into chunks 170 | self.split_file_remote( 171 | SCPClient, file_size, split_size, remote_tmpdir, remote_path, remote_file 172 | ) 173 | 174 | # add chunk names to a list, pass this info to Progress 175 | chunks = self.get_chunk_info(remote_tmpdir, remote_file) 176 | self.progress.add_chunks(file_size, chunks) 177 | 178 | # begin connection/rate limit check and transfer process 179 | command_list = [] 180 | if junos or evo: 181 | command_list = self.scs.limit_check(self.copy_proto) 182 | print("starting transfer...") 183 | self.progress.start_progress(self.use_curses) 184 | with self.scs.tempdir(): 185 | # copy files from remote host 186 | self.scs.hard_close = True 187 | loop_start = datetime.datetime.now() 188 | loop = asyncio.new_event_loop() 189 | tasks = [] 190 | for chunk in chunks: 191 | task = loop.run_in_executor( 192 | executor, 193 | functools.partial( 194 | self.get_files, 195 | FTP, 196 | SSHShell, 197 | SCPClient, 198 | chunk, 199 | remote_tmpdir, 200 | ssh_kwargs, 201 | ), 202 | ) 203 | tasks.append(task) 204 | try: 205 | loop.run_until_complete(asyncio.gather(*tasks)) 206 | except TransferError: 207 | self.progress.stop_progress() 208 | self.scs.close( 209 | err_str="an error occurred while copying the files from the remote host", 210 | ) 211 | finally: 212 | loop.close() 213 | 214 | self.scs.hard_close = False 215 | while self.progress.totals["percent_done"] != 100: 216 | time.sleep(0.1) 217 | self.progress.stop_progress() 218 | 219 | print("\ntransfer complete") 220 | loop_end = datetime.datetime.now() 221 | 222 | # combine chunks 223 | self.join_files_local(local_path, remote_file) 224 | 225 | # remove remote tmp dir 226 | self.scs.remote_cleanup() 227 | 228 | # rollback any config changes made 229 | if command_list: 230 | self.scs.limits_rollback() 231 | 232 | # check local file size is correct 233 | self.compare_file_sizes(file_size, remote_dir, remote_file, local_path) 234 | 235 | if self.noverify: 236 | print(f"file has been successfully copied to {local_path}") 237 | else: 238 | # generate a sha hash for the combined file, compare to hash of src 239 | self.local_sha_get(sha_hash, local_path) 240 | 241 | self.sshshell.close() 242 | return loop_start, loop_end 243 | 244 | def get_chunk_info(self, remote_tmpdir, remote_file): 245 | """obtains the remote chunk file size and names 246 | :param remote_tmpdir: 247 | :type string: 248 | :param remote_file: 249 | :type string: 250 | :return chunks: 251 | :type list: 252 | """ 253 | logger.info("entering get_chunk_info()") 254 | result, stdout = self.scs.ssh_cmd(f"ls -l {remote_tmpdir}/") 255 | if not result: 256 | self.scs.close( 257 | err_str="couldn't get list of files from host", 258 | ) 259 | lines = stdout.splitlines() 260 | chunks = [] 261 | for line in lines: 262 | if fnmatch.fnmatch(line, f"* {remote_file}*"): 263 | chunk = line.split() 264 | chunks.append([chunk[-1], int(chunk[-5])]) 265 | if not chunks: 266 | self.scs.close( 267 | err_str="failed to retreive chunk names and sizes", 268 | ) 269 | logger.debug(chunks) 270 | return chunks 271 | 272 | def validate_remote_path_get(self, remote_path): 273 | """path must be a full path, expand as required 274 | :return remote_file: 275 | :type string: 276 | :return remote_dir: 277 | :type string: 278 | :return remote_path: 279 | :type string: 280 | :return filesize_path: 281 | :type string: 282 | """ 283 | logger.info("entering validate_remote_path_get()") 284 | try: 285 | self.verify_path_exists(remote_path) 286 | self.verify_path_is_not_directory(remote_path) 287 | self.verify_path_is_readable(remote_path) 288 | remote_path = self.expand_remote_path(remote_path) 289 | remote_path = self.path_startswith_tilda(remote_path) 290 | filesize_path = self.check_if_symlink(remote_path) 291 | remote_dir = os.path.dirname(remote_path) 292 | remote_file = os.path.basename(remote_path) 293 | except ValueError as err: 294 | self.scs.close( 295 | err_str=err, 296 | ) 297 | # update SplitCopyShared with these values 298 | self.scs.remote_dir = remote_dir 299 | self.scs.remote_file = remote_file 300 | return remote_file, remote_dir, remote_path, filesize_path 301 | 302 | def expand_local_dir(self, target): 303 | """determines the local dir based on the target arg 304 | :param target: 305 | :type string: 306 | :return local_dir: 307 | :type string: 308 | """ 309 | logger.info("entering expand_local_dir()") 310 | local_dir = None 311 | target = os.path.abspath(os.path.expanduser(target)) 312 | if os.path.isdir(target): 313 | # target is a / 314 | local_dir = target 315 | elif os.path.isdir(os.path.dirname(target)): 316 | # target is a / 317 | local_dir = os.path.dirname(target) 318 | else: 319 | # target is only 320 | local_dir = os.getcwd() 321 | # update SplitCopyShared 322 | self.scs.local_dir = local_dir 323 | logger.debug(f"local dir = {local_dir}") 324 | return local_dir 325 | 326 | def verify_local_dir_perms(self, local_dir): 327 | """ensure local directory is writeable 328 | :param local_dir: 329 | :type string: 330 | :return None: 331 | :raises SystemExit: if directory is not writeable 332 | """ 333 | logger.debug("entering verify_local_dir_perms()") 334 | try: 335 | with tempfile.TemporaryFile(dir=local_dir) as foo: 336 | pass 337 | except PermissionError: 338 | raise SystemExit( 339 | f"Unable to write file to {local_dir} due to directory permissions" 340 | ) 341 | 342 | def determine_local_filename(self, target, remote_path): 343 | """determines the local file based on the target arg 344 | :param remote_file: 345 | :type string: 346 | :return local_file: 347 | :type string: 348 | """ 349 | logger.info("entering determine_local_filename()") 350 | local_file = None 351 | remote_file = os.path.basename(remote_path) 352 | if os.path.isdir(target): 353 | # target is a '/' 354 | local_file = remote_file 355 | else: 356 | # target is a /, or only 357 | if os.path.basename(target) != remote_file: 358 | # have to honour the change of name 359 | local_file = os.path.basename(target) 360 | else: 361 | local_file = remote_file 362 | logger.debug(f"local file = {local_file}") 363 | return local_file 364 | 365 | def expand_remote_path(self, remote_path): 366 | """if only a filename is provided, expands the remote 367 | path to its absolute path 368 | :param remote_path: 369 | :type string: 370 | :return remote_path: 371 | :type string: 372 | :raises ValueError: if remote cmd fails 373 | """ 374 | logger.info("entering expand_remote_path()") 375 | if not re.search(r"\/", remote_path) or re.match(r"\.\/", remote_path): 376 | result, stdout = self.scs.ssh_cmd("pwd") 377 | if result: 378 | if self.use_shell: 379 | pwd = stdout.split("\n")[1].rstrip() 380 | else: 381 | pwd = stdout 382 | remote_path = re.sub(r"^\.\/", "", remote_path) 383 | remote_path = f"{pwd}/{remote_path}" 384 | logger.debug(f"remote_path now = {remote_path}") 385 | else: 386 | raise ValueError( 387 | "Cannot determine the current working directory on the remote host" 388 | ) 389 | return remote_path 390 | 391 | def path_startswith_tilda(self, remote_path): 392 | """expands ~ based path to absolute path 393 | :return None: 394 | :raises ValueError: if remote cmd fails 395 | """ 396 | logger.info("entering path_startswith_tilda()") 397 | if re.match(r"~", remote_path): 398 | result, stdout = self.scs.ssh_cmd(f"ls -d {remote_path}") 399 | if result: 400 | if self.use_shell: 401 | remote_path = stdout.split("\n")[1].rstrip() 402 | else: 403 | remote_path = stdout 404 | logger.debug(f"remote_path now = {remote_path}r") 405 | else: 406 | raise ValueError(f"unable to expand remote path {remote_path}") 407 | return remote_path 408 | 409 | def verify_path_is_not_directory(self, remote_path): 410 | """verifies remote path is not a directory 411 | :param remote_path: 412 | :type string: 413 | :return result: 414 | :raises ValueError: if path is a directory 415 | """ 416 | logger.info("entering verify_path_is_not_directory()") 417 | result, stdout = self.scs.ssh_cmd(f"test -d {remote_path}") 418 | if result: 419 | raise ValueError(f"src path '{remote_path}' is a directory, not a file") 420 | return result 421 | 422 | def verify_path_exists(self, remote_path): 423 | """verifies remote path exists 424 | :param remote_path: 425 | :type string: 426 | :return None: 427 | :raises ValueError: if test fails 428 | """ 429 | logger.info("entering verify_path_exists()") 430 | result, stdout = self.scs.ssh_cmd(f"test -e {remote_path}") 431 | if not result: 432 | raise ValueError(f"'{remote_path}' on remote host doesn't exist") 433 | return result 434 | 435 | def verify_path_is_readable(self, remote_path): 436 | """verifies the remote path is readable 437 | :param remote_path: 438 | :type string: 439 | :return None 440 | :raises ValueError: if test fails 441 | """ 442 | logger.info("entering verify_file_is_readable()") 443 | result, stdout = self.scs.ssh_cmd(f"test -r {remote_path}") 444 | if not result: 445 | raise ValueError(f"'{remote_path}' on remote host is not readable") 446 | return result 447 | 448 | def check_if_symlink(self, remote_path): 449 | """if remote_path is a symlink, determine the link dst path 450 | this is required to correctly determine the files size in remote_filesize() 451 | :param remote_path: 452 | :type string: 453 | :return filesize_path: 454 | :type string: 455 | :raises ValueError: if test fails 456 | """ 457 | logger.info("entering check_if_symlink()") 458 | filesize_path = remote_path 459 | result, stdout = self.scs.ssh_cmd(f"test -L {remote_path}") 460 | if result: 461 | logger.info("file is a symlink") 462 | cmd = f"ls -l {remote_path}" 463 | result, stdout = self.scs.ssh_cmd(cmd) 464 | if result: 465 | if self.use_shell: 466 | linked_path = stdout.split()[-2].rstrip() 467 | else: 468 | linked_path = stdout.split()[-1] 469 | linked_dir = os.path.dirname(linked_path) 470 | linked_file = os.path.basename(linked_path) 471 | else: 472 | raise ValueError(f"file on remote host is a symlink, cmd {cmd} failed") 473 | if not linked_dir: 474 | # symlink is in the same directory as source file use remote_dir 475 | remote_dir = os.path.dirname(remote_path) 476 | filesize_path = f"{remote_dir}/{linked_file}" 477 | else: 478 | filesize_path = f"{linked_dir}/{linked_file}" 479 | logger.debug(f"filesize_path updated from {remote_path} to {filesize_path}") 480 | return filesize_path 481 | 482 | def check_target_exists(self, path): 483 | """checks if the target file already exists 484 | :param path: path to file 485 | :type string: 486 | :return result: 487 | :type bool: 488 | """ 489 | logger.info("entering check_target_exists()") 490 | result = False 491 | if os.path.exists(path) and os.path.isfile(path) or os.path.islink(path): 492 | result = True 493 | return result 494 | 495 | def delete_target_local(self, file_path): 496 | """verifies whether path is a file. 497 | if true and --overwrite flag is specified attempt to delete it 498 | else alert the user and exit 499 | :param file_path: path to file 500 | :type string: 501 | :return None: 502 | :raises SystemExit: if file cannot be deleted 503 | """ 504 | logger.info("entering delete_target_local()") 505 | err = "" 506 | if self.check_target_exists(file_path): 507 | if self.overwrite: 508 | try: 509 | os.remove(file_path) 510 | except PermissionError: 511 | err = ( 512 | f"target file '{file_path}' already exists, cannot " 513 | "be deleted due to a permissions error" 514 | ) 515 | else: 516 | err = ( 517 | f"target file '{file_path}' already exists, " 518 | "use --overwrite arg or delete it manually" 519 | ) 520 | 521 | if err: 522 | if self.sshshell is not None: 523 | self.scs.close( 524 | err_str=err, 525 | ) 526 | else: 527 | raise SystemExit(err) 528 | 529 | def remote_filesize(self, filesize_path): 530 | """determines the remote file size in bytes 531 | :param filesize_path: 532 | :type string: 533 | :return file_size: 534 | :type int: 535 | """ 536 | logger.info("entering remote_filesize()") 537 | result, stdout = self.scs.ssh_cmd(f"ls -l {filesize_path}") 538 | if result: 539 | if self.use_shell: 540 | file_size = int(stdout.split("\n")[1].split()[4]) 541 | else: 542 | file_size = int(stdout.split()[4]) 543 | else: 544 | err = "cannot determine remote file size" 545 | self.scs.close( 546 | err_str=err, 547 | ) 548 | logger.info(f"src file size is {file_size}") 549 | if not file_size: 550 | err = "remote file size is 0 bytes, nothing to copy" 551 | self.scs.close(err_str=err) 552 | 553 | return file_size 554 | 555 | def remote_sha_get(self, remote_path): 556 | """checks for existence of a sha hash file 557 | if none found, generates a sha hash for the remote file to be copied 558 | :param remote_path: 559 | :type string: 560 | :return sha_bin: 561 | :type string: 562 | :return sha_len: 563 | :type int: 564 | :return sha_hash: 565 | """ 566 | logger.info("entering remote_sha_get()") 567 | sha_hash = {} 568 | result, stdout = self.find_existing_sha_files(remote_path) 569 | if result: 570 | sha_hash = self.process_existing_sha_files(stdout) 571 | if not sha_hash: 572 | sha_hash[1] = True 573 | sha_bin, sha_len = self.scs.req_sha_binaries(sha_hash) 574 | print("generating remote sha hash...") 575 | if sha_bin == "shasum": 576 | result, stdout = self.scs.ssh_cmd( 577 | f"{sha_bin} -a {sha_len} {remote_path}", timeout=120 578 | ) 579 | else: 580 | result, stdout = self.scs.ssh_cmd( 581 | f"{sha_bin} {remote_path}", timeout=120 582 | ) 583 | if not result: 584 | self.scs.close( 585 | err_str="failed to generate remote sha1", 586 | ) 587 | for line in stdout.splitlines(): 588 | try: 589 | sha_hash[1] = re.search(r"([0-9a-f]{40})", line).group(1) 590 | break 591 | except AttributeError: 592 | pass 593 | if not isinstance(sha_hash[1], str): 594 | self.scs.close( 595 | err_str="failed to obtain remote sha1", 596 | ) 597 | 598 | logger.info(f"remote sha hashes = {sha_hash}") 599 | return sha_hash 600 | 601 | def find_existing_sha_files(self, remote_path): 602 | """checks for presence of existing sha* files 603 | :param remote_path: 604 | :type string: 605 | :return result: 606 | :type bool: 607 | :return stdout: 608 | :type string: 609 | """ 610 | logger.info("entering find_existing_sha_files()") 611 | result, stdout = self.scs.ssh_cmd(f"ls -1 {remote_path}.sha*") 612 | return result, stdout 613 | 614 | def process_existing_sha_files(self, output): 615 | """reads existing sha files, puts the hash and sha length info a dict() 616 | :param output: 617 | :type string: 618 | :returns sha_hash: 619 | :type dict: 620 | """ 621 | logger.info("entering process_existing_sha_files()") 622 | sha_hash = {} 623 | for line in output.splitlines(): 624 | line = line.rstrip() 625 | match = re.search(r"\.sha([0-9]+)$", line) 626 | try: 627 | sha_num = int(match.group(1)) 628 | except AttributeError: 629 | continue 630 | logger.info(f"{line} file found") 631 | result, stdout = self.scs.ssh_cmd(f"cat {line}") 632 | if result: 633 | if self.use_shell: 634 | sha_hash[sha_num] = stdout.split("\n")[1].split()[0].rstrip() 635 | else: 636 | sha_hash[sha_num] = stdout.split()[0] 637 | logger.info(f"sha_hash[{sha_num}] added") 638 | else: 639 | logger.info(f"unable to read remote sha file {line}") 640 | return sha_hash 641 | 642 | def split_file_remote( 643 | self, scp_lib, file_size, split_size, remote_tmpdir, remote_path, remote_file 644 | ): 645 | """writes a script into a file, copies it to the remote host then executes it. 646 | the source file is split into multiple smaller chunks ready to be copied 647 | :param scp_lib: 648 | :type class: 649 | :param file_size: 650 | :type int: 651 | :param split_size: 652 | :type int: 653 | :param remote_tmpdir: 654 | :type string: 655 | :param remote_path: 656 | :type string: 657 | :param remote_file: 658 | :type string: 659 | :return None: 660 | """ 661 | logger.info("entering split_file_remote()") 662 | result = False 663 | total_blocks = ceil(file_size / _BUF_SIZE) 664 | block_size = ceil(split_size / _BUF_SIZE) 665 | logger.info(f"total_blocks = {total_blocks}, block_size = {block_size}") 666 | cmd = ( 667 | f"size_b={block_size}; size_tb={total_blocks}; i=0; o=00; " 668 | "while [ $i -lt $size_tb ]; do " 669 | f"dd if={remote_path} of={remote_tmpdir}/{remote_file}_$o " 670 | f"bs={_BUF_SIZE} count=$size_b skip=$i 2>&1; " 671 | "result=`echo $?`; if [ $result -gt 0 ]; then exit 1; fi;" 672 | "i=`expr $i + $size_b`; o=`expr $o + 1`; " 673 | "if [ $o -lt 10 ]; then o=0$o; fi; done" 674 | ) 675 | 676 | # switched to file copy as the '> ' in 'echo cmd > file' 677 | # would sometimes be interpreted as shell prompt 678 | with self.scs.tempdir(): 679 | with open("split.sh", "w") as fd: 680 | fd.write(cmd) 681 | transport = self.sshshell._transport 682 | with scp_lib(transport) as scpclient: 683 | scpclient.put("split.sh", f"{remote_tmpdir}/split.sh") 684 | print("splitting remote file...") 685 | result, stdout = self.scs.ssh_cmd( 686 | f"sh {remote_tmpdir}/split.sh", 687 | timeout=self.split_timeout, 688 | ) 689 | if not result: 690 | err = f"failed to split file on remote host, due to error:\n{stdout}" 691 | self.scs.close(err_str=err) 692 | 693 | def get_files(self, ftp_lib, ssh_lib, scp_lib, chunk, remote_tmpdir, ssh_kwargs): 694 | """copies files from remote host via ftp or scp 695 | :param ftp_lib: 696 | :type class: 697 | :param ssh_lib: 698 | :type class: 699 | :param scp_lib: 700 | :type class: 701 | :param chunk: name and size of the file to copy 702 | :type: list 703 | :param remote_tmpdir: path of the tmp directory on remote host 704 | :type: str 705 | :param ssh_kwargs: keyword arguments 706 | :type dict: 707 | :raises TransferError: if file transfer fails 3 times 708 | :returns None: 709 | """ 710 | logger.info("entering get_files()") 711 | err_count = 0 712 | file_name = chunk[0] 713 | file_size = chunk[1] 714 | srcpath = f"{remote_tmpdir}/{file_name}" 715 | logger.info(f"{file_name}, size {file_size}") 716 | while err_count < 3: 717 | try: 718 | if self.copy_proto == "ftp": 719 | with ftp_lib( 720 | file_size=file_size, 721 | progress=self.progress, 722 | host=self.host, 723 | user=self.user, 724 | passwd=self.passwd, 725 | ) as ftp: 726 | restart_marker = None 727 | if err_count: 728 | try: 729 | restart_marker = os.stat(file_name).st_size 730 | except FileNotFoundError: 731 | pass 732 | self.progress.zero_file_stats(file_name) 733 | if restart_marker is not None: 734 | self.progress.print_error( 735 | f"resuming {file_name} from byte {restart_marker}" 736 | ) 737 | ftp.get(srcpath, file_name, restart_marker) 738 | break 739 | else: 740 | with ssh_lib(**ssh_kwargs) as ssh: 741 | ssh.socket_open() 742 | ssh.transport_open() 743 | if not ssh.worker_thread_auth(): 744 | ssh.close() 745 | raise SSHException("authentication failed") 746 | with scp_lib( 747 | ssh._transport, progress=self.progress.report_progress 748 | ) as scpclient: 749 | if err_count: 750 | self.progress.zero_file_stats(file_name) 751 | scpclient.get(srcpath, file_name) 752 | # hack. at times, a FIN wasn't being sent resulting in sshd (notty) 753 | # processes being left in ESTABLISHED state on server. 754 | # adding sleep here appears to prevent this 755 | time.sleep(1) 756 | break 757 | except Exception as err: 758 | err_count += 1 759 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 760 | if not self.mute: 761 | if err_count < 3: 762 | self.progress.print_error( 763 | f"chunk {file_name} transfer failed due to " 764 | f"{err.__class__.__name__} {str(err)}, retrying" 765 | ) 766 | else: 767 | self.progress.print_error( 768 | f"chunk {file_name} transfer failed due to " 769 | f"{err.__class__.__name__} {str(err)}" 770 | ) 771 | time.sleep(err_count) 772 | 773 | if err_count == 3: 774 | self.mute = True 775 | raise TransferError 776 | 777 | def join_files_local(self, local_path, remote_file): 778 | """concatenates the file chunks into one file on local host 779 | :param local_path: 780 | :type string: 781 | :param remote_file: 782 | :type string: 783 | :returns None: 784 | """ 785 | logger.info("entering join_files_local()") 786 | print("joining chunks...") 787 | local_tmpdir = self.scs.return_tmpdir() 788 | src_files = glob.glob(f"{local_tmpdir}{os.path.sep}{remote_file}*") 789 | with open(local_path, "wb") as dst: 790 | for src in sorted(src_files): 791 | with open(src, "rb") as chunk: 792 | data = chunk.read(_BUF_SIZE_READ) 793 | while data: 794 | dst.write(data) 795 | data = chunk.read(_BUF_SIZE_READ) 796 | if not os.path.isfile(local_path): 797 | err = f"recombined file {local_path} isn't found, exiting" 798 | self.scs.close( 799 | err_str=err, 800 | ) 801 | 802 | def compare_file_sizes(self, file_size, remote_dir, remote_file, local_path): 803 | """obtains the newly combined file size, compares it to the source files size 804 | :param file_size: 805 | :type int: 806 | :param remote_dir: 807 | :type string: 808 | :param remote_file: 809 | :type string: 810 | :param local_path: 811 | :type string: 812 | :return None: 813 | """ 814 | logger.info("entering compare_file_sizes()") 815 | combined_file_size = os.path.getsize(local_path) 816 | if combined_file_size != file_size: 817 | self.scs.close( 818 | err_str=( 819 | f"combined file size is {combined_file_size}, file " 820 | f"{self.host}:{remote_dir}/{remote_file} size " 821 | f"is {file_size}. Unexpected mismatch in file size. Please retry" 822 | ), 823 | config_rollback=False, 824 | ) 825 | print("local and remote file sizes match") 826 | 827 | def local_sha_get(self, sha_hash, local_path): 828 | """generates a sha hash for the combined file on the local host 829 | :param sha_hash: 830 | :type dict: 831 | :param local_path: 832 | :type string: 833 | :returns None: 834 | """ 835 | logger.info("entering local_sha_get()") 836 | print("generating local sha hash...") 837 | if sha_hash.get(512): 838 | sha_idx = 512 839 | sha = hashlib.sha512() 840 | elif sha_hash.get(384): 841 | sha_idx = 384 842 | sha = hashlib.sha384() 843 | elif sha_hash.get(256): 844 | sha_idx = 256 845 | sha = hashlib.sha256() 846 | elif sha_hash.get(224): 847 | sha_idx = 224 848 | sha = hashlib.sha224() 849 | else: 850 | sha_idx = 1 851 | sha = hashlib.sha1() 852 | with open(local_path, "rb") as dst: 853 | data = dst.read(_BUF_SIZE_READ) 854 | while data: 855 | sha.update(data) 856 | data = dst.read(_BUF_SIZE_READ) 857 | local_sha = sha.hexdigest() 858 | logger.info(f"local sha = {local_sha}") 859 | if local_sha == sha_hash.get(sha_idx): 860 | print( 861 | "local and remote sha hash match\nfile has been " 862 | f"successfully copied to {local_path}" 863 | ) 864 | else: 865 | err = ( 866 | f"file has been copied to {local_path}, " 867 | "but the local and remote sha hash do not match - please retry" 868 | ) 869 | self.scs.close( 870 | err_str=err, 871 | config_rollback=False, 872 | ) 873 | 874 | 875 | class TransferError(Exception): 876 | """ 877 | custom exception to indicate problem with file transfer 878 | """ 879 | 880 | pass 881 | -------------------------------------------------------------------------------- /src/splitcopy/paramikoshell.py: -------------------------------------------------------------------------------- 1 | """ Copyright (c) 2018, Juniper Networks, Inc 2 | All rights reserved 3 | This SOFTWARE is licensed under the LICENSE provided in the 4 | ./LICENCE file. By downloading, installing, copying, or otherwise 5 | using the SOFTWARE, you agree to be bound by the terms of that 6 | LICENSE. 7 | """ 8 | 9 | # stdlib 10 | import datetime 11 | import getpass 12 | import logging 13 | import os 14 | import re 15 | import select 16 | import socket 17 | import sys 18 | import traceback 19 | import warnings 20 | 21 | # 3rd party exceptions 22 | from cryptography.utils import CryptographyDeprecationWarning 23 | from paramiko.ssh_exception import ( 24 | AuthenticationException, 25 | BadAuthenticationType, 26 | PasswordRequiredException, 27 | SSHException, 28 | ) 29 | 30 | # 3rd party 31 | warnings.simplefilter("ignore", CryptographyDeprecationWarning) 32 | import paramiko 33 | 34 | logging.getLogger("paramiko").setLevel(logging.CRITICAL) 35 | 36 | _SHELL_PROMPT = re.compile(r"(% |# |\$ |> |%\t)$") 37 | _SELECT_WAIT = 0.1 38 | _RECVSZ = 1024 39 | _EXIT_CODE = re.compile(r"\r\n0\r\n", re.MULTILINE) 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | class SSHShell: 45 | """class providing ssh connectivity using paramiko lib""" 46 | 47 | def __init__(self, **kwargs): 48 | """Initialise the SSHShell class""" 49 | self.kwargs = kwargs 50 | logger.debug(self.kwargs) 51 | self.hostname = self.kwargs.get("hostname") 52 | self.username = self.kwargs.get("username") 53 | self.ssh_port = self.kwargs.get("ssh_port") 54 | self._chan = None 55 | self._transport = None 56 | self.use_shell = False 57 | 58 | def __enter__(self): 59 | self.socket_open() 60 | self.transport_open() 61 | return self 62 | 63 | def __exit__(self, exc_ty, exc_val, exc_tb): 64 | if self._transport is not None: 65 | # closes the transport and underlying socket 66 | self.close_transport() 67 | 68 | def socket_open(self): 69 | """wrapper around proxy or direct methods 70 | :return None: 71 | """ 72 | logger.info("entering socket_open()") 73 | self.socket = self.socket_proxy() 74 | if not self.socket: 75 | self.socket = self.socket_direct() 76 | 77 | def socket_proxy(self): 78 | """checks the .ssh/config file for any proxy commands to reach host 79 | :return sock: 80 | :type subprocess: 81 | """ 82 | logger.info("entering socket_proxy()") 83 | sock = None 84 | ssh_config = os.path.expanduser("~/.ssh/config") 85 | if os.path.isfile(ssh_config): 86 | config = paramiko.SSHConfig() 87 | with open(ssh_config) as open_ssh_config: 88 | config.parse(open_ssh_config) 89 | host_config = config.lookup(self.hostname) 90 | if host_config.get("proxycommand"): 91 | sock = paramiko.proxy.ProxyCommand(host_config.get("proxycommand")) 92 | return sock 93 | 94 | def socket_direct(self): 95 | """open a socket to remote host 96 | :return sock: 97 | :type socket object: 98 | """ 99 | logger.info("entering socket_direct()") 100 | sock = None 101 | try: 102 | sock = socket.create_connection((self.hostname, self.ssh_port), 10) 103 | except (socket.gaierror, socket.herror): 104 | raise ConnectionError("address or hostname not reachable") 105 | except (socket.timeout, ConnectionRefusedError, IOError, OSError): 106 | raise ConnectionError( 107 | f"error connecting to remote host on port {self.ssh_port}" 108 | ) 109 | return sock 110 | 111 | def get_pkey_from_file(self, pkey_type, pkey_path): 112 | """attempt to decode the private key 113 | :param pkey_type: key algorithm 114 | :type string: 115 | :param pkey_path: path to key file 116 | :type string: 117 | :return Pkey object: 118 | :raises PasswordRequiredException: if key cannot be decoded 119 | """ 120 | pkey = None 121 | try: 122 | if pkey_type == "RSA": 123 | pkey = paramiko.RSAKey.from_private_key_file(filename=pkey_path) 124 | elif pkey_type == "DSA": 125 | pkey = paramiko.DSSKey.from_private_key_file(filename=pkey_path) 126 | elif pkey_type == "EC": 127 | pkey = paramiko.ECDSAKey.from_private_key_file(filename=pkey_path) 128 | elif pkey_type == "OPENSSH": 129 | pkey = paramiko.Ed25519Key.from_private_key_file(filename=pkey_path) 130 | except PasswordRequiredException: 131 | raise 132 | except AttributeError: 133 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 134 | print( 135 | f"{pkey_type} key found, this paramiko version is missing support " 136 | f"for {pkey_type} keys" 137 | ) 138 | return pkey 139 | 140 | def transport_open(self): 141 | """opens a transport to the host 142 | :return: None 143 | """ 144 | self._transport = paramiko.Transport(self.socket) 145 | self._transport.start_client() 146 | 147 | def worker_thread_auth(self): 148 | """authentication has succeeded previously, simplify nth time around 149 | :return result: 150 | :type bool: 151 | """ 152 | result = False 153 | auth_method = self.kwargs.get("auth_method") 154 | if auth_method == "agent": 155 | self.auth_using_agent() 156 | elif auth_method == "publickey": 157 | self.auth_using_provided_keyfile() 158 | elif auth_method == "keyboard-interactive": 159 | self.auth_using_keyb() 160 | else: 161 | self.password_auth() 162 | 163 | if self.is_authenticated(): 164 | result = True 165 | return result 166 | 167 | def main_thread_auth(self): 168 | """determines what authentication methods the server supports 169 | attempts the available authentication methods in order: 170 | * publickey auth 171 | * keyboard-interactive auth 172 | * password auth 173 | :return result: 174 | :type bool: 175 | """ 176 | logger.info("entering main_thread_auth()") 177 | allowed_types = None 178 | result = False 179 | try: 180 | self._transport.auth_none(self.kwargs["username"]) 181 | except BadAuthenticationType as e: 182 | allowed_types = e.allowed_types 183 | if allowed_types is None: 184 | raise SSHException("no authentication methods possible") 185 | logger.info(allowed_types) 186 | 187 | for auth_type in allowed_types: 188 | logger.info(f"trying auth method {auth_type}") 189 | if auth_type == "publickey" and self.kwargs["key_filename"] is None: 190 | if self.auth_using_agent(): 191 | self.kwargs["auth_method"] = "agent" 192 | break 193 | if self.auth_using_keyfiles(): 194 | self.kwargs["auth_method"] = "publickey" 195 | break 196 | elif auth_type == "publickey" and self.kwargs["key_filename"]: 197 | if self.auth_using_provided_keyfile(): 198 | self.kwargs["auth_method"] = "publickey" 199 | break 200 | elif auth_type == "keyboard-interactive" and self.auth_using_keyb(): 201 | self.kwargs["auth_method"] = "keyboard-interactive" 202 | break 203 | elif auth_type == "password" and self.password_auth(): 204 | self.kwargs["auth_method"] = "password" 205 | break 206 | 207 | if self.is_authenticated(): 208 | print("ssh authentication succeeded") 209 | result = True 210 | return result 211 | 212 | def ask_password(self): 213 | """obtains the password for PasswordAuthentication 214 | :return password: 215 | :type string: 216 | """ 217 | logger.info("entering ask_password()") 218 | password = getpass.getpass( 219 | prompt=f"{self.username}@{self.hostname}'s password: ", 220 | stream=None, 221 | ) 222 | return password 223 | 224 | def password_auth(self): 225 | """attempts Password Authentication 226 | :raises AuthenticationException: if auth fails 227 | :return result: 228 | :type bool: 229 | """ 230 | logger.info("entering password_auth()") 231 | result = False 232 | if not self.kwargs["password"]: 233 | self.kwargs["password"] = self.ask_password() 234 | try: 235 | self._transport.auth_password( 236 | username=self.kwargs["username"], password=self.kwargs["password"] 237 | ) 238 | result = True 239 | except AuthenticationException: 240 | logger.info("password authentication failed") 241 | return result 242 | 243 | def auth_using_keyb(self): 244 | """attempts keyboard-interactive authentication 245 | :return result: 246 | :type bool: 247 | """ 248 | logger.info("entering auth_using_keyb()") 249 | result = False 250 | if not self.kwargs["password"]: 251 | self.kwargs["password"] = self.ask_password() 252 | 253 | def handler(title, instructions, fields): 254 | logger.debug(fields) 255 | if len(fields) > 1: 256 | raise SSHException("keyboard-interactive authentication failed.") 257 | if len(fields) == 0: 258 | return [] 259 | return [self.kwargs["password"]] 260 | 261 | try: 262 | username = self.kwargs["username"] 263 | self._transport.auth_interactive(username, handler) 264 | result = True 265 | except (SSHException, AuthenticationException): 266 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 267 | logger.info("keyboard-interactive authentication failed") 268 | return result 269 | 270 | def auth_using_agent(self): 271 | """attempts publickey authentication using keys held by ssh-agent 272 | :return result: 273 | :type bool: 274 | """ 275 | logger.info("entering auth_using_agent()") 276 | agent = paramiko.Agent() 277 | agent_keys = agent.get_keys() 278 | logger.info(f"ssh agent has {len(agent_keys)} keys") 279 | result = False 280 | for pkey in agent_keys: 281 | pkey_type = pkey.get_name() 282 | logger.info(f"ssh agent has key type {pkey_type}") 283 | try: 284 | self._transport.auth_publickey(self.kwargs["username"], pkey) 285 | result = True 286 | except SSHException as err: 287 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 288 | logger.info(f"{pkey_type} key authentication failed with error: {err}") 289 | return result 290 | 291 | def auth_using_keyfiles(self): 292 | """attempts publickey authentication using keys found in ~/.ssh 293 | Iterates over any keys found 294 | :return result: 295 | :type bool: 296 | """ 297 | logger.info("entering auth_using_keyfiles()") 298 | pkey_types = { 299 | "RSA": "id_rsa", 300 | "DSA": "id_dsa", 301 | "EC": "id_ecdsa", 302 | "OPENSSH": "id_ed25519", 303 | } 304 | pkey_files = [] 305 | result = False 306 | for pkey_type in pkey_types: 307 | path = os.path.expanduser(f"~/.ssh/{pkey_types[pkey_type]}") 308 | if os.path.isfile(path): 309 | pkey_files.append((pkey_type, path)) 310 | logger.debug(f"key files found: {pkey_files}") 311 | for pkey_file in pkey_files: 312 | pkey_type, pkey_path = pkey_file[0], pkey_file[1] 313 | try: 314 | if self.key_auth_common(pkey_type, pkey_path): 315 | self.kwargs.update({"key_filename": pkey_path}) 316 | result = True 317 | break 318 | except PasswordRequiredException: 319 | continue 320 | return result 321 | 322 | def key_auth_common(self, pkey_type, pkey_path): 323 | """attempts authentication using specified key and type 324 | :param pkey_type: key algorithm 325 | :type string: 326 | :param pkey_path: path to key file 327 | :type string: 328 | :return result: 329 | :type bool: 330 | """ 331 | result = False 332 | pkey = None 333 | try: 334 | pkey = self.get_pkey_from_file(pkey_type, pkey_path) 335 | except PasswordRequiredException: 336 | logger.info(f"key {pkey_path} has a passphrase") 337 | raise 338 | if pkey is not None: 339 | try: 340 | self._transport.auth_publickey(self.kwargs["username"], pkey) 341 | result = True 342 | except SSHException as err: 343 | self.kwargs.update({"key_filename": None}) 344 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 345 | print(f"{pkey_type} key authentication failed with error: {err}") 346 | return result 347 | 348 | def auth_using_provided_keyfile(self): 349 | """as key type is unknown, attempt publickey authentication 350 | using provided keyfile by looping through supported types 351 | :return result: 352 | :type bool: 353 | """ 354 | logger.info("entering auth_using_provided_keyfile()") 355 | pkey_path = self.kwargs["key_filename"] 356 | pkey_types = ["RSA", "DSA", "EC", "OPENSSH"] 357 | result = False 358 | for pkey_type in pkey_types: 359 | try: 360 | if self.key_auth_common(pkey_type, pkey_path): 361 | result = True 362 | break 363 | except PasswordRequiredException: 364 | break 365 | return result 366 | 367 | def is_authenticated(self): 368 | """verifies if authentication was successful 369 | :return result: 370 | :type bool: 371 | """ 372 | result = False 373 | logger.info("entering is_authenticated()") 374 | if self._transport.is_authenticated(): 375 | result = True 376 | return result 377 | 378 | def channel_open(self): 379 | """opens a channel of type 'session' over existing transport 380 | :return None: 381 | """ 382 | logger.info("entering channel_open()") 383 | self._chan = self._transport.open_session() 384 | 385 | def invoke_shell(self): 386 | """request a pty and interactive shell on the channel 387 | :return None: 388 | """ 389 | logger.info("entering invoke_shell()") 390 | self.use_shell = True 391 | self._chan.get_pty() 392 | self._chan.invoke_shell() 393 | 394 | def stdout_read(self, timeout): 395 | """reads data off the socket 396 | :param timeout: amount of time before timeout is raised 397 | :type int: 398 | :returns output: stdout from the cmd 399 | :type string: 400 | """ 401 | chan = self._chan 402 | now = datetime.datetime.now() 403 | timeout_time = now + datetime.timedelta(seconds=timeout) 404 | output = "" 405 | while not _SHELL_PROMPT.search(output): 406 | rd, wr, err = select.select([chan], [], [], _SELECT_WAIT) 407 | if rd: 408 | data = chan.recv(_RECVSZ) 409 | output += data.decode("ascii", "ignore") 410 | if datetime.datetime.now() > timeout_time: 411 | raise TimeoutError 412 | return output 413 | 414 | def set_transport_keepalive(self): 415 | """ensures session stays up if inactive for long period 416 | not suitable for scp, will terminate session with BadUseError if enabled 417 | :return None: 418 | """ 419 | self._transport.set_keepalive(60) 420 | 421 | def write(self, cmd): 422 | """sends a cmd + newline char over the channel 423 | :param cmd: cmd to be sent over the channel 424 | :type string: 425 | :return None: 426 | """ 427 | self._chan.send(f"{cmd}\n") 428 | logger.info(f"sent '{cmd}'") 429 | 430 | def close(self): 431 | """terminates both the channel (if present) and the underlying transport 432 | :return None: 433 | """ 434 | self.close_channel() 435 | self.close_transport() 436 | 437 | def close_channel(self): 438 | """terminates the channel 439 | :return None: 440 | """ 441 | try: 442 | self._chan.close() 443 | except AttributeError: 444 | pass 445 | except EOFError: 446 | pass 447 | 448 | def close_transport(self): 449 | """terminates the underlying transport 450 | :return None: 451 | """ 452 | try: 453 | self._transport.close() 454 | except AttributeError: 455 | pass 456 | 457 | def shell_cmd(self, cmd, timeout, exitcode): 458 | """sends a cmd to remote host over the existing channel and shell 459 | if exitcode is True will check its exit status 460 | if a timeout occurs, will attempt to close the existing channel 461 | then request a new channel, pty and interactive shell 462 | :param cmd: cmd to run on remote host 463 | :type string: 464 | :param timeout: amount of time before timeout is raised 465 | :type float: 466 | :param exitcode: toggles whether to check for exit status or not 467 | :type bool: 468 | :return result: whether successful or not 469 | :type bool: 470 | :return stdout: the output of the command 471 | :type string: 472 | """ 473 | result = False 474 | stdout = "" 475 | self.write(cmd) 476 | stdout = self.stdout_read(timeout) 477 | logger.debug(stdout) 478 | if exitcode: 479 | self.write("echo $?") 480 | rc = self.stdout_read(timeout) 481 | if re.search(_EXIT_CODE, rc): 482 | result = True 483 | 484 | return result, stdout 485 | 486 | def exec_cmd(self, cmd, timeout, combine): 487 | """execute a command on the remote host. 488 | a new channel is opened prior to the command being executed. 489 | the channel is closed once the cmds exit status has been received 490 | :param cmd: cmd to run on remote host 491 | :type string: 492 | :param timeout: amount of time before timeout is raised 493 | :type float: 494 | :param combine: whether stderr should be combined into stdout 495 | :type bool: 496 | :return result: whether successful or not 497 | :type bool: 498 | :return stdout: the output of the command 499 | :type string: 500 | """ 501 | result = False 502 | exit_code = None 503 | stdout = "" 504 | stdout_bytes = [] 505 | chan = self._transport.open_session(timeout=30) 506 | chan.settimeout(timeout) 507 | chan.exec_command(cmd) 508 | if combine: 509 | chan.set_combine_stderr(True) 510 | out_bytes = chan.recv(_RECVSZ) 511 | while out_bytes: 512 | stdout_bytes.append(out_bytes) 513 | out_bytes = chan.recv(_RECVSZ) 514 | stdout = b"".join(stdout_bytes).rstrip().decode() 515 | while exit_code is None: 516 | if chan.exit_status_ready(): 517 | exit_code = chan.recv_exit_status() 518 | chan.close() 519 | if exit_code == 0: 520 | result = True 521 | return result, stdout 522 | -------------------------------------------------------------------------------- /src/splitcopy/progress.py: -------------------------------------------------------------------------------- 1 | """ Copyright (c) 2018, Juniper Networks, Inc 2 | All rights reserved 3 | This SOFTWARE is licensed under the LICENSE provided in the 4 | ./LICENCE file. By downloading, installing, copying, or otherwise 5 | using the SOFTWARE, you agree to be bound by the terms of that 6 | LICENSE. 7 | """ 8 | 9 | # stdlib on *nix, 3rd party on win32 10 | import curses 11 | 12 | # stdlib 13 | import logging 14 | import time 15 | from shutil import get_terminal_size 16 | from threading import Thread 17 | 18 | # local modules 19 | from splitcopy.shared import pad_string 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def percent_val(total_amount, partial_amount): 25 | """returns a percentage 26 | :param total_amount: 27 | :type int: 28 | :param partial_amount: 29 | :type int: 30 | :return int: 31 | """ 32 | return int(round(100 / total_amount * partial_amount, 2)) 33 | 34 | 35 | def progress_bar(percent_done): 36 | """returns a graphical progress bar as a string 37 | :param percent_done: 38 | :type int: 39 | :return string: 40 | """ 41 | return f"[{'#' * int(percent_done/2)}{(50 - int(percent_done/2)) * ' '}]" 42 | 43 | 44 | def bytes_display(num_bytes): 45 | """Function that returns a string identifying the size of the number 46 | :param num_bytes: 47 | :type int: 48 | :return amount: 49 | :type float: 50 | :return unit: 51 | :type string: 52 | """ 53 | amount = 0.0 54 | unit = "" 55 | if num_bytes < 1024**2: 56 | amount = num_bytes / 1024 57 | unit = "KB" 58 | elif num_bytes < 1024**3: 59 | amount = num_bytes / 1024**2 60 | unit = "MB" 61 | elif num_bytes < 1024**4: 62 | amount = num_bytes / 1024**3 63 | unit = "GB" 64 | return amount, unit 65 | 66 | 67 | def prepare_curses(): 68 | """Function to do some prep work to use curses. 69 | :return stdscr: 70 | :type _curses.window object: 71 | """ 72 | stdscr = curses.initscr() 73 | curses.noecho() 74 | curses.cbreak() 75 | return stdscr 76 | 77 | 78 | def abandon_curses(): 79 | """Function to exit curses and restore terminal to prior state. 80 | :return None: 81 | """ 82 | try: 83 | curses.nocbreak() 84 | curses.echo() 85 | curses.endwin() 86 | except (curses.error, AttributeError): 87 | pass 88 | 89 | 90 | class Progress: 91 | """class which both FTP and SCPClient calls back to. 92 | provides a progress meter to the user 93 | """ 94 | 95 | def __init__(self): 96 | """Initialize the class""" 97 | self.chunks = [] 98 | self.chunk_size = "" 99 | self.totals = {} 100 | self.error_list = ["", "", ""] 101 | self.files = {} 102 | self.curses = False 103 | self.stdscr = None 104 | self.timer = None 105 | self.stop_timer = False 106 | 107 | def add_chunks(self, total_file_size, chunks): 108 | """Function that creates required data structures 109 | :param chunks: 110 | :type list: 111 | :return None: 112 | """ 113 | self.chunks = chunks 114 | self.chunk_size = str(chunks[0][1]) 115 | for chunk in chunks: 116 | file_name = chunk[0] 117 | self.files[file_name] = {} 118 | self.files[file_name]["sent_bytes"] = 0 119 | self.files[file_name]["last_sent_bytes"] = 0 120 | self.files[file_name]["bytes_per_sec"] = 0.0 121 | self.files[file_name]["percent_done"] = 0 122 | self.files[file_name]["complete"] = 0 123 | self.totals["sum_bytes_sent"] = 0 124 | self.totals["sum_completed"] = 0 125 | self.totals["sum_bytes_per_sec"] = 0.0 126 | self.totals["percent_done"] = 0 127 | self.totals["total_file_size"] = total_file_size 128 | 129 | def check_term_size(self, result): 130 | """function that checks whether curses can be supported or not 131 | preferable to do this prior to initiating a curses window 132 | :param result: 133 | :type bool: 134 | :return result: 135 | :type bool: 136 | """ 137 | if result: 138 | term_width, term_height = get_terminal_size() 139 | req_height = len(self.chunks) + 4 140 | if term_height < req_height: 141 | result = False 142 | if not result: 143 | print("terminal window is too small to display per-chunk statistics") 144 | return result 145 | 146 | def initiate_timer_thread(self): 147 | """Function that starts a single thread 148 | :return None: 149 | """ 150 | self.timer = Thread( 151 | name="refresh_timer", 152 | target=self.refresh_timer, 153 | args=(1, lambda: self.stop_timer), 154 | ) 155 | self.timer.start() 156 | 157 | def refresh_timer(self, thread_id, stop): 158 | """Function that calls other functions to update data that is then displayed 159 | to the user once a second 160 | :param thread_id: 161 | :type int: # required for Thread(), otherwise unused 162 | :param stop: 163 | :type function: # allows loop to be exited gracefully 164 | :return None: 165 | """ 166 | while True: 167 | if stop(): 168 | break 169 | self.rates_update() 170 | self.totals_update() 171 | if self.curses: 172 | self.update_screen_contents() 173 | # add a newline to the end of the error list 174 | # pushing older errors out of the curses display 175 | self.print_error("") 176 | # remove the 1st element from the error_list 177 | # effectively making it a circular buffer 178 | del self.error_list[0] 179 | else: 180 | self.disp_total_progress() 181 | time.sleep(1) 182 | 183 | def stop_timer_thread(self): 184 | """function that causes the timer thread to exit 185 | :returns None: 186 | """ 187 | self.stop_timer = True 188 | try: 189 | self.timer.join() 190 | except (AttributeError, RuntimeError): 191 | pass 192 | 193 | def report_progress(self, file_name, file_size, sent): 194 | """For every % of data transferred, notifies the user 195 | :param file_name: name of file 196 | :type string: (from FTP lib) or bytes(from SCP lib) 197 | :param file_size: file size in bytes 198 | :type int: 199 | :param sent: bytes transferred 200 | :type int: 201 | :return None: 202 | """ 203 | try: 204 | file_name = file_name.decode() 205 | except AttributeError: 206 | # FTP lib uses string already 207 | pass 208 | if file_size == sent: 209 | self.files[file_name]["complete"] = 1 210 | else: 211 | self.files[file_name]["complete"] = 0 212 | self.files[file_name]["sent_bytes"] = sent 213 | self.file_percentage_update(file_name, file_size, sent) 214 | 215 | def disp_total_progress(self): 216 | """Function that outputs progress string when curses is not used 217 | :return None: 218 | """ 219 | print(f"\r{self.total_progress_str()}", end="") 220 | 221 | def file_percentage_update(self, file_name, file_size, sent): 222 | """Function to update the percent complete for a given file 223 | :param file_name: 224 | :type string: 225 | :param file_size: 226 | :type int: 227 | :param sent: 228 | :type int: 229 | :return None: 230 | """ 231 | percent_done = percent_val(file_size, sent) 232 | if self.files[file_name]["percent_done"] != percent_done: 233 | self.files[file_name]["percent_done"] = percent_done 234 | 235 | def totals_update(self): 236 | """Function that determines the total number of bytes sent, 237 | the total percentage of bytes transferred and how many of 238 | the chunks are completed 239 | :return None: 240 | """ 241 | sum_bytes_sent = 0 242 | sum_completed = 0 243 | total_file_size = self.totals["total_file_size"] 244 | for file in self.files.values(): 245 | sum_bytes_sent += file["sent_bytes"] 246 | sum_completed += file["complete"] 247 | self.totals["sum_bytes_sent"] = sum_bytes_sent 248 | self.totals["sum_completed"] = sum_completed 249 | percent_done = percent_val(total_file_size, sum_bytes_sent) 250 | self.totals["percent_done"] = percent_done 251 | logger.debug(self.totals) 252 | 253 | def total_progress_str(self): 254 | """returns a single line with progress info such as: 255 | % done, number of bytes transferred etc 256 | :return output: 257 | :type string: 258 | """ 259 | percent_done = self.totals["percent_done"] 260 | sum_completed = self.totals["sum_completed"] 261 | sum_bytes_per_sec = self.totals["sum_bytes_per_sec"] 262 | sum_bytes_sent = self.totals["sum_bytes_sent"] 263 | total_file_size = self.totals["total_file_size"] 264 | sum_bytes, sum_bytes_unit = bytes_display(sum_bytes_sent) 265 | total_bytes, total_bytes_unit = bytes_display(total_file_size) 266 | rate_per_sec, rate_unit = bytes_display(sum_bytes_per_sec) 267 | output = ( 268 | f"{str(percent_done)}% done {sum_bytes:.1f}{sum_bytes_unit}" 269 | f"/{total_bytes:.1f}{total_bytes_unit} " 270 | f"{rate_per_sec:>6.1f}{rate_unit}/s " 271 | f"({sum_completed}/{len(self.chunks)} chunks completed)" 272 | ) 273 | return output 274 | 275 | def rates_update(self): 276 | """updates the transfer rates per chunk and total. Called on a 1sec periodic 277 | :return None: 278 | """ 279 | sum_bytes_per_sec = 0.0 280 | for file in self.chunks: 281 | file_name = file[0] 282 | sent_bytes = self.files[file_name]["sent_bytes"] 283 | last_sent_bytes = self.files[file_name]["last_sent_bytes"] 284 | bytes_per_sec = sent_bytes - last_sent_bytes 285 | self.files[file_name]["last_sent_bytes"] = sent_bytes 286 | self.files[file_name]["bytes_per_sec"] = bytes_per_sec 287 | sum_bytes_per_sec += bytes_per_sec 288 | self.totals["sum_bytes_per_sec"] = sum_bytes_per_sec 289 | 290 | def zero_file_stats(self, file_name): 291 | """Function that resets a files stats if transfer is restarted 292 | :param file_name: 293 | :type string: 294 | :return None: 295 | """ 296 | self.files[file_name]["last_sent_bytes"] = 0 297 | self.files[file_name]["bytes_per_sec"] = 0 298 | self.files[file_name]["percent_done"] = 0 299 | self.files[file_name]["sent_bytes"] = 0 300 | self.files[file_name]["complete"] = 0 301 | 302 | def update_screen_contents(self): 303 | """Function collates the information to be drawn by curses 304 | :return None: 305 | """ 306 | txt_lines = [] 307 | for file in self.chunks: 308 | file_name = file[0] 309 | if len(file_name) > 10: 310 | file_name_str = f"{file_name[0:6]}..{file_name[-2:]}" 311 | else: 312 | file_name_str = file_name 313 | sent_bytes, sent_bytes_unit = bytes_display( 314 | self.files[file_name]["sent_bytes"] 315 | ) 316 | bytes_per_sec, bytes_per_sec_unit = bytes_display( 317 | self.files[file_name]["bytes_per_sec"] 318 | ) 319 | percent_done = self.files[file_name]["percent_done"] 320 | txt_lines.append( 321 | f"{file_name_str} {progress_bar(percent_done)} " 322 | f"{percent_done:>3}% {sent_bytes:>6.1f}{sent_bytes_unit} " 323 | f"{bytes_per_sec:>6.1f}{bytes_per_sec_unit}/s" 324 | ) 325 | txt_lines.append(pad_string("")) 326 | txt_lines.append(f"{pad_string(self.total_progress_str())}") 327 | # display the three most recent error strings 328 | err_idx = -3 329 | while err_idx < 0: 330 | txt_lines.append(self.error_list[err_idx]) 331 | err_idx += 1 332 | try: 333 | self.redraw_screen(txt_lines) 334 | except curses.error: 335 | abandon_curses() 336 | self.curses = False 337 | 338 | def print_error(self, error): 339 | """correctly output errors when curses window is active or not 340 | :param error: 341 | :type string: 342 | :return None: 343 | """ 344 | if not self.curses: 345 | print(f"\n{error}") 346 | else: 347 | # when using curses window, \n results in the following line starting 348 | # at the column the previous line ended at. This quickly becomes 349 | # illegible. Idea here is to put any error logs in a list 350 | # and only display the most recent additions in update_screen_contents() 351 | padded_string = pad_string(error) 352 | self.error_list.append(f"{padded_string}") 353 | 354 | def start_progress(self, use_curses): 355 | """Function that starts the timer thread (and thus progress output 356 | Initiates the curses window (if applicable) 357 | :param use_curses: 358 | :type bool: 359 | :return None: 360 | """ 361 | self.curses = self.check_term_size(use_curses) 362 | if self.curses: 363 | self.stdscr = prepare_curses() 364 | self.initiate_timer_thread() 365 | 366 | def stop_progress(self): 367 | """Function that stops the timer thread (and thus progress output) 368 | Then shuts down the curses window (if applicable) 369 | :return None: 370 | """ 371 | self.stop_timer_thread() 372 | abandon_curses() 373 | 374 | def redraw_screen(self, txt_lines): 375 | """Method to redraw the screen using curses library. 376 | :param txt_lines: 377 | :type list: 378 | :return None: 379 | """ 380 | lines = len(txt_lines) 381 | for line in range(lines): 382 | # using format 'y-axis, x-axis, string' 383 | self.stdscr.addstr(line, 0, txt_lines[line]) 384 | self.stdscr.refresh() 385 | -------------------------------------------------------------------------------- /src/splitcopy/put.py: -------------------------------------------------------------------------------- 1 | """ Copyright (c) 2018, Juniper Networks, Inc 2 | All rights reserved 3 | This SOFTWARE is licensed under the LICENSE provided in the 4 | ./LICENCE file. By downloading, installing, copying, or otherwise 5 | using the SOFTWARE, you agree to be bound by the terms of that 6 | LICENSE. 7 | """ 8 | 9 | # stdlib 10 | import asyncio 11 | import datetime 12 | import fnmatch 13 | import functools 14 | import hashlib 15 | import logging 16 | import os 17 | import re 18 | import signal 19 | import sys 20 | import time 21 | import traceback 22 | import warnings 23 | from ftplib import error_perm, error_proto, error_reply, error_temp 24 | 25 | # silence this warning 26 | from cryptography.utils import CryptographyDeprecationWarning 27 | 28 | warnings.simplefilter("ignore", CryptographyDeprecationWarning) 29 | 30 | # 3rd party 31 | from paramiko.ssh_exception import SSHException 32 | from scp import SCPClient 33 | 34 | # local modules 35 | from splitcopy.ftp import FTP 36 | from splitcopy.paramikoshell import SSHShell 37 | from splitcopy.progress import Progress 38 | from splitcopy.shared import SplitCopyShared 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | # use st_blksize 43 | _BUF_SIZE_READ = 1024 * 8 44 | 45 | 46 | class SplitCopyPut: 47 | def __init__(self, **kwargs): 48 | """ 49 | Initialize the SplitCopyPut class 50 | """ 51 | self.user = kwargs.get("user") 52 | self.host = kwargs.get("host") 53 | self.passwd = kwargs.get("passwd") 54 | self.ssh_key = kwargs.get("ssh_key") 55 | self.ssh_port = kwargs.get("ssh_port") 56 | self.remote_path = kwargs.get("remote_path") 57 | self.local_dir = kwargs.get("local_dir") 58 | self.local_file = kwargs.get("local_file") 59 | self.local_path = kwargs.get("local_path") 60 | self.copy_proto = kwargs.get("copy_proto") 61 | self.noverify = kwargs.get("noverify") 62 | self.use_curses = kwargs.get("use_curses") 63 | self.overwrite = kwargs.get("overwrite") 64 | self.scs = SplitCopyShared(**kwargs) 65 | self.mute = False 66 | self.sshshell = None 67 | self.progress = Progress() 68 | self.use_shell = False 69 | 70 | def handlesigint(self, sigint, stack): 71 | """called when SigInt is received 72 | :param sigint: 73 | :type int: 74 | :param stack: 75 | :type frame: 76 | """ 77 | logger.debug(f"signal {sigint} received, stack:\n{stack}") 78 | self.mute = True 79 | self.progress.stop_progress() 80 | self.scs.close() 81 | 82 | def put(self): 83 | """copies file from local host to remote host 84 | performs file split/transfer/join/verify functions 85 | :returns loop_start: time when transfers started 86 | :type: datetime object 87 | :returns loop_end: time when transfers ended 88 | :type: datetime object 89 | """ 90 | ssh_kwargs = { 91 | "username": self.user, 92 | "hostname": self.host, 93 | "password": self.passwd, 94 | "key_filename": self.ssh_key, 95 | "ssh_port": self.ssh_port, 96 | } 97 | 98 | # handle sigint gracefully on *nix 99 | signal.signal(signal.SIGINT, self.handlesigint) 100 | 101 | # determine local file size 102 | file_size = self.determine_local_filesize() 103 | 104 | # confirm local storage is sufficient 105 | self.scs.storage_check_local(file_size) 106 | 107 | # connect to host, open ssh transport 108 | self.sshshell, ssh_kwargs = self.scs.connect(SSHShell, **ssh_kwargs) 109 | 110 | # is this a juniper cli? 111 | if self.scs.juniper_cli_check(): 112 | self.use_shell = True 113 | self.scs.enter_shell() 114 | 115 | # ensure dest path is valid 116 | remote_file, remote_dir = self.validate_remote_path_put() 117 | 118 | # if target file exists, delete it? 119 | self.delete_target_remote(remote_dir, remote_file) 120 | 121 | # determine the OS 122 | junos, evo, bsd_version, sshd_version = self.scs.which_os() 123 | 124 | # verify which protocol to use 125 | self.copy_proto, self.passwd = self.scs.which_proto(self.copy_proto) 126 | 127 | # check required binaries exist on remote host 128 | self.scs.req_binaries(junos=junos, evo=evo) 129 | 130 | # cleanup previous remote tmp directory if found 131 | self.scs.remote_cleanup( 132 | remote_dir=remote_dir, remote_file=remote_file, silent=True 133 | ) 134 | 135 | # determine optimal size for chunks 136 | split_size, executor = self.scs.file_split_size( 137 | file_size, sshd_version, bsd_version, evo, self.copy_proto 138 | ) 139 | 140 | # confirm remote storage is sufficient 141 | self.scs.storage_check_remote(file_size, split_size, remote_dir) 142 | 143 | if not self.noverify: 144 | # get/create sha for local file 145 | sha_bin, sha_len, sha_hash = self.local_sha_put() 146 | 147 | with self.scs.tempdir(): 148 | # split file into chunks 149 | self.split_file_local(file_size, split_size) 150 | 151 | # add chunk names to a list, pass this info to Progress 152 | chunks = self.get_chunk_info() 153 | self.progress.add_chunks(file_size, chunks) 154 | 155 | # create tmp directory 156 | remote_tmpdir = self.scs.mkdir_remote(remote_dir, remote_file) 157 | 158 | # begin connection/rate limit check and transfer process 159 | command_list = [] 160 | if junos or evo: 161 | command_list = self.scs.limit_check(self.copy_proto) 162 | print("starting transfer...") 163 | self.progress.start_progress(self.use_curses) 164 | # copy files to remote host 165 | self.scs.hard_close = True 166 | loop_start = datetime.datetime.now() 167 | loop = asyncio.new_event_loop() 168 | tasks = [] 169 | for chunk in chunks: 170 | task = loop.run_in_executor( 171 | executor, 172 | functools.partial( 173 | self.put_files, 174 | FTP, 175 | SSHShell, 176 | SCPClient, 177 | chunk, 178 | remote_tmpdir, 179 | ssh_kwargs, 180 | ), 181 | ) 182 | tasks.append(task) 183 | try: 184 | loop.run_until_complete(asyncio.gather(*tasks)) 185 | except TransferError: 186 | self.progress.stop_progress() 187 | self.scs.close( 188 | err_str="\nan error occurred while copying the files to the remote host", 189 | ) 190 | finally: 191 | loop.close() 192 | self.scs.hard_close = False 193 | while self.progress.totals["percent_done"] != 100: 194 | time.sleep(0.1) 195 | self.progress.stop_progress() 196 | 197 | print("\ntransfer complete") 198 | loop_end = datetime.datetime.now() 199 | 200 | # combine chunks 201 | self.join_files_remote( 202 | SCPClient, chunks, remote_dir, remote_file, remote_tmpdir 203 | ) 204 | 205 | # remove remote tmp dir 206 | self.scs.remote_cleanup() 207 | 208 | # rollback any config changes made 209 | if command_list: 210 | self.scs.limits_rollback() 211 | 212 | # check remote file size is correct 213 | self.compare_file_sizes(file_size, remote_dir, remote_file) 214 | 215 | if self.noverify: 216 | print( 217 | f"file has been successfully copied to {self.host}:" 218 | f"{remote_dir}/{remote_file}" 219 | ) 220 | else: 221 | # generate a sha hash for the combined file, compare to hash of src 222 | self.remote_sha_put(sha_bin, sha_len, sha_hash, remote_dir, remote_file) 223 | 224 | self.sshshell.close() 225 | return loop_start, loop_end 226 | 227 | def get_chunk_info(self): 228 | """obtains the remote chunk file size and names 229 | :return chunks: 230 | :type list: 231 | """ 232 | logger.info("entering get_chunk_info()") 233 | chunks = [] 234 | for chunk in os.listdir("."): 235 | if fnmatch.fnmatch(chunk, f"{self.local_file}*"): 236 | chunks.append([chunk, os.stat(chunk).st_size]) 237 | if not chunks: 238 | self.scs.close(err_str="file split operation failed") 239 | # sort alphabetically 240 | chunks = sorted(chunks) 241 | logger.info(f"# of chunks = {len(chunks)}") 242 | logger.debug(chunks) 243 | return chunks 244 | 245 | def validate_remote_path_put(self): 246 | """path provided can be a full path to a file, or just a directory 247 | :return: None 248 | """ 249 | logger.info("entering validate_remote_path_put()") 250 | remote_path = self.remote_path 251 | try: 252 | remote_path = self.expand_remote_path(remote_path) 253 | remote_path = self.path_startswith_tilda(remote_path) 254 | except ValueError as err: 255 | self.scs.close( 256 | err_str=err, 257 | ) 258 | 259 | if self.scs.ssh_cmd(f"test -d {remote_path}")[0]: 260 | # target path provided is a directory 261 | remote_file = self.local_file 262 | remote_dir = remote_path.rstrip("/") 263 | elif self.scs.ssh_cmd(f"test -d {os.path.dirname(remote_path)}")[0]: 264 | if os.path.basename(remote_path) != self.local_file: 265 | # target path provided was a full path, file name does not match src 266 | # honour the change of file name 267 | remote_file = os.path.basename(remote_path) 268 | else: 269 | # target path provided was a full path, file name matches src 270 | remote_file = self.local_file 271 | remote_dir = os.path.dirname(remote_path) 272 | else: 273 | self.scs.close( 274 | err_str=f"target path {remote_path} on remote host isn't valid", 275 | ) 276 | logger.debug(f"remote_dir now = {remote_dir}, remote_file now = {remote_file}") 277 | # update SplitCopyShared with these values 278 | self.scs.remote_dir = remote_dir 279 | self.scs.remote_file = remote_file 280 | return remote_file, remote_dir 281 | 282 | def expand_remote_path(self, remote_path): 283 | """if only a filename is provided, expands the remote 284 | path to its absolute path 285 | :param remote_path: 286 | :type string: 287 | :return remote_path: 288 | :type string: 289 | :raises ValueError: if remote cmd fails 290 | """ 291 | logger.info("entering expand_remote_path()") 292 | if ( 293 | not re.search(r"\/", remote_path) 294 | or re.match(r"\.\/", remote_path) 295 | or not remote_path 296 | ): 297 | result, stdout = self.scs.ssh_cmd("pwd") 298 | if result: 299 | if self.use_shell: 300 | pwd = stdout.split("\n")[1].rstrip() 301 | else: 302 | pwd = stdout 303 | remote_path = re.sub(r"^\.\/", "", remote_path) 304 | remote_path = f"{pwd}/{remote_path}" 305 | logger.debug(f"remote_path now = {remote_path}") 306 | else: 307 | raise ValueError( 308 | "Cannot determine the current working directory on the remote host" 309 | ) 310 | return remote_path 311 | 312 | def path_startswith_tilda(self, remote_path): 313 | """expands ~ based path to absolute path 314 | :param remote_path: 315 | :type string: 316 | :return remote_path: 317 | :type string: 318 | :raises ValueError: if remote cmd fails 319 | """ 320 | logger.info("entering path_startswith_tilda()") 321 | if re.match(r"~", remote_path): 322 | result, stdout = self.scs.ssh_cmd(f"ls -d {remote_path}") 323 | if result: 324 | if self.use_shell: 325 | remote_path = stdout.split("\n")[1].rstrip() 326 | else: 327 | remote_path = stdout 328 | logger.debug(f"remote_path now = {remote_path}") 329 | else: 330 | raise ValueError(f"unable to expand remote path {remote_path}") 331 | return remote_path 332 | 333 | def check_target_exists(self, remote_dir, remote_file): 334 | """checks if the target file already exists 335 | :param remote_dir: 336 | :type string: 337 | :param remote_file: 338 | :type string: 339 | :return result: 340 | :type bool: 341 | """ 342 | logger.info("entering check_target_exists()") 343 | result, stdout = self.scs.ssh_cmd(f"test -e {remote_dir}/{remote_file}") 344 | return result 345 | 346 | def delete_target_remote(self, remote_dir, remote_file): 347 | """verifies whether path is a file. 348 | if true and --overwrite flag is specified attempt to delete it 349 | else alert the user and exit 350 | :param remote_dir: 351 | :type string: 352 | :param remote_file: 353 | :type string: 354 | :return None: 355 | """ 356 | logger.info("entering delete_target_remote()") 357 | if self.check_target_exists(remote_dir, remote_file): 358 | if self.overwrite: 359 | result, stdout = self.scs.ssh_cmd(f"rm -f {remote_dir}/{remote_file}") 360 | if not result: 361 | err = "remote file already exists, and could not be deleted" 362 | self.scs.close(err_str=err) 363 | else: 364 | err = "remote file already exists. use --overwrite arg or delete it manually" 365 | self.scs.close(err_str=err) 366 | 367 | def determine_local_filesize(self): 368 | """determines the local files size in bytes 369 | :return file_size: 370 | :type int: 371 | """ 372 | logger.info("entering determine_local_filesize()") 373 | file_size = os.path.getsize(self.local_path) 374 | logger.info(f"src file size is {file_size}") 375 | if not file_size: 376 | err = "local file size is 0 bytes, nothing to copy" 377 | self.scs.close(err_str=err) 378 | return file_size 379 | 380 | def local_sha_put(self): 381 | """checks whether a sha hash already exists for the file 382 | if not creates one 383 | :returns None: 384 | """ 385 | file_path = self.local_path 386 | sha_hash = {} 387 | logger.info("entering local_sha_put()") 388 | if os.path.isfile(f"{file_path}.sha512"): 389 | with open(f"{file_path}.sha512", "r") as shafile: 390 | local_sha = shafile.read().split()[0].rstrip() 391 | sha_hash[512] = local_sha 392 | if os.path.isfile(f"{file_path}.sha384"): 393 | with open(f"{file_path}.sha384", "r") as shafile: 394 | local_sha = shafile.read().split()[0].rstrip() 395 | sha_hash[384] = local_sha 396 | if os.path.isfile(f"{file_path}.sha256"): 397 | with open(f"{file_path}.sha256", "r") as shafile: 398 | local_sha = shafile.read().split()[0].rstrip() 399 | sha_hash[256] = local_sha 400 | if os.path.isfile(f"{file_path}.sha224"): 401 | with open(f"{file_path}.sha224", "r") as shafile: 402 | local_sha = shafile.read().split()[0].rstrip() 403 | sha_hash[224] = local_sha 404 | if os.path.isfile(f"{file_path}.sha1"): 405 | with open(f"{file_path}.sha1", "r") as shafile: 406 | local_sha = shafile.read().split()[0].rstrip() 407 | sha_hash[1] = local_sha 408 | if not sha_hash: 409 | print("sha1 not found, generating sha1...") 410 | sha1 = hashlib.sha1() 411 | with open(file_path, "rb") as original_file: 412 | data = original_file.read(_BUF_SIZE_READ) 413 | while data: 414 | sha1.update(data) 415 | data = original_file.read(_BUF_SIZE_READ) 416 | local_sha = sha1.hexdigest() 417 | sha_hash[1] = local_sha 418 | logger.info(f"local sha hashes = {sha_hash}") 419 | sha_bin, sha_len = self.scs.req_sha_binaries(sha_hash) 420 | return sha_bin, sha_len, sha_hash 421 | 422 | def split_file_local(self, file_size, split_size): 423 | """splits file into chunks of size already determined in file_split_size() 424 | This function emulates GNU split. 425 | :param file_size: 426 | :type int: 427 | :param split_size: 428 | :type int: 429 | :returns None: 430 | """ 431 | print("splitting file...") 432 | try: 433 | total_bytes = 0 434 | with open(self.local_path, "rb") as src: 435 | sfx_1 = "a" 436 | sfx_2 = "a" 437 | while total_bytes < file_size: 438 | with open(f"{self.local_file}{sfx_1}{sfx_2}", "wb") as chunk: 439 | logger.info(f"writing data to {self.local_file}{sfx_1}{sfx_2}") 440 | src.seek(total_bytes) 441 | data = src.read(split_size) 442 | chunk.write(data) 443 | total_bytes += split_size 444 | if sfx_2 == "z": 445 | sfx_1 = "b" 446 | sfx_2 = "a" 447 | else: 448 | sfx_2 = chr(ord(sfx_2) + 1) 449 | except Exception as error: 450 | err = f"an error occurred while splitting the file, the error was:\n{error}" 451 | self.scs.close(err_str=err) 452 | 453 | def join_files_remote( 454 | self, scp_lib, chunks, remote_dir, remote_file, remote_tmpdir 455 | ): 456 | """concatenates the files chunks into one file on remote host 457 | :param scp_lib: 458 | :type class: 459 | :param chunks: 460 | :type list: 461 | :param remote_dir: 462 | :type string: 463 | :param remote_file: 464 | :type string: 465 | :param remote_tmpdir: 466 | :type string: 467 | :returns None: 468 | """ 469 | logger.info("entering join_files_remote()") 470 | print("joining chunks...") 471 | result = False 472 | cmd = "" 473 | try: 474 | for chunk in chunks: 475 | cmd += ( 476 | f"cat {remote_tmpdir}/{chunk[0]} " 477 | f">>{remote_dir}/{remote_file}\n" 478 | "if [ $? -gt 0 ]; then exit 1; fi\n" 479 | f"rm {remote_tmpdir}/{chunk[0]}\n" 480 | "if [ $? -gt 0 ]; then exit 1; fi\n" 481 | ) 482 | with self.scs.tempdir(): 483 | with open("join.sh", "w", newline="\n") as fd: 484 | fd.write(cmd) 485 | transport = self.sshshell._transport 486 | with scp_lib(transport) as scpclient: 487 | scpclient.put("join.sh", f"{remote_tmpdir}/join.sh") 488 | result, stdout = self.scs.ssh_cmd( 489 | f"sh {remote_tmpdir}/join.sh", 490 | timeout=600, 491 | retry=False, 492 | ) 493 | except Exception as err: 494 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 495 | self.scs.close( 496 | err_str=( 497 | f"{err.__class__.__name__} while combining file chunks on " 498 | f"remote host: {str(err)}" 499 | ), 500 | ) 501 | 502 | if not result: 503 | self.scs.close( 504 | err_str=( 505 | "failed to combine chunks on remote host. " f"error was:\n{stdout}" 506 | ), 507 | ) 508 | return result 509 | 510 | def compare_file_sizes(self, file_size, remote_dir, remote_file): 511 | """obtains the newly combined file size, compares it to the source files size 512 | :param file_size: 513 | :type int: 514 | :param remote_dir: 515 | :type string: 516 | :param remote_file: 517 | :type string: 518 | :return None: 519 | """ 520 | logger.info("entering compare_file_sizes()") 521 | result, stdout = self.scs.ssh_cmd(f"ls -l {remote_dir}/{remote_file}") 522 | if not result: 523 | self.scs.close( 524 | err_str=( 525 | f"file {self.host}:{remote_dir}/{remote_file} not found! please retry" 526 | ), 527 | config_rollback=False, 528 | ) 529 | if self.use_shell: 530 | combined_file_size = int(stdout.split("\r\n")[1].split()[4]) 531 | else: 532 | combined_file_size = int(stdout.split()[4]) 533 | if combined_file_size != file_size: 534 | self.scs.close( 535 | err_str=( 536 | f"combined file size is {combined_file_size}, file " 537 | f"{self.host}:{remote_dir}/{remote_file} size " 538 | f"is {file_size}. Unexpected mismatch in file size. Please retry" 539 | ), 540 | config_rollback=False, 541 | ) 542 | print("local and remote file sizes match") 543 | 544 | def remote_sha_put(self, sha_bin, sha_len, sha_hash, remote_dir, remote_file): 545 | """creates a sha hash for the newly combined file 546 | on the remote host compares against local sha 547 | :param sha_bin: 548 | :type string: 549 | :param sha_len: 550 | :type int: 551 | :param sha_hash: 552 | :type hash: 553 | :param remote_dir: 554 | :type string: 555 | :param remote_file: 556 | :type string: 557 | :returns None: 558 | """ 559 | print("generating remote sha hash...") 560 | remote_sha = "" 561 | if sha_bin == "shasum": 562 | cmd = f"shasum -a {sha_len}" 563 | else: 564 | cmd = f"{sha_bin}" 565 | 566 | result, stdout = self.scs.ssh_cmd( 567 | f"{cmd} {remote_dir}/{remote_file}", timeout=300 568 | ) 569 | if not result: 570 | print( 571 | "remote sha hash generation failed or timed out, " 572 | f'manually check the output of "{cmd} {remote_dir}/{remote_file}" and ' 573 | f"compare against {sha_hash[sha_len]}" 574 | ) 575 | return 576 | for line in stdout.splitlines(): 577 | try: 578 | remote_sha = re.search(r"([0-9a-f]{40,})", line).group(1) 579 | break 580 | except AttributeError: 581 | pass 582 | if not remote_sha: 583 | self.scs.close( 584 | err_str="failed to obtain remote sha hash to compare against", 585 | config_rollback=False, 586 | ) 587 | logger.info(f"remote sha = {remote_sha}") 588 | if sha_hash[sha_len] == remote_sha: 589 | print( 590 | f"local and remote sha hash match\nfile has been " 591 | f"successfully copied to {self.host}:{remote_dir}/{remote_file}" 592 | ) 593 | else: 594 | self.scs.close( 595 | err_str=( 596 | f"file has been copied to {self.host}:{remote_dir}/{remote_file}" 597 | ", but the local and remote sha do not match - please retry" 598 | ), 599 | config_rollback=False, 600 | ) 601 | 602 | def put_files(self, ftp_lib, ssh_lib, scp_lib, chunk, remote_tmpdir, ssh_kwargs): 603 | """copies files to remote host via ftp or scp 604 | :param ftp_lib: 605 | :type class: 606 | :param ssh_lib: 607 | :type class: 608 | :param scp_lib: 609 | :type class: 610 | :param chunk: name and size of the file to copy 611 | :type: list 612 | :param remote_tmpdir: path of the tmp directory on remote host 613 | :type: str 614 | :param ssh_kwargs: keyword arguments 615 | :type dict: 616 | :raises TransferError: if file transfer fails 3 times 617 | :returns None: 618 | """ 619 | err_count = 0 620 | file_name = chunk[0] 621 | file_size = chunk[1] 622 | dstpath = f"{remote_tmpdir}/{file_name}" 623 | logger.info(f"{file_name}, size {file_size}") 624 | while err_count < 3: 625 | try: 626 | if self.copy_proto == "ftp": 627 | with ftp_lib( 628 | file_size=file_size, 629 | progress=self.progress, 630 | host=self.host, 631 | user=self.user, 632 | passwd=self.passwd, 633 | ) as ftp: 634 | restart_marker = None 635 | if err_count: 636 | try: 637 | ftp.sendcmd("TYPE I") 638 | restart_marker = ftp.size(dstpath) 639 | except (error_perm, error_proto, error_reply, error_temp): 640 | pass 641 | if restart_marker is not None: 642 | self.progress.print_error( 643 | f"resuming {file_name} from byte {restart_marker}" 644 | ) 645 | else: 646 | self.progress.zero_file_stats(file_name) 647 | ftp.put(file_name, dstpath, restart_marker) 648 | break 649 | else: 650 | with ssh_lib(**ssh_kwargs) as ssh: 651 | ssh.socket_open() 652 | ssh.transport_open() 653 | if not ssh.worker_thread_auth(): 654 | ssh.close() 655 | raise SSHException("authentication failed") 656 | with scp_lib( 657 | ssh._transport, progress=self.progress.report_progress 658 | ) as scpclient: 659 | if err_count: 660 | self.progress.zero_file_stats(file_name) 661 | scpclient.put(file_name, dstpath) 662 | break 663 | except Exception as err: 664 | err_count += 1 665 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 666 | if not self.mute: 667 | if err_count < 3: 668 | self.progress.print_error( 669 | f"chunk {file_name} transfer failed due to " 670 | f"{err.__class__.__name__} {str(err)}, retrying" 671 | ) 672 | else: 673 | self.progress.print_error( 674 | f"chunk {file_name} transfer failed due to " 675 | f"{err.__class__.__name__} {str(err)}" 676 | ) 677 | time.sleep(err_count) 678 | 679 | if err_count == 3: 680 | self.mute = True 681 | raise TransferError 682 | 683 | 684 | class TransferError(Exception): 685 | """Custom exception to indicate problem with file transfer""" 686 | 687 | pass 688 | -------------------------------------------------------------------------------- /src/splitcopy/shared.py: -------------------------------------------------------------------------------- 1 | """ Copyright (c) 2018, Juniper Networks, Inc 2 | All rights reserved 3 | This SOFTWARE is licensed under the LICENSE provided in the 4 | ./LICENCE file. By downloading, installing, copying, or otherwise 5 | using the SOFTWARE, you agree to be bound by the terms of that 6 | LICENSE. 7 | """ 8 | 9 | # stdlib 10 | import concurrent.futures 11 | import datetime 12 | import getpass 13 | import logging 14 | import os 15 | import re 16 | import shutil 17 | import socket 18 | import sys 19 | import tempfile 20 | import traceback 21 | from contextlib import contextmanager 22 | from ftplib import error_perm, error_proto, error_reply, error_temp 23 | from math import ceil 24 | from socket import timeout as socket_timeout 25 | 26 | # 3rd party 27 | from paramiko.ssh_exception import SSHException 28 | 29 | # local modules 30 | from splitcopy.ftp import FTP 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def pad_string(text): 36 | """pads a given string to the terminal width 37 | :param text: 38 | :type string: 39 | :return padded_string: 40 | :type string 41 | """ 42 | term_width = shutil.get_terminal_size()[0] 43 | padding = " " * (term_width - len(text)) 44 | padded_string = f"{text}{padding}" 45 | return padded_string 46 | 47 | 48 | class SplitCopyShared: 49 | """class containing functions used by both SplitCopyGet 50 | and SplitCopyPut classes 51 | """ 52 | 53 | def __init__(self, **kwargs): 54 | """Initialise the class""" 55 | self.user = kwargs.get("user") 56 | self.host = kwargs.get("host") 57 | self.passwd = kwargs.get("passwd") 58 | self.ssh_key = kwargs.get("ssh_key") 59 | self.ssh_port = kwargs.get("ssh_port") 60 | self.local_dir = kwargs.get("local_dir") 61 | self.copy_op = kwargs.get("copy_op") 62 | self.remote_dir = "" 63 | self.remote_file = "" 64 | self.command_list = [] 65 | self.rm_remote_tmp = False 66 | self.local_tmpdir = None 67 | self.remote_tmpdir = None 68 | self.sshshell = None 69 | self.use_shell = False 70 | self.hard_close = False 71 | 72 | def connect(self, ssh_lib, **ssh_kwargs): 73 | """open an ssh session to a remote host 74 | :param ssh_lib: 75 | :type class: 76 | :param ssh_kwargs: 77 | :type dict: 78 | :return self.sshshell: 79 | :type paramiko.SSHShell object: 80 | :returm ssh_kwargs: 81 | :type dict: 82 | """ 83 | logger.info("entering connect()") 84 | try: 85 | self.sshshell = ssh_lib(**ssh_kwargs) 86 | self.sshshell.socket_open() 87 | self.sshshell.transport_open() 88 | self.sshshell.set_transport_keepalive() 89 | if self.sshshell.main_thread_auth(): 90 | ssh_kwargs = self.sshshell.kwargs 91 | logger.debug(f"ssh_kwargs returned are: {ssh_kwargs}") 92 | else: 93 | raise SSHException("authentication failed") 94 | except Exception as err: 95 | logger.debug("".join(traceback.format_exception(*sys.exc_info()))) 96 | if self.sshshell is not None: 97 | self.sshshell.close() 98 | raise SystemExit( 99 | f"{err.__class__.__name__} returned while connecting via ssh: {str(err)}" 100 | ) 101 | return self.sshshell, ssh_kwargs 102 | 103 | def which_proto(self, copy_proto): 104 | """determines which protocol will be used for the transfer. 105 | If FTP is selected as protocol, verify that authentication works 106 | :param copy_proto: 107 | :type string: 108 | :return copy_proto: 109 | :type string: 110 | :return passwd: 111 | :type string: 112 | """ 113 | logger.info("entering which_proto()") 114 | passwd = self.sshshell.kwargs["password"] 115 | result = None 116 | if copy_proto == "ftp" and self.ftp_port_check(): 117 | if passwd is None: 118 | passwd = getpass.getpass( 119 | prompt=f"{self.user}'s password: ", stream=None 120 | ) 121 | try: 122 | result = self.ftp_login_check(passwd) 123 | except (error_reply, error_temp, error_perm, error_proto) as err: 124 | print( 125 | f"ftp login failed, switching to scp for transfer. Error was: {err}" 126 | ) 127 | except socket_timeout: 128 | print("ftp auth timed out, switching to scp for transfer") 129 | 130 | if not result: 131 | copy_proto = "scp" 132 | else: 133 | copy_proto = "scp" 134 | 135 | logger.info(f"copy_proto == {copy_proto}") 136 | return copy_proto, passwd 137 | 138 | def ftp_port_check(self, socket_lib=socket): 139 | """checks whether the ftp port is open 140 | :return result: 141 | :type bool: 142 | """ 143 | logger.info("entering ftp_port_check()") 144 | result = False 145 | print("attempting FTP authentication...") 146 | try: 147 | socket_lib.create_connection((self.host, 21), 10) 148 | logger.info("ftp port is open") 149 | result = True 150 | except socket_timeout: 151 | print("ftp socket timed out, switching to scp for transfer") 152 | except ConnectionRefusedError: 153 | print("ftp connection refused, switching to scp for transfer") 154 | 155 | return result 156 | 157 | def ftp_login_check(self, passwd, ftp_lib=FTP): 158 | """verifies ftp authentication on remote host 159 | :param passwd: 160 | :type string: 161 | :return result: 162 | :type bool: 163 | """ 164 | logger.info("entering ftp_login_check()") 165 | result = False 166 | kwargs = { 167 | "host": self.host, 168 | "user": self.user, 169 | "passwd": passwd, 170 | "timeout": 10, 171 | } 172 | with ftp_lib(**kwargs) as ftp: 173 | result = True 174 | return result 175 | 176 | def juniper_cli_check(self): 177 | """determines whether exec cmd is run on a juniper cli 178 | :returns bool: 179 | """ 180 | logger.info("entering juniper_cli_check()") 181 | result, stdout = self.ssh_cmd("uname") 182 | if result and stdout == "\nerror: unknown command: uname": 183 | # this is junos or evo CLI. exit code is always 0. 184 | self.use_shell = True 185 | elif result: 186 | pass 187 | else: 188 | err = "cmd 'uname' failed on remote host, it must be *nix based" 189 | self.close(err_str=err) 190 | return self.use_shell 191 | 192 | def which_os(self): 193 | """determines if host is JUNOS or EVO 194 | no support for remote Windows OS running OpenSSH 195 | :return junos: 196 | :type bool: 197 | :return evo: 198 | :type bool: 199 | :return bsd_version: 200 | :type float: 201 | :return sshd_version: 202 | :type float: 203 | """ 204 | logger.info("entering which_os()") 205 | junos = False 206 | evo = False 207 | bsd_version = float() 208 | sshd_version = float() 209 | result, stdout = self.ssh_cmd("uname") 210 | if not result: 211 | err = "cmd 'uname' failed on remote host, it must be *nix based" 212 | self.close(err_str=err) 213 | if self.use_shell: 214 | host_os = stdout.split("\n")[1].rstrip() 215 | else: 216 | host_os = stdout 217 | if host_os == "Linux" and self.evo_os(): 218 | evo = True 219 | elif host_os == "JUNOS": 220 | junos = True 221 | bsd_version = 6.3 222 | sshd_version = self.which_sshd() 223 | elif host_os == "FreeBSD" and self.junos_os(): 224 | junos = True 225 | bsd_version = self.which_bsd() 226 | sshd_version = self.which_sshd() 227 | logger.info( 228 | f"evo = {evo}, " 229 | f"junos = {junos}, " 230 | f"bsd_version = {bsd_version}, " 231 | f"sshd_version = {sshd_version}" 232 | ) 233 | return junos, evo, bsd_version, sshd_version 234 | 235 | def evo_os(self): 236 | """determines if host is running EVO 237 | :return result: 238 | :type bool: 239 | """ 240 | logger.info("entering evo_os()") 241 | result, stdout = self.ssh_cmd("test -e /usr/sbin/evo-pfemand") 242 | return result 243 | 244 | def junos_os(self): 245 | """determines if host is running JUNOS 246 | :return result: 247 | :type bool: 248 | """ 249 | logger.info("entering junos_os()") 250 | result, stdout = self.ssh_cmd("uname -i | egrep 'JUNIPER|JNPR'") 251 | return result 252 | 253 | def which_bsd(self): 254 | """determines the BSD version of JUNOS 255 | :return bsd_version: 256 | :type float: 257 | """ 258 | logger.info("entering which_bsd()") 259 | result, stdout = self.ssh_cmd("uname -r") 260 | if not result: 261 | self.close(err_str="failed to determine remote bsd version") 262 | if self.use_shell: 263 | uname = stdout.split("\n")[1] 264 | else: 265 | uname = stdout 266 | bsd_version = float(uname.split("-")[1]) 267 | return bsd_version 268 | 269 | def which_sshd(self): 270 | """determines the OpenSSH daemon version 271 | :return sshd_version: 272 | :type float 273 | """ 274 | logger.info("entering which_sshd()") 275 | result, stdout = self.ssh_cmd("sshd -v", exitcode=False, combine=True) 276 | if self.use_shell: 277 | if not re.search(r"OpenSSH_", stdout): 278 | self.close(err_str="failed to determine remote openssh version") 279 | output = stdout.split("\n")[2] 280 | else: 281 | if not re.search(r"OpenSSH_", stdout): 282 | self.close(err_str="failed to determine remote openssh version") 283 | output = stdout.split("\n")[1] 284 | version = re.sub(r"OpenSSH_", "", output) 285 | sshd_version = float(version[0:3]) 286 | return sshd_version 287 | 288 | def req_binaries(self, junos=False, evo=False): 289 | """ensures required binaries exist on remote host 290 | :param junos: 291 | :type bool: 292 | :param evo: 293 | :type bool: 294 | :returns None: 295 | """ 296 | logger.info("entering req_binaries()") 297 | if not junos and not evo: 298 | if self.copy_op == "get": 299 | req_bins = "dd ls df rm" 300 | else: 301 | req_bins = "cat ls df rm" 302 | result, stdout = self.ssh_cmd(f"which {req_bins}") 303 | if not result: 304 | self.close( 305 | err_str=( 306 | f"one or more required binaries [{req_bins}] is missing from remote host" 307 | ) 308 | ) 309 | 310 | def req_sha_binaries(self, sha_hash): 311 | """ensures required binaries for sha hash creation exist on remote host 312 | :param sha_hash: 313 | :type hash: 314 | :return sha_bin: 315 | :type string: 316 | :return sha_len: 317 | :type int: 318 | """ 319 | logger.info("entering req_sha_binaries()") 320 | sha_bins = [] 321 | sha_bin = "" 322 | sha_len = 0 323 | if sha_hash.get(512): 324 | bins = [("sha512sum", 512), ("sha512", 512), ("shasum", 512)] 325 | sha_bins.extend(bins) 326 | if sha_hash.get(384): 327 | bins = [("sha384sum", 384), ("sha384", 384), ("shasum", 384)] 328 | sha_bins.extend(bins) 329 | if sha_hash.get(256): 330 | bins = [("sha256sum", 256), ("sha256", 256), ("shasum", 256)] 331 | sha_bins.extend(bins) 332 | if sha_hash.get(224): 333 | bins = [("sha224sum", 224), ("sha224", 224), ("shasum", 224)] 334 | sha_bins.extend(bins) 335 | if sha_hash.get(1): 336 | bins = [("sha1sum", 1), ("sha1", 1), ("shasum", 1)] 337 | sha_bins.extend(bins) 338 | 339 | sha_bins = sorted(set(sha_bins), reverse=True, key=lambda x: (x[1], x[0])) 340 | logger.info(sha_bins) 341 | 342 | for req_bin in sha_bins: 343 | result, stdout = self.ssh_cmd(f"which {req_bin[0]}") 344 | if result: 345 | sha_bin = req_bin[0] 346 | sha_len = req_bin[1] 347 | break 348 | if not sha_bin: 349 | self.close( 350 | err_str=( 351 | "required binary used to generate a sha " 352 | "hash on the remote host isn't found" 353 | ) 354 | ) 355 | return sha_bin, sha_len 356 | 357 | def close(self, err_str=None, config_rollback=True): 358 | """called when we want to exit the script 359 | attempts to delete the remote temp directory and close the TCP session 360 | If hard_close == False, contextmanager will rm the local temp dir 361 | If not, we must delete it manually 362 | :param err_str: 363 | :type: string: 364 | :param config_rollback: 365 | :type bool: 366 | :raises SystemExit: terminates the script gracefully 367 | :raises os._exit: terminates the script immediately (even asyncio loop) 368 | """ 369 | logger.info("entering close()") 370 | if err_str: 371 | print(err_str) 372 | if ( 373 | self.use_shell 374 | and self.sshshell._chan is not None 375 | and not self.sshshell._chan.closed 376 | or not self.use_shell 377 | and self.sshshell._transport is not None 378 | and self.sshshell._transport.active 379 | ): 380 | if self.rm_remote_tmp: 381 | self.remote_cleanup() 382 | if config_rollback and self.command_list: 383 | self.limits_rollback() 384 | print(f"\r{pad_string('closing device connection')}") 385 | self.sshshell.close() 386 | if self.hard_close: 387 | try: 388 | shutil.rmtree(self.local_tmpdir) 389 | except PermissionError: 390 | # windows can throw this error, silence it for now 391 | print( 392 | f"{self.local_tmpdir} may still exist, please delete manually if so" 393 | ) 394 | raise os._exit(1) 395 | else: 396 | raise SystemExit(1) 397 | 398 | def file_split_size(self, file_size, sshd_version, bsd_version, evo, copy_proto): 399 | """determines the optimal chunk size. This depends on the python 400 | version, cpu count, the protocol used and the FreeBSD/OpenSSH versions 401 | :returns split_size: 402 | :type int: 403 | :returns executor: 404 | :type concurrent.futures object: 405 | """ 406 | logger.info("entering file_split_size()") 407 | 408 | cpu_count = 1 409 | try: 410 | cpu_count = os.cpu_count() 411 | except NotImplementedError: 412 | pass 413 | max_workers = min(32, cpu_count * 5) 414 | 415 | # each uid can have max of 64 processes 416 | # modulate worker count to consume no more than 40 pids 417 | if copy_proto == "ftp": 418 | # ftp creates 1 user process per chunk, no modulation required 419 | split_size = ceil(file_size / max_workers) 420 | elif max_workers == 5: 421 | # 1 cpu core, 5 workers will create <= 20 pids 422 | # no modulation required 423 | split_size = ceil(file_size / max_workers) 424 | else: 425 | # scp to FreeBSD 6 based junos creates 3 user processes per chunk 426 | # scp to FreeBSD 10+ based junos creates 2 user processes per chunk 427 | # +1 user process if openssh version is >= 7.4 428 | pid_count = 0 429 | max_pids = 40 430 | if sshd_version >= 7.4 and bsd_version == 6.3: 431 | pid_count = 4 432 | elif sshd_version >= 7.4 and bsd_version >= 10.0: 433 | pid_count = 3 434 | elif bsd_version == 6.3: 435 | pid_count = 3 436 | elif bsd_version >= 10.0: 437 | pid_count = 2 438 | elif evo: 439 | pid_count = 3 440 | 441 | if pid_count: 442 | max_workers = round(max_pids / pid_count) 443 | else: 444 | # sshd config defaults 445 | # Maxsessions = 10, MaxStartups = 10:30:100 446 | # value here should not hit these limits 447 | max_workers = 5 448 | 449 | split_size = ceil(file_size / max_workers) 450 | 451 | # concurrent.futures.ThreadPoolExecutor can be a limiting factor 452 | # if using python < 3.5.3 the default max_workers is 5. 453 | # see https://github.com/python/cpython/blob/v3.5.2/Lib/asyncio/base_events.py 454 | # hence defining a custom executor to normalize max_workers across versions 455 | executor = concurrent.futures.ThreadPoolExecutor( 456 | max_workers=max_workers, thread_name_prefix="ThreadPoolWorker" 457 | ) 458 | logger.info( 459 | f"max_workers = {max_workers}, cpu_count = {cpu_count}, split_size = {split_size}" 460 | ) 461 | return split_size, executor 462 | 463 | def mkdir_remote(self, remote_dir, remote_file): 464 | """creates a tmp directory on the remote host 465 | :param remote_dir: 466 | :type string: 467 | :param remote_file: 468 | :type string: 469 | :returns remote_tmpdir: 470 | :type string: 471 | """ 472 | logger.info("entering mkdir_remote()") 473 | time_stamp = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d%H%M%S") 474 | self.remote_tmpdir = f"{remote_dir}/splitcopy_{remote_file}.{time_stamp}" 475 | result, stdout = self.ssh_cmd(f"mkdir -p {self.remote_tmpdir}") 476 | if not result: 477 | err = f"unable to create the tmp directory {self.remote_tmpdir} on remote host" 478 | self.close(err_str=err) 479 | self.rm_remote_tmp = True 480 | return self.remote_tmpdir 481 | 482 | def storage_check_remote(self, file_size, split_size, remote_dir): 483 | """checks whether there is enough storage space on remote node 484 | :param file_size: 485 | :type int: 486 | :param split_size: 487 | :type int: 488 | :param remote_dir: 489 | :type string: 490 | :returns None: 491 | """ 492 | logger.info("entering storage_check_remote()") 493 | avail_blocks = 0 494 | print("checking remote storage...") 495 | result, stdout = self.ssh_cmd(f"df -k {remote_dir}") 496 | if not result: 497 | self.close(err_str="failed to determine remote disk space available") 498 | try: 499 | fs_blocks = re.search(r" ([0-9]+) +([0-9]+) +(-?[0-9]+) +", stdout) 500 | total_blocks = int(fs_blocks.group(1)) 501 | used_blocks = int(fs_blocks.group(2)) 502 | avail_blocks = int(fs_blocks.group(3)) 503 | except AttributeError: 504 | err = "unable to determine available storage on remote host" 505 | self.close(err_str=err) 506 | 507 | if avail_blocks < 0: 508 | reserved_blocks_percent = 100 - round( 509 | 100 / total_blocks * (used_blocks + avail_blocks) 510 | ) 511 | reserved_blocks_threshold = used_blocks + avail_blocks 512 | reserved_blocks_count = total_blocks - (used_blocks + avail_blocks) 513 | err = ( 514 | f"not enough available storage on remote host in {remote_dir}\n" 515 | f"{reserved_blocks_count} / {reserved_blocks_percent}% of 1024-byte blocks " 516 | "are reserved and may only be allocated by privileged processes\n" 517 | f"used blocks: {used_blocks} is > than the threshold for reserved blocks: {reserved_blocks_threshold}" 518 | ) 519 | self.close(err_str=err) 520 | 521 | avail_bytes = avail_blocks * 1024 522 | logger.info(f"remote filesystem available bytes is {avail_bytes}") 523 | if self.copy_op == "get": 524 | if file_size > avail_bytes: 525 | err = ( 526 | f"not enough storage on remote host in {remote_dir}\n" 527 | f"available bytes ({avail_bytes}) must be >= the original file size " 528 | f"({file_size}) because it has to store the file chunks" 529 | ) 530 | self.close(err_str=err) 531 | else: 532 | if file_size + split_size > avail_bytes: 533 | err = ( 534 | f"not enough storage on remote host in {remote_dir}\n" 535 | f"available bytes ({avail_bytes}) must be > " 536 | f"the original file size ({file_size}) + largest chunk size " 537 | f"({split_size})" 538 | ) 539 | self.close(err_str=err) 540 | 541 | def storage_check_local(self, file_size): 542 | """checks whether there is enough storage space on local node 543 | :param file_size: 544 | :type int: 545 | :return None: 546 | """ 547 | logger.info("entering storage_check_local()") 548 | print("checking local storage...") 549 | local_tmpdir = tempfile.gettempdir() 550 | avail_bytes = shutil.disk_usage(local_tmpdir)[2] 551 | logger.info(f"local filesystem {local_tmpdir} available bytes is {avail_bytes}") 552 | if file_size > avail_bytes: 553 | err = ( 554 | f"not enough storage on local host in temp dir {local_tmpdir}.\n" 555 | f"Available bytes ({avail_bytes}) must be > the original file size " 556 | f"({file_size}) because it has to store the file chunks" 557 | ) 558 | self.close(err_str=err) 559 | 560 | if self.copy_op == "get": 561 | avail_bytes = shutil.disk_usage(self.local_dir)[2] 562 | logger.info( 563 | f"local filesystem {self.local_dir} available bytes is {avail_bytes}" 564 | ) 565 | if file_size > avail_bytes: 566 | err = ( 567 | f"not enough storage on local host in {self.local_dir}.\n" 568 | f"Available bytes ({avail_bytes}) must be > the " 569 | f"original file size ({file_size}) because it has to " 570 | "recombine the file chunks into a whole file" 571 | ) 572 | self.close(err_str=err) 573 | 574 | @contextmanager 575 | def change_dir(self, cleanup=lambda: True): 576 | """cds into temp directory. 577 | Upon script exit, changes back to original directory 578 | and calls cleanup() to delete the temp directory 579 | :param cleanup: 580 | :type function: 581 | :returns None: 582 | """ 583 | prevdir = os.getcwd() 584 | os.chdir(os.path.expanduser(self.local_tmpdir)) 585 | try: 586 | yield 587 | finally: 588 | os.chdir(prevdir) 589 | cleanup() 590 | 591 | @contextmanager 592 | def tempdir(self): 593 | """creates a temp directory, defines how to delete directory upon script exit 594 | :returns None: 595 | """ 596 | self.local_tmpdir = tempfile.mkdtemp() 597 | logger.info(self.local_tmpdir) 598 | 599 | def cleanup(): 600 | """deletes temp dir""" 601 | shutil.rmtree(self.local_tmpdir) 602 | 603 | with self.change_dir(cleanup): 604 | yield self.local_tmpdir 605 | 606 | def return_tmpdir(self): 607 | """Function to return class variable 608 | :return self.local_tmpdir: 609 | :type string: 610 | """ 611 | return self.local_tmpdir 612 | 613 | def find_configured_limits(self, config_stanzas, limits): 614 | """Function that retrieves any configuration stazas that implement 615 | rate/connection limits. 616 | It is faster to perform grep on the router, than to transfer 617 | potentially huge amounts of text and do it locally 618 | :param config_stanzas: 619 | :type list: 620 | :param limits: 621 | :type list: 622 | :return cli_config: 623 | :type string: 624 | """ 625 | logger.info("entering find_configured_limits()") 626 | cli_config = "" 627 | limits_str = "|".join(limits) 628 | for stanza in config_stanzas: 629 | result, stdout = self.ssh_cmd( 630 | f"cli -c 'show configuration {stanza} | display set | " 631 | f'grep "{limits_str}" | no-more\'', 632 | ) 633 | cli_config += stdout 634 | return cli_config 635 | 636 | def limit_check(self, copy_proto): 637 | """Function that checks the remote junos/evo hosts configuration to 638 | determine whether there are any ftp or ssh connection/rate limits defined. 639 | If found, these configuration lines will be deactivated 640 | :param copy_proto: 641 | :type string: 642 | :return self.command_list: 643 | :type list: 644 | """ 645 | logger.info("entering limit_check()") 646 | config_stanzas = ["groups", "system services", "system login"] 647 | retry_options = "system login retry-options" 648 | limits = ["services ssh connection-limit", "services ssh rate-limit"] 649 | if copy_proto == "ftp": 650 | limits.append("services ftp connection-limit") 651 | limits.append("services ftp rate-limit") 652 | print("checking router configuration... ") 653 | cli_config = self.find_configured_limits(config_stanzas, limits) 654 | 655 | for limit in limits: 656 | re_limit_multiline = re.compile(rf"^set .*{limit} [0-9]", re.MULTILINE) 657 | conf_list_limits = re.findall(re_limit_multiline, cli_config) 658 | for conf_statement in conf_list_limits: 659 | conf_line = re.sub(" [0-9]+$", "", conf_statement) 660 | conf_line = re.sub(r"^set", "deactivate", conf_line) 661 | self.command_list.append(f"{conf_line};") 662 | re_retry_multiline = re.compile(rf"^set .*{retry_options}", re.MULTILINE) 663 | conf_list_retry_options = re.findall(re_retry_multiline, cli_config) 664 | for conf_statement in conf_list_retry_options: 665 | conf_line = re.match(rf"(set .*{retry_options})", conf_statement).group(1) 666 | conf_line = re.sub(r"^set", "deactivate", conf_line) 667 | self.command_list.append(f"{conf_line};") 668 | 669 | # if limits were configured, deactivate them 670 | if self.command_list: 671 | print("rate-limit/connection-limit/login retry-options configuration found") 672 | logger.info(self.command_list) 673 | result, stdout = self.ssh_cmd( 674 | f'cli -c "edit;{"".join(self.command_list)}commit and-quit"', 675 | exitcode=False, 676 | timeout=60, 677 | ) 678 | # cli always returns true so can't use exitcode 679 | if re.search(r"commit complete\r\nExiting configuration mode", stdout): 680 | print("configuration has been modified. deactivated the relevant lines") 681 | self.ssh_cmd( 682 | "logger 'splitcopy has made the following config changes: " 683 | f"{''.join(self.command_list)}'", 684 | exitcode=False, 685 | ) 686 | else: 687 | err = ( 688 | "Error: failed to deactivate connection-limit/rate-limit/login retry-options" 689 | f"configuration. output was:\n{stdout}" 690 | ) 691 | self.close(err_str=err) 692 | return self.command_list 693 | 694 | def limits_rollback(self): 695 | """Function to revert config changes made to remote host 696 | :returns None: 697 | """ 698 | logger.info("entering limits_rollback()") 699 | rollback_cmds = "".join(self.command_list) 700 | rollback_cmds = re.sub("deactivate", "activate", rollback_cmds) 701 | result, stdout = self.ssh_cmd( 702 | f'cli -c "edit;{rollback_cmds}commit and-quit"', 703 | exitcode=False, 704 | timeout=60, 705 | ) 706 | # cli always returns true so can't use exitcode 707 | if re.search(r"commit complete\r\nExiting configuration mode", stdout): 708 | print("configuration changes made have been reverted") 709 | self.ssh_cmd( 710 | "logger 'splitcopy has reverted config changes'", 711 | exitcode=False, 712 | ) 713 | else: 714 | print( 715 | "Error: failed to revert the configuration changes. " 716 | f"output was:\n{stdout}" 717 | ) 718 | 719 | def remote_cleanup(self, remote_dir=None, remote_file=None, silent=False): 720 | """Function that deletes the tmp directory on remote host 721 | :param remote_dir: 722 | :type string: 723 | :param remote_file: 724 | :type string: 725 | :param silent: determines whether we announce the dir deletion 726 | :type: bool 727 | :return None: 728 | """ 729 | logger.info("entering remote_cleanup()") 730 | result = False 731 | if remote_dir: 732 | self.remote_dir = remote_dir 733 | if remote_file: 734 | self.remote_file = remote_file 735 | if not silent: 736 | print(f"\r{pad_string('deleting remote tmp directory...')}") 737 | if self.remote_tmpdir is None: 738 | if self.copy_op == "get": 739 | result, stdout = self.ssh_cmd( 740 | f"rm -rf /var/tmp/splitcopy_{self.remote_file}.*" 741 | ) 742 | else: 743 | result, stdout = self.ssh_cmd( 744 | f"rm -rf {self.remote_dir}/splitcopy_{self.remote_file}.*" 745 | ) 746 | else: 747 | result, stdout = self.ssh_cmd(f"rm -rf {self.remote_tmpdir}") 748 | if not result and not silent: 749 | print( 750 | f"unable to delete the tmp directory {self.remote_tmpdir} on remote host, " 751 | "delete it manually" 752 | ) 753 | self.rm_remote_tmp = False 754 | return result 755 | 756 | def enter_shell(self): 757 | """in order to drop into shell from cli mode, a pty and interactive shell are required 758 | :return result: 759 | :type bool: 760 | """ 761 | try: 762 | # request a channel 763 | self.sshshell.channel_open() 764 | # request a pty and an interactive shell session 765 | self.sshshell.invoke_shell() 766 | # remove the welcome message from the socket 767 | self.sshshell.stdout_read(timeout=30) 768 | except SSHException as err: 769 | self.close(err_str=err) 770 | # enter shell mode 771 | result, stdout = self.ssh_cmd("start shell", exitcode=False) 772 | return result 773 | 774 | def ssh_cmd( 775 | self, cmd, timeout=30, exitcode=True, combine=False, retry=True, count=0 776 | ): 777 | """wrapper around functions that send a cmd to a remote host. 778 | which function gets called depends on whether an interactive shell is in use. 779 | if exitcode is True will check its exit status 780 | :param cmd: cmd to run on remote host 781 | :type string: 782 | :param timeout: amount of time before timeout is raised 783 | :type float: 784 | :param exitcode: toggles whether to check for exit status or not 785 | :type bool: 786 | :return result: whether successful or not 787 | :type bool: 788 | :return stdout: the output of the command 789 | :type string: 790 | """ 791 | result = False 792 | stdout = "" 793 | logger.debug(cmd) 794 | try: 795 | if self.use_shell: 796 | result, stdout = self.sshshell.shell_cmd(cmd, timeout, exitcode) 797 | else: 798 | result, stdout = self.sshshell.exec_cmd(cmd, timeout, combine) 799 | logger.debug(result) 800 | except TimeoutError: 801 | count += 1 802 | timeout = timeout * 2 803 | if self.use_shell: 804 | # channel is now unusable, close it and open a new channel 805 | self.sshshell.close_channel() 806 | self.enter_shell() 807 | if retry: 808 | if count == 3: 809 | self.close(err_str="cmd timed out") 810 | print( 811 | f"cmd '{cmd}' timed out, retrying with a timeout of {timeout} secs" 812 | ) 813 | result, stdout = self.ssh_cmd( 814 | cmd, 815 | timeout=timeout, 816 | exitcode=exitcode, 817 | combine=combine, 818 | count=count, 819 | ) 820 | except SSHException as err: 821 | self.close(err_str=f"ssh exception '{err}' raised while running '{cmd}'") 822 | except OSError as err: 823 | self.close(err_str=f"OSError exception raised while running '{cmd}'") 824 | return result, stdout 825 | -------------------------------------------------------------------------------- /src/splitcopy/splitcopy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ Copyright (c) 2018, Juniper Networks, Inc 3 | All rights reserved 4 | This SOFTWARE is licensed under the LICENSE provided in the 5 | ./LICENCE file. By downloading, installing, copying, or otherwise 6 | using the SOFTWARE, you agree to be bound by the terms of that 7 | LICENSE. 8 | """ 9 | 10 | # stdlib 11 | import argparse 12 | import datetime 13 | import getpass 14 | import logging 15 | import os 16 | import re 17 | import signal 18 | import socket 19 | 20 | from splitcopy.get import SplitCopyGet 21 | from splitcopy.put import SplitCopyPut 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def parse_args(): 27 | """parses arguments 28 | :return args: 29 | :type Namespace: 30 | """ 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "source", help="either or user@: or :" 34 | ) 35 | parser.add_argument( 36 | "target", help="either or user@: or :" 37 | ) 38 | parser.add_argument( 39 | "--pwd", nargs=1, help="password to authenticate on remote host" 40 | ) 41 | parser.add_argument( 42 | "--ssh_key", 43 | nargs=1, 44 | help="path to ssh private key (only if in non-default location)", 45 | ) 46 | parser.add_argument( 47 | "--scp", action="store_true", help="use scp to copy files instead of ftp" 48 | ) 49 | parser.add_argument( 50 | "--noverify", 51 | action="store_true", 52 | help="skip sha hash comparison of src and dst file", 53 | ) 54 | parser.add_argument( 55 | "--split_timeout", 56 | nargs=1, 57 | help="time to wait for remote file split operation to complete, default 120s", 58 | ) 59 | parser.add_argument( 60 | "--ssh_port", 61 | nargs=1, 62 | help="ssh port number to connect to", 63 | ) 64 | parser.add_argument( 65 | "--overwrite", 66 | action="store_true", 67 | help="if target file already exists, delete it prior to transfer", 68 | ) 69 | parser.add_argument("--nocurses", action="store_true", help="disable curses output") 70 | parser.add_argument("--log", nargs=1, help="log level, eg DEBUG") 71 | args = parser.parse_args() 72 | return args 73 | 74 | 75 | def windows_path(arg_name): 76 | """determine if argument is a windows/UNC path 77 | :param arg_name: 78 | :type string: 79 | :return bool: 80 | """ 81 | result = False 82 | if os.path.splitdrive(arg_name)[0]: 83 | # arg is a windows drive/UNC sharepoint 84 | # splitdrive() only returns a value at index 0 on windows systems 85 | logger.debug(f"'{arg_name}' is a windows path") 86 | result = True 87 | return result 88 | 89 | 90 | def parse_src_arg_as_local(source): 91 | """attempts to open path provided on local filesystem 92 | :param arg_name: 93 | :type string: 94 | :return local_file: 95 | :type string: 96 | :return local_dir: 97 | :type string: 98 | :return local_path: 99 | :type string: 100 | """ 101 | local_file = "" 102 | local_dir = "" 103 | local_path = os.path.abspath(os.path.expanduser(source)) 104 | with open(local_path, "rb"): 105 | local_file = os.path.basename(local_path) 106 | local_dir = os.path.dirname(local_path) 107 | return local_file, local_dir, local_path 108 | 109 | 110 | def parse_arg_as_remote(arg_name): 111 | """parses argument to determine if it's on a remote host. 112 | If successful, returns the user, host and remote path. 113 | :param arg_name: 114 | :type string: 115 | :return user: 116 | :type string: 117 | :return host: 118 | :type string: 119 | :return remote_path: 120 | :type string: 121 | """ 122 | user = "" 123 | host = "" 124 | remote_path = "" 125 | # greedy match for @ and : 126 | user_at_host = re.compile(r"(.+)@(.+):(.+)*") 127 | # greedy match for : 128 | host_only = re.compile(r"(.+):(.+)*") 129 | # not impossible that remote_path could contain ':' or '@' char 130 | # or username could contain '@' char 131 | # above is too simplistic to deal with these edge use cases 132 | if re.match(user_at_host, arg_name): 133 | # usernames ought to consist of [a-z_][a-z0-9_-]*[$]? 134 | # according to useradd man page, but the use of chars such as '@' 135 | # are not enforced which affects the pattern match 136 | # hostname does thankfully enforce '@' and ':' as invalid 137 | regex = re.match(user_at_host, arg_name) 138 | user = regex.group(1) 139 | host = regex.group(2) 140 | remote_path = regex.group(3) or "" 141 | elif re.match(host_only, arg_name): 142 | user = getpass.getuser() 143 | regex = re.match(host_only, arg_name) 144 | host = regex.group(1) 145 | remote_path = regex.group(2) or "" 146 | else: 147 | raise ValueError( 148 | f"'{arg_name}' is not in the correct format " 149 | "@: or :" 150 | ) 151 | return user, host, remote_path 152 | 153 | 154 | def open_ssh_keyfile(path): 155 | """tests whether the provided ssh keyfile can be read 156 | :param path: 157 | :type string: 158 | :return bool: 159 | """ 160 | result = False 161 | ssh_key = os.path.abspath(os.path.expanduser(path)) 162 | with open(ssh_key, "r") as key: 163 | result = True 164 | return result 165 | 166 | 167 | def handlesigint(sigint, stack): 168 | """called upon sigINT, effectively suppresses KeyboardInterrupt 169 | :param sigint: 170 | :type int: 171 | :param stack: 172 | :type frame object: 173 | :raises SystemExit: 174 | :return None: 175 | """ 176 | raise SystemExit 177 | 178 | 179 | def process_args(source, target): 180 | """determines the copy operation to perform, paths, username and host 181 | :param source: 182 | :type string: 183 | :param target: 184 | :type string: 185 | :returns result: 186 | :type dict: 187 | :raises SystemExit: 188 | """ 189 | user = "" 190 | host = "" 191 | remote_path = "" 192 | local_dir = "" 193 | local_file = "" 194 | local_path = "" 195 | copy_op = "" 196 | source_in_remote_format = False 197 | target_in_remote_format = False 198 | 199 | try: 200 | local_file, local_dir, local_path = parse_src_arg_as_local(source) 201 | except FileNotFoundError: 202 | # expected if this is a remote path 203 | pass 204 | except PermissionError: 205 | raise SystemExit( 206 | f"'{source}' exists, but file cannot be read due to a permissions error" 207 | ) 208 | except IsADirectoryError: 209 | raise SystemExit(f"'{source}' is a directory, not a file") 210 | 211 | try: 212 | user, host, remote_path = parse_arg_as_remote(source) 213 | if not windows_path(source): 214 | source_in_remote_format = True 215 | except ValueError as err: 216 | pass 217 | 218 | try: 219 | user, host, remote_path = parse_arg_as_remote(target) 220 | if not windows_path(target): 221 | target_in_remote_format = True 222 | except ValueError as err: 223 | pass 224 | 225 | if source_in_remote_format and target_in_remote_format: 226 | raise SystemExit( 227 | f"both '{source}' and '{target}' are remote paths - " 228 | "one path must be local, the other remote" 229 | ) 230 | elif local_file and target_in_remote_format: 231 | copy_op = "put" 232 | elif local_file and not target_in_remote_format: 233 | raise SystemExit( 234 | f"file '{source}' found, remote path '{target}' is not in the correct format [user@]host:path" 235 | ) 236 | elif not local_file and target_in_remote_format: 237 | raise SystemExit(f"'{source}' file not found") 238 | elif not local_file and not source_in_remote_format: 239 | raise SystemExit(f"'{source}' file not found") 240 | elif not local_file and source_in_remote_format and not remote_path: 241 | raise SystemExit(f"'{source}' does not specify a filepath") 242 | elif not local_file and source_in_remote_format and not target_in_remote_format: 243 | copy_op = "get" 244 | 245 | try: 246 | host = socket.gethostbyname(host) 247 | except socket.gaierror as exc: 248 | raise SystemExit( 249 | f"Could not resolve hostname '{host}', resolution failed" 250 | ) from exc 251 | 252 | result = { 253 | "user": user, 254 | "host": host, 255 | "remote_path": remote_path, 256 | "local_dir": local_dir, 257 | "local_file": local_file, 258 | "local_path": local_path, 259 | "copy_op": copy_op, 260 | "target": target, 261 | } 262 | return result 263 | 264 | 265 | def main(get_class=SplitCopyGet, put_class=SplitCopyPut): 266 | """body of script 267 | :param get_class: 268 | :type class: 269 | :param put_class: 270 | :type class: 271 | :return bool: 272 | """ 273 | signal.signal(signal.SIGINT, handlesigint) 274 | start_time = datetime.datetime.now() 275 | 276 | args = parse_args() 277 | if not args.log: 278 | loglevel = "WARNING" 279 | else: 280 | loglevel = args.log[0] 281 | 282 | numeric_level = getattr(logging, loglevel.upper(), None) 283 | if not isinstance(numeric_level, int): 284 | raise ValueError(f"Invalid log level: {loglevel}") 285 | logging.basicConfig( 286 | format="%(asctime)s %(name)s %(lineno)s %(funcName)s %(levelname)s:%(message)s", 287 | level=numeric_level, 288 | ) 289 | 290 | passwd = "" 291 | ssh_key = "" 292 | ssh_port = 22 293 | copy_proto = "" 294 | noverify = args.noverify 295 | use_curses = True 296 | overwrite = args.overwrite 297 | 298 | if args.nocurses: 299 | use_curses = False 300 | 301 | if args.pwd is not None: 302 | passwd = args.pwd[0] 303 | 304 | if not args.scp: 305 | copy_proto = "ftp" 306 | else: 307 | copy_proto = "scp" 308 | 309 | if args.ssh_key is not None: 310 | ssh_key = args.ssh_key[0] 311 | try: 312 | open_ssh_keyfile(ssh_key) 313 | except FileNotFoundError as exc: 314 | raise SystemExit(f"'{ssh_key}' file does not exist") from exc 315 | except PermissionError as exc: 316 | raise SystemExit( 317 | f"'{ssh_key}' exists, but file cannot be read due to a permissions error" 318 | ) from exc 319 | except IsADirectoryError as exc: 320 | raise SystemExit(f"'{ssh_key}' is a directory, not a file") from exc 321 | 322 | if args.ssh_port is not None: 323 | try: 324 | ssh_port = int(args.ssh_port[0]) 325 | except ValueError as exc: 326 | raise SystemExit("ssh_port must be an integer") from exc 327 | 328 | split_timeout = 120 329 | if args.split_timeout is not None: 330 | try: 331 | split_timeout = int(args.split_timeout[0]) 332 | except ValueError as exc: 333 | raise SystemExit("split_timeout must be an integer") from exc 334 | 335 | kwargs = process_args(args.source, args.target) 336 | kwargs["passwd"] = passwd 337 | kwargs["ssh_key"] = ssh_key 338 | kwargs["ssh_port"] = ssh_port 339 | kwargs["copy_proto"] = copy_proto 340 | kwargs["noverify"] = noverify 341 | kwargs["split_timeout"] = split_timeout 342 | kwargs["use_curses"] = use_curses 343 | kwargs["overwrite"] = overwrite 344 | logger.info(kwargs) 345 | 346 | if kwargs["copy_op"] == "get": 347 | splitcopyget = get_class(**kwargs) 348 | loop_start, loop_end = splitcopyget.get() 349 | else: 350 | splitcopyput = put_class(**kwargs) 351 | loop_start, loop_end = splitcopyput.put() 352 | 353 | # and we are done... 354 | end_time = datetime.datetime.now() 355 | time_delta = end_time - start_time 356 | transfer_delta = loop_end - loop_start 357 | print(f"data transfer = {transfer_delta}\ntotal runtime = {time_delta}") 358 | return True 359 | -------------------------------------------------------------------------------- /src/splitcopy/tests/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.insert(0, "../src/splitcopy") 4 | -------------------------------------------------------------------------------- /src/splitcopy/tests/test_ftp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tempfile import NamedTemporaryFile 3 | 4 | from pytest import MonkeyPatch 5 | from splitcopy.ftp import FTP 6 | 7 | 8 | class MockLogger: 9 | def __init__(self): 10 | self.level = 10 11 | 12 | def getEffectiveLevel(self): 13 | return self.level 14 | 15 | def removeHandler(self, hdlr): 16 | pass 17 | 18 | def addHandler(self, hdlr): 19 | pass 20 | 21 | 22 | class mockFTP: 23 | def __init__(self, **kwargs): 24 | return None 25 | 26 | 27 | def mockgetlogger(name=None): 28 | return MockLogger() 29 | 30 | 31 | def init_ftp(file_size=None, progress=None, **kwargs): 32 | return FTP(file_size, progress, **kwargs) 33 | 34 | 35 | class TestFTP: 36 | def test_context_manager(self, monkeypatch: MonkeyPatch): 37 | def quit(self): 38 | pass 39 | 40 | monkeypatch.setattr(logging, "getLogger", mockgetlogger) 41 | monkeypatch.setattr("logging.Logger", MockLogger) 42 | monkeypatch.setattr("ftplib.FTP", mockFTP) 43 | monkeypatch.setattr(FTP, "quit", quit) 44 | ftp = init_ftp() 45 | with ftp as foo: 46 | result = True 47 | expected = True 48 | assert expected == result 49 | 50 | def test_put(self, monkeypatch: MonkeyPatch): 51 | def storbinary(cmd, fp, callback, rest): 52 | callback(b"foobar" * 10) 53 | 54 | class MockProgress: 55 | def __init__(self): 56 | pass 57 | 58 | def report_progress(self, file_name, file_size, sent): 59 | pass 60 | 61 | monkeypatch.setattr("ftplib.FTP", mockFTP) 62 | mockprog = MockProgress() 63 | ftp = init_ftp(file_size=100, progress=mockprog) 64 | monkeypatch.setattr(ftp, "storbinary", storbinary) 65 | remote_file = "/var/tmp/remote" 66 | restart_marker = 10 67 | with NamedTemporaryFile() as tmpfile: 68 | local_file = tmpfile.name 69 | result = ftp.put(local_file, remote_file, restart_marker) 70 | expected = None 71 | assert expected == result 72 | 73 | def test_get(self, monkeypatch: MonkeyPatch): 74 | def retrbinary(cmd, callback, rest): 75 | callback(b"foobar" * 10) 76 | 77 | class MockProgress: 78 | def __init__(self): 79 | pass 80 | 81 | def report_progress(self, file_name, file_size, sent): 82 | pass 83 | 84 | monkeypatch.setattr("ftplib.FTP", mockFTP) 85 | mockprog = MockProgress() 86 | ftp = init_ftp(file_size=100, progress=mockprog) 87 | monkeypatch.setattr(ftp, "retrbinary", retrbinary) 88 | remote_file = "/var/tmp/remote" 89 | restart_marker = 10 90 | with NamedTemporaryFile() as tmpfile: 91 | local_file = tmpfile.name 92 | result = ftp.get(remote_file, local_file, restart_marker) 93 | expected = None 94 | assert expected == result 95 | -------------------------------------------------------------------------------- /src/splitcopy/tests/test_progress.py: -------------------------------------------------------------------------------- 1 | from curses import error as curses_error 2 | from threading import Thread 3 | 4 | from pytest import MonkeyPatch 5 | from splitcopy.progress import ( 6 | Progress, 7 | abandon_curses, 8 | bytes_display, 9 | percent_val, 10 | prepare_curses, 11 | progress_bar, 12 | ) 13 | 14 | 15 | def get_chunk_data(): 16 | chunks = [ 17 | ["chunk0", 1024], 18 | ["chunk1", 1024], 19 | ["chunk2", 1024], 20 | ["chunk3", 1024], 21 | ["chunk4", 1024], 22 | ["chunk5", 1024], 23 | ["chunk6", 1024], 24 | ["chunk7", 1024], 25 | ["chunk8", 1024], 26 | ["chunk9", 1024], 27 | ["chunk10", 1024], 28 | ["chunk11", 1024], 29 | ["chunk12", 1024], 30 | ["chunk13", 1024], 31 | ["chunk14", 1024], 32 | ["chunk15", 1024], 33 | ["chunk16", 1024], 34 | ["chunk17", 1024], 35 | ["chunk18", 1024], 36 | ["chunk19", 1024], 37 | ["chunk20", 1024], 38 | ] 39 | total_file_size = 21504 40 | return total_file_size, chunks 41 | 42 | 43 | class Test_Progress: 44 | def test_percent_val(self): 45 | expected = 10 46 | result = percent_val(100, 10) 47 | assert result == expected 48 | 49 | def test_progress_bar(self): 50 | expected = "[#####" + " " * 45 + "]" 51 | result = progress_bar(10) 52 | assert result == expected 53 | 54 | def test_bytes_display_kb(self): 55 | expected = (0.48828125, "KB") 56 | result = bytes_display(500) 57 | assert result == expected 58 | 59 | def test_bytes_display_mb(self): 60 | expected = (4.76837158203125, "MB") 61 | result = bytes_display(5000000) 62 | assert result == expected 63 | 64 | def test_bytes_display_gb(self): 65 | expected = (4.656612873077393, "GB") 66 | result = bytes_display(5000000000) 67 | assert result == expected 68 | 69 | def test_prepare_curses(self, monkeypatch: MonkeyPatch): 70 | expected = True 71 | monkeypatch.setattr("curses.initscr", lambda: True) 72 | monkeypatch.setattr("curses.noecho", lambda: True) 73 | monkeypatch.setattr("curses.cbreak", lambda: True) 74 | result = prepare_curses() 75 | assert result == expected 76 | 77 | def test_abandon_curses_fail(self, monkeypatch: MonkeyPatch): 78 | def nocbreak(): 79 | raise AttributeError 80 | 81 | monkeypatch.setattr("curses.nocbreak", nocbreak) 82 | result = abandon_curses() 83 | expected = None 84 | assert result == expected 85 | 86 | def test_abandon_curses(self, monkeypatch: MonkeyPatch): 87 | expected = None 88 | monkeypatch.setattr("curses.nocbreak", lambda: True) 89 | monkeypatch.setattr("curses.echo", lambda: True) 90 | monkeypatch.setattr("curses.endwin", lambda: True) 91 | result = abandon_curses() 92 | assert result == expected 93 | 94 | def test_check_term_size_nocurses(self): 95 | expected = False 96 | progress = Progress() 97 | result = progress.check_term_size(False) 98 | assert result == expected 99 | 100 | def test_check_term_size_is_too_big(self): 101 | expected = False 102 | progress = Progress() 103 | total_file_size, chunks = get_chunk_data() 104 | progress.add_chunks(total_file_size, chunks) 105 | # shutil.get_terminal_size() will by default return 80,24 106 | # as self.chunks + 4 > 24 it should fail 107 | result = progress.check_term_size(True) 108 | assert result == expected 109 | 110 | def test_check_term_size_is_ok(self): 111 | expected = True 112 | # make len(chunks) == 20 113 | total_file_size, chunks = get_chunk_data() 114 | del chunks[0] 115 | progress = Progress() 116 | progress.add_chunks(total_file_size, chunks) 117 | # shutil.get_terminal_size() will by default return 80,24 118 | # as self.chunks + 4 <= 24 it should return True 119 | result = progress.check_term_size(True) 120 | assert result == expected 121 | 122 | def test_initiate_timer_thread(self): 123 | expected = None 124 | progress = Progress() 125 | total_file_size, chunks = get_chunk_data() 126 | progress.add_chunks(total_file_size, chunks) 127 | progress.stop_timer = True 128 | result = progress.initiate_timer_thread() 129 | assert result == expected 130 | 131 | def test_refresh_timer_stop(self): 132 | expected = True 133 | progress = Progress() 134 | total_file_size, chunks = get_chunk_data() 135 | progress.add_chunks(total_file_size, chunks) 136 | self.stop_timer = True 137 | timer = Thread( 138 | name="testing_refresh_timer", 139 | target=progress.refresh_timer, 140 | args=(1, lambda: self.stop_timer), 141 | ) 142 | timer.start() 143 | timer.join() # loop must have exited or this wouldn't sucessfully stop the thread 144 | result = timer._is_stopped 145 | assert result == expected 146 | 147 | def test_refresh_timer_nocurses(self, monkeypatch: MonkeyPatch): 148 | expected = True 149 | 150 | def rates_update(self): 151 | pass 152 | 153 | def totals_update(self): 154 | pass 155 | 156 | monkeypatch.setattr(Progress, "rates_update", rates_update) 157 | monkeypatch.setattr(Progress, "totals_update", totals_update) 158 | progress = Progress() 159 | total_file_size, chunks = get_chunk_data() 160 | progress.add_chunks(total_file_size, chunks) 161 | self.stop_timer = False 162 | timer = Thread( 163 | name="testing_refresh_timer", 164 | target=progress.refresh_timer, 165 | args=(1, lambda: self.stop_timer), 166 | ) 167 | timer.start() 168 | self.stop_timer = True 169 | timer.join() 170 | result = timer._is_stopped 171 | assert result == expected 172 | 173 | def test_refresh_timer_curses(self, monkeypatch: MonkeyPatch): 174 | expected = True 175 | 176 | def rates_update(self): 177 | pass 178 | 179 | def totals_update(self): 180 | pass 181 | 182 | def update_screen_contents(self): 183 | pass 184 | 185 | def print_error(self, error): 186 | pass 187 | 188 | monkeypatch.setattr(Progress, "rates_update", rates_update) 189 | monkeypatch.setattr(Progress, "totals_update", totals_update) 190 | monkeypatch.setattr(Progress, "update_screen_contents", update_screen_contents) 191 | monkeypatch.setattr(Progress, "print_error", print_error) 192 | progress = Progress() 193 | total_file_size, chunks = get_chunk_data() 194 | progress.add_chunks(total_file_size, chunks) 195 | progress.curses = True 196 | self.stop_timer = False 197 | timer = Thread( 198 | name="testing_refresh_timer", 199 | target=progress.refresh_timer, 200 | args=(1, lambda: self.stop_timer), 201 | ) 202 | timer.start() 203 | self.stop_timer = True 204 | timer.join() 205 | result = timer._is_stopped 206 | assert result == expected 207 | 208 | def test_stop_timer_thread_fail(self): 209 | class timer: 210 | def join(): 211 | raise AttributeError 212 | 213 | progress = Progress() 214 | total_file_size, chunks = get_chunk_data() 215 | progress.add_chunks(total_file_size, chunks) 216 | progress.timer = timer 217 | result = progress.stop_timer_thread() 218 | expected = None 219 | assert result == expected 220 | 221 | def test_stop_timer_thread(self): 222 | expected = True 223 | 224 | def loop_thread(self, stop): 225 | while True: 226 | if stop(): 227 | break 228 | 229 | progress = Progress() 230 | total_file_size, chunks = get_chunk_data() 231 | progress.add_chunks(total_file_size, chunks) 232 | progress.timer = Thread( 233 | name="testing_stop_timer_thread", 234 | target=loop_thread, 235 | args=(1, lambda: progress.stop_timer), 236 | ) 237 | progress.timer.start() 238 | progress.stop_timer_thread() 239 | result = progress.timer._is_stopped 240 | assert result == expected 241 | 242 | def test_report_progress_complete(self, monkeypatch: MonkeyPatch): 243 | expected = True 244 | 245 | def file_percentage_update(self, file_name, file_size, sent): 246 | pass 247 | 248 | file_name = "chunk20" 249 | file_size = 20 250 | sent = 20 251 | monkeypatch.setattr(Progress, "file_percentage_update", file_percentage_update) 252 | progress = Progress() 253 | total_file_size, chunks = get_chunk_data() 254 | progress.add_chunks(total_file_size, chunks) 255 | progress.report_progress(file_name, file_size, sent) 256 | result = progress.files[file_name]["complete"] 257 | assert result == expected 258 | 259 | def test_report_progress_incomplete(self, monkeypatch: MonkeyPatch): 260 | expected = False 261 | 262 | def file_percentage_update(self, file_name, file_size, sent): 263 | pass 264 | 265 | file_name = "chunk20" 266 | file_size = 20 267 | sent = 19 268 | monkeypatch.setattr(Progress, "file_percentage_update", file_percentage_update) 269 | progress = Progress() 270 | total_file_size, chunks = get_chunk_data() 271 | progress.add_chunks(total_file_size, chunks) 272 | progress.report_progress(file_name, file_size, sent) 273 | result = progress.files[file_name]["complete"] 274 | assert result == expected 275 | 276 | def test_disp_total_progress(self, capsys, monkeypatch: MonkeyPatch): 277 | expected = "\rfoo" 278 | 279 | def total_progress_str(self): 280 | return "foo" 281 | 282 | monkeypatch.setattr(Progress, "total_progress_str", total_progress_str) 283 | progress = Progress() 284 | total_file_size, chunks = get_chunk_data() 285 | progress.add_chunks(total_file_size, chunks) 286 | progress.disp_total_progress() 287 | captured = capsys.readouterr() 288 | result = captured.out 289 | assert result == expected 290 | 291 | def test_file_percentage_update(self, monkeypatch: MonkeyPatch): 292 | expected = 10 293 | file_name = "chunk20" 294 | file_size = 1024 295 | sent = 102 296 | 297 | def percent_val(file_size, sent): 298 | return 10 299 | 300 | monkeypatch.setattr("splitcopy.progress.percent_val", percent_val) 301 | progress = Progress() 302 | total_file_size, chunks = get_chunk_data() 303 | progress.add_chunks(total_file_size, chunks) 304 | progress.file_percentage_update(file_name, file_size, sent) 305 | result = progress.files[file_name]["percent_done"] 306 | assert result == expected 307 | 308 | def test_totals_update(self, monkeypatch: MonkeyPatch): 309 | expected = (1024, 1, 4) 310 | 311 | def percent_val(file_size, sent): 312 | return 4 313 | 314 | monkeypatch.setattr("splitcopy.progress.percent_val", percent_val) 315 | progress = Progress() 316 | total_file_size, chunks = get_chunk_data() 317 | progress.add_chunks(total_file_size, chunks) 318 | progress.files["chunk0"]["sent_bytes"] = 1024 319 | progress.files["chunk0"]["complete"] = 1 320 | progress.totals_update() 321 | sum_bytes_sent = progress.totals["sum_bytes_sent"] 322 | sum_completed = progress.totals["sum_completed"] 323 | percent_done = progress.totals["percent_done"] 324 | result = (sum_bytes_sent, sum_completed, percent_done) 325 | assert result == expected 326 | 327 | def test_total_progress_str(self, monkeypatch: MonkeyPatch): 328 | expected = "0% done 0.0KB/0.0KB 0.0KB/s (0/21 chunks completed)" 329 | 330 | def bytes_display(byte_val): 331 | return 0.0, "KB" 332 | 333 | progress = Progress() 334 | total_file_size, chunks = get_chunk_data() 335 | progress.add_chunks(total_file_size, chunks) 336 | monkeypatch.setattr("splitcopy.progress.bytes_display", bytes_display) 337 | result = progress.total_progress_str() 338 | assert result == expected 339 | 340 | def test_rates_update(self): 341 | expected = 1024 342 | progress = Progress() 343 | total_file_size, chunks = get_chunk_data() 344 | progress.add_chunks(total_file_size, chunks) 345 | progress.files["chunk0"]["sent_bytes"] = 1024 346 | progress.rates_update() 347 | result = progress.totals["sum_bytes_per_sec"] 348 | assert result == expected 349 | 350 | def test_zero_file_stats(self): 351 | expected = 0 352 | progress = Progress() 353 | total_file_size, chunks = get_chunk_data() 354 | progress.add_chunks(total_file_size, chunks) 355 | progress.files["chunk0"]["sent_bytes"] = 1024 356 | progress.zero_file_stats("chunk0") 357 | result = progress.files["chunk0"]["sent_bytes"] 358 | assert result == expected 359 | 360 | def test_update_screen_contents(self, capsys, monkeypatch: MonkeyPatch): 361 | expected = ( 362 | f"['chunk0 [{' ' * 50}] 0% 0.0KB 0.0KB/s', " 363 | f"'chunk1 [{' ' * 50}] 0% 0.0KB 0.0KB/s', " 364 | f"'chunk2 [{' ' * 50}] 0% 0.0KB 0.0KB/s', " 365 | "'', " 366 | "'0% done 0.0KB/0.0KB 0.0KB/s (0/3 chunks completed)', " 367 | "'', " 368 | "'', " 369 | "'']\n" 370 | ) 371 | 372 | def progress_bar(percent_done): 373 | return f"[{' ' * 50}]" 374 | 375 | def bytes_display(num_bytes): 376 | return 0.0, "KB" 377 | 378 | def pad_string(foo): 379 | return foo 380 | 381 | def total_progress_str(self): 382 | return "0% done 0.0KB/0.0KB 0.0KB/s (0/3 chunks completed)" 383 | 384 | def redraw_screen(self, txt_lines): 385 | print(txt_lines) 386 | 387 | monkeypatch.setattr("splitcopy.progress.progress_bar", progress_bar) 388 | monkeypatch.setattr("splitcopy.progress.bytes_display", bytes_display) 389 | monkeypatch.setattr("splitcopy.progress.pad_string", pad_string) 390 | monkeypatch.setattr(Progress, "total_progress_str", total_progress_str) 391 | monkeypatch.setattr(Progress, "redraw_screen", redraw_screen) 392 | progress = Progress() 393 | total_file_size, chunks = get_chunk_data() 394 | del chunks[3:21] 395 | progress.add_chunks(total_file_size, chunks) 396 | progress.update_screen_contents() 397 | captured = capsys.readouterr() 398 | result = captured.out 399 | assert result == expected 400 | 401 | def test_update_screen_contents_fail(self, monkeypatch: MonkeyPatch): 402 | def progress_bar(percent_done): 403 | return f"[{' ' * 50}]" 404 | 405 | def bytes_display(num_bytes): 406 | return 0.0, "KB" 407 | 408 | def pad_string(foo): 409 | return foo 410 | 411 | def total_progress_str(self): 412 | return "0% done 0.0KB/0.0KB 0.0KB/s (0/3 chunks completed)" 413 | 414 | def redraw_screen(self, txt_lines): 415 | raise curses_error 416 | 417 | monkeypatch.setattr("splitcopy.progress.progress_bar", progress_bar) 418 | monkeypatch.setattr("splitcopy.progress.bytes_display", bytes_display) 419 | monkeypatch.setattr("splitcopy.progress.pad_string", pad_string) 420 | monkeypatch.setattr(Progress, "total_progress_str", total_progress_str) 421 | monkeypatch.setattr(Progress, "redraw_screen", redraw_screen) 422 | progress = Progress() 423 | total_file_size, chunks = get_chunk_data() 424 | progress.add_chunks(total_file_size, chunks) 425 | progress.update_screen_contents() 426 | result = progress.curses 427 | expected = False 428 | assert result == expected 429 | 430 | def test_update_screen_contents_longfilename( 431 | self, capsys, monkeypatch: MonkeyPatch 432 | ): 433 | expected = ( 434 | f"['somelo..e0 [{' ' * 50}] 0% 0.0KB 0.0KB/s', " 435 | f"'somelo..e1 [{' ' * 50}] 0% 0.0KB 0.0KB/s', " 436 | f"'somelo..e2 [{' ' * 50}] 0% 0.0KB 0.0KB/s', " 437 | "'', " 438 | "'0% done 0.0KB/0.0KB 0.0KB/s (0/3 chunks completed)', " 439 | "'', " 440 | "'', " 441 | "'']\n" 442 | ) 443 | 444 | def progress_bar(percent_done): 445 | return f"[{' ' * 50}]" 446 | 447 | def bytes_display(num_bytes): 448 | return 0.0, "KB" 449 | 450 | def pad_string(foo): 451 | return foo 452 | 453 | def total_progress_str(self): 454 | return "0% done 0.0KB/0.0KB 0.0KB/s (0/3 chunks completed)" 455 | 456 | def redraw_screen(self, txt_lines): 457 | print(txt_lines) 458 | 459 | chunks = [ 460 | ["somelongname0", 1024], 461 | ["somelongname1", 1024], 462 | ["somelongname2", 1024], 463 | ] 464 | total_file_size = 3072 465 | monkeypatch.setattr("splitcopy.progress.progress_bar", progress_bar) 466 | monkeypatch.setattr("splitcopy.progress.bytes_display", bytes_display) 467 | monkeypatch.setattr("splitcopy.progress.pad_string", pad_string) 468 | monkeypatch.setattr(Progress, "total_progress_str", total_progress_str) 469 | monkeypatch.setattr(Progress, "redraw_screen", redraw_screen) 470 | progress = Progress() 471 | progress.add_chunks(total_file_size, chunks) 472 | progress.update_screen_contents() 473 | captured = capsys.readouterr() 474 | result = captured.out 475 | assert result == expected 476 | 477 | def test_print_error_nocurses(self, capsys, monkeypatch: MonkeyPatch): 478 | expected = "\nfoo\n" 479 | progress = Progress() 480 | progress.print_error("foo") 481 | captured = capsys.readouterr() 482 | result = captured.out 483 | assert result == expected 484 | 485 | def test_print_error_curses(self, monkeypatch: MonkeyPatch): 486 | expected = ["", "", "", "foo"] 487 | 488 | def pad_string(foo): 489 | return foo 490 | 491 | progress = Progress() 492 | progress.curses = True 493 | monkeypatch.setattr("splitcopy.progress.pad_string", pad_string) 494 | progress.print_error("foo") 495 | result = progress.error_list 496 | assert result == expected 497 | 498 | def test_start_progress(self, monkeypatch: MonkeyPatch): 499 | expected = True 500 | 501 | def prepare_curses(): 502 | return True 503 | 504 | def check_term_size(self, foo): 505 | return True 506 | 507 | def initiate_timer_thread(self): 508 | pass 509 | 510 | monkeypatch.setattr("splitcopy.progress.prepare_curses", prepare_curses) 511 | monkeypatch.setattr(Progress, "check_term_size", check_term_size) 512 | monkeypatch.setattr(Progress, "initiate_timer_thread", initiate_timer_thread) 513 | progress = Progress() 514 | total_file_size, chunks = get_chunk_data() 515 | progress.add_chunks(total_file_size, chunks) 516 | foo = True 517 | progress.start_progress(foo) 518 | result = progress.stdscr 519 | assert result == expected 520 | 521 | def test_stop_progress(self, monkeypatch: MonkeyPatch): 522 | expected = None 523 | 524 | def abandon_curses(): 525 | pass 526 | 527 | def stop_timer_thread(self): 528 | pass 529 | 530 | monkeypatch.setattr("splitcopy.progress.abandon_curses", abandon_curses) 531 | monkeypatch.setattr(Progress, "stop_timer_thread", stop_timer_thread) 532 | progress = Progress() 533 | total_file_size, chunks = get_chunk_data() 534 | progress.add_chunks(total_file_size, chunks) 535 | result = progress.stop_progress() 536 | assert result == expected 537 | 538 | def test_redraw_screen(self, capsys): 539 | expected = "[[0, 0, 'foo'], [1, 0, 'bar']]\n" 540 | 541 | class MockCurses: 542 | def __init__(self): 543 | self.lines = [] 544 | 545 | def addstr(self, y, x, str): 546 | self.lines.append([y, x, str]) 547 | 548 | def refresh(self): 549 | print(self.lines) 550 | 551 | progress = Progress() 552 | progress.stdscr = MockCurses() 553 | progress.redraw_screen(["foo", "bar"]) 554 | captured = capsys.readouterr() 555 | result = captured.out 556 | assert result == expected 557 | -------------------------------------------------------------------------------- /src/splitcopy/tests/test_splitcopy.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import re 3 | from argparse import Namespace 4 | from socket import gaierror 5 | 6 | import splitcopy.splitcopy as splitcopy 7 | from pytest import MonkeyPatch, raises 8 | 9 | 10 | class MockOpen: 11 | def __init__(self, file, perms, newline=None): 12 | self.data = ["abcdef0123456789"] 13 | 14 | def __enter__(self): 15 | return self 16 | 17 | def __exit__(self, *args): 18 | pass 19 | 20 | 21 | class MockSplitCopyGet: 22 | def __init__(self, **kwargs): 23 | pass 24 | 25 | def get(*args): 26 | return (datetime.datetime.now(), datetime.datetime.now()) 27 | 28 | 29 | class MockSplitCopyPut: 30 | def __init__(self, **kwargs): 31 | pass 32 | 33 | def put(*args): 34 | return (datetime.datetime.now(), datetime.datetime.now()) 35 | 36 | 37 | def test_parse_args(monkeypatch: MonkeyPatch): 38 | def parse_args(*args): 39 | return Namespace(source="/var/tmp/foo", target="192.168.64.7:/var/tmp/") 40 | 41 | monkeypatch.setattr("argparse.ArgumentParser.parse_args", parse_args) 42 | result = splitcopy.parse_args() 43 | assert result == Namespace(source="/var/tmp/foo", target="192.168.64.7:/var/tmp/") 44 | 45 | 46 | def test_windows_path(monkeypatch: MonkeyPatch): 47 | def splitdrive(*args): 48 | return ("C:", "\\windows\\system32") 49 | 50 | monkeypatch.setattr("os.path.splitdrive", splitdrive) 51 | result = splitcopy.windows_path("C:\\windows\\system32") 52 | assert result == True 53 | 54 | 55 | def test_windows_path_fail(monkeypatch: MonkeyPatch): 56 | def splitdrive(*args): 57 | return ("", "/var/tmp") 58 | 59 | monkeypatch.setattr("os.path.splitdrive", splitdrive) 60 | result = splitcopy.windows_path("/var/tmp") 61 | assert result == False 62 | 63 | 64 | def test_parse_src_arg_as_local_abspath(monkeypatch: MonkeyPatch): 65 | def expanduser(*args): 66 | return "/var/tmp/foo" 67 | 68 | def abspath(*args): 69 | return "/var/tmp/foo" 70 | 71 | monkeypatch.setattr("os.path.expanduser", expanduser) 72 | monkeypatch.setattr("os.path.abspath", abspath) 73 | monkeypatch.setattr("builtins.open", MockOpen) 74 | result = splitcopy.parse_src_arg_as_local("/var/tmp/foo") 75 | assert result == ("foo", "/var/tmp", "/var/tmp/foo") 76 | 77 | 78 | def test_parse_src_arg_as_local_tilda(monkeypatch: MonkeyPatch): 79 | def expanduser(*args): 80 | return "/homes/foobar/tmp/foo" 81 | 82 | def abspath(*args): 83 | return "/homes/foobar/tmp/foo" 84 | 85 | monkeypatch.setattr("os.path.expanduser", expanduser) 86 | monkeypatch.setattr("os.path.abspath", abspath) 87 | monkeypatch.setattr("builtins.open", MockOpen) 88 | result = splitcopy.parse_src_arg_as_local("~/tmp/foo") 89 | assert result == ("foo", "/homes/foobar/tmp", "/homes/foobar/tmp/foo") 90 | 91 | 92 | def test_parse_src_arg_as_local_fileonly(monkeypatch: MonkeyPatch): 93 | def expanduser(*args): 94 | return "foo" 95 | 96 | def abspath(*args): 97 | return "/homes/foobar/foo" 98 | 99 | monkeypatch.setattr("builtins.open", MockOpen) 100 | monkeypatch.setattr("os.path.expanduser", expanduser) 101 | monkeypatch.setattr("os.path.abspath", abspath) 102 | result = splitcopy.parse_src_arg_as_local("foo") 103 | assert result == ("foo", "/homes/foobar", "/homes/foobar/foo") 104 | 105 | 106 | def test_parse_src_arg_as_local_dotfile(monkeypatch: MonkeyPatch): 107 | def expanduser(*args): 108 | return "./foo" 109 | 110 | def abspath(*args): 111 | return "/homes/foobar/foo" 112 | 113 | monkeypatch.setattr("os.path.abspath", abspath) 114 | monkeypatch.setattr("os.path.expanduser", expanduser) 115 | monkeypatch.setattr("builtins.open", MockOpen) 116 | result = splitcopy.parse_src_arg_as_local("./foo") 117 | assert result == ("foo", "/homes/foobar", "/homes/foobar/foo") 118 | 119 | 120 | def test_parse_src_arg_as_local_permerror(monkeypatch: MonkeyPatch): 121 | class MockOpen2(MockOpen): 122 | def __enter__(self): 123 | raise PermissionError 124 | 125 | def expanduser(*args): 126 | return "/var/tmp/foo" 127 | 128 | def abspath(*args): 129 | return "/var/tmp/foo" 130 | 131 | monkeypatch.setattr("os.path.expanduser", expanduser) 132 | monkeypatch.setattr("os.path.abspath", abspath) 133 | monkeypatch.setattr("builtins.open", MockOpen2) 134 | with raises(PermissionError): 135 | splitcopy.parse_src_arg_as_local("/var/tmp/foo") 136 | 137 | 138 | def test_parse_src_arg_as_local_filenotfounderror(monkeypatch: MonkeyPatch): 139 | class MockOpen2(MockOpen): 140 | def __enter__(self): 141 | raise FileNotFoundError 142 | 143 | def expanduser(*args): 144 | return "/var/tmp/foo" 145 | 146 | def abspath(*args): 147 | return "/var/tmp/foo" 148 | 149 | monkeypatch.setattr("os.path.expanduser", expanduser) 150 | monkeypatch.setattr("os.path.abspath", abspath) 151 | monkeypatch.setattr("builtins.open", MockOpen2) 152 | with raises(FileNotFoundError): 153 | splitcopy.parse_src_arg_as_local("/var/tmp/foo") 154 | 155 | 156 | def test_parse_src_arg_as_local_isdirerror(monkeypatch: MonkeyPatch): 157 | class MockOpen2(MockOpen): 158 | def __enter__(self): 159 | raise IsADirectoryError 160 | 161 | def expanduser(*args): 162 | return "/var/tmp/foo" 163 | 164 | def abspath(*args): 165 | return "/var/tmp/foo" 166 | 167 | monkeypatch.setattr("os.path.expanduser", expanduser) 168 | monkeypatch.setattr("os.path.abspath", abspath) 169 | monkeypatch.setattr("builtins.open", MockOpen2) 170 | with raises(IsADirectoryError): 171 | splitcopy.parse_src_arg_as_local("/var/tmp") 172 | 173 | 174 | def test_parse_arg_as_remote_incorrect_format(monkeypatch: MonkeyPatch): 175 | with raises(ValueError): 176 | splitcopy.parse_arg_as_remote("someone@foobar") 177 | 178 | 179 | def test_parse_arg_as_remote_inc_username(monkeypatch: MonkeyPatch): 180 | result = splitcopy.parse_arg_as_remote("someone@foobar:/var/tmp/foo") 181 | assert result == ("someone", "foobar", "/var/tmp/foo") 182 | 183 | 184 | def test_parse_arg_as_remote_without_username(monkeypatch: MonkeyPatch): 185 | def getuser(): 186 | return "someone" 187 | 188 | monkeypatch.setattr("getpass.getuser", getuser) 189 | result = splitcopy.parse_arg_as_remote("foobar:/var/tmp/foo") 190 | assert result == ("someone", "foobar", "/var/tmp/foo") 191 | 192 | 193 | def test_parse_arg_as_remote_nodir(monkeypatch: MonkeyPatch): 194 | result = splitcopy.parse_arg_as_remote("someone@foobar:foo") 195 | assert result == ("someone", "foobar", "foo") 196 | 197 | 198 | def test_parse_arg_as_remote_dotdir(monkeypatch: MonkeyPatch): 199 | result = splitcopy.parse_arg_as_remote("someone@foobar:./foo") 200 | assert result == ("someone", "foobar", "./foo") 201 | 202 | 203 | def test_parse_arg_as_remote_tilda(monkeypatch: MonkeyPatch): 204 | result = splitcopy.parse_arg_as_remote("someone@foobar:~/foo") 205 | assert result == ("someone", "foobar", "~/foo") 206 | 207 | 208 | def test_parse_arg_as_remote_nofile(monkeypatch: MonkeyPatch): 209 | result = splitcopy.parse_arg_as_remote("someone@foobar:") 210 | assert result == ("someone", "foobar", "") 211 | 212 | 213 | def test_open_ssh_keyfile_filenotfounderror(monkeypatch: MonkeyPatch): 214 | class MockOpen2(MockOpen): 215 | def __enter__(self): 216 | raise FileNotFoundError 217 | 218 | def expanduser(*args): 219 | return "/var/tmp/sshkey" 220 | 221 | def abspath(*args): 222 | return "/var/tmp/sshkey" 223 | 224 | monkeypatch.setattr("os.path.abspath", abspath) 225 | monkeypatch.setattr("os.path.expanduser", expanduser) 226 | monkeypatch.setattr("builtins.open", MockOpen2) 227 | with raises(FileNotFoundError): 228 | splitcopy.open_ssh_keyfile("/var/tmp/sshkey") 229 | 230 | 231 | def test_open_ssh_keyfile_permerror(monkeypatch: MonkeyPatch): 232 | class MockOpen2(MockOpen): 233 | def __enter__(self): 234 | raise PermissionError 235 | 236 | def expanduser(*args): 237 | return "/var/tmp/sshkey" 238 | 239 | def abspath(*args): 240 | return "/var/tmp/sshkey" 241 | 242 | monkeypatch.setattr("os.path.abspath", abspath) 243 | monkeypatch.setattr("os.path.expanduser", expanduser) 244 | monkeypatch.setattr("builtins.open", MockOpen2) 245 | with raises(PermissionError): 246 | splitcopy.open_ssh_keyfile("/var/tmp/sshkey") 247 | 248 | 249 | def test_open_ssh_keyfile_isdirerror(monkeypatch: MonkeyPatch): 250 | class MockOpen2(MockOpen): 251 | def __enter__(self): 252 | raise IsADirectoryError 253 | 254 | def expanduser(*args): 255 | return "/var/tmp/sshkey" 256 | 257 | def abspath(*args): 258 | return "/var/tmp/sshkey" 259 | 260 | monkeypatch.setattr("os.path.abspath", abspath) 261 | monkeypatch.setattr("os.path.expanduser", expanduser) 262 | monkeypatch.setattr("builtins.open", MockOpen2) 263 | with raises(IsADirectoryError): 264 | splitcopy.open_ssh_keyfile("/var/tmp/sshkey") 265 | 266 | 267 | def test_open_ssh_keyfile(monkeypatch: MonkeyPatch): 268 | def expanduser(*args): 269 | return "/var/tmp/sshkey" 270 | 271 | def abspath(*args): 272 | return "/var/tmp/sshkey" 273 | 274 | monkeypatch.setattr("os.path.abspath", abspath) 275 | monkeypatch.setattr("os.path.expanduser", expanduser) 276 | monkeypatch.setattr("builtins.open", MockOpen) 277 | result = splitcopy.open_ssh_keyfile("/var/tmp/sshkey") 278 | assert result == True 279 | 280 | 281 | def test_process_args_src_permerror(monkeypatch: MonkeyPatch): 282 | def parse_src_arg_as_local(*args): 283 | raise PermissionError 284 | 285 | monkeypatch.setattr( 286 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 287 | ) 288 | source = "/var/tmp/foo" 289 | target = "192.168.64.7:/var/tmp/" 290 | with raises( 291 | SystemExit, 292 | match=( 293 | f"'{source}' exists, but file cannot be read due to a permissions error" 294 | ), 295 | ): 296 | splitcopy.process_args(source, target) 297 | 298 | 299 | def test_process_args_src_isadirerror(monkeypatch: MonkeyPatch): 300 | def parse_src_arg_as_local(*args): 301 | raise IsADirectoryError 302 | 303 | monkeypatch.setattr( 304 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 305 | ) 306 | source = "/var/tmp" 307 | target = "192.168.64.7:/var/tmp/" 308 | with raises(SystemExit, match=f"'{source}' is a directory, not a file"): 309 | splitcopy.process_args(source, target) 310 | 311 | 312 | def test_process_args_both_args_remote(monkeypatch: MonkeyPatch): 313 | def parse_src_arg_as_local(*args): 314 | raise FileNotFoundError 315 | 316 | def parse_arg_as_remote(*args): 317 | return (None, "192.168.65.2", "/var/tmp/foo") 318 | 319 | monkeypatch.setattr( 320 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 321 | ) 322 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 323 | source = "192.168.85.2:/var/tmp/foo" 324 | target = "192.168.25.2:/var/tmp/foo" 325 | with raises( 326 | SystemExit, 327 | match=( 328 | f"both '{source}' and '{target}' are remote paths - " 329 | "one path must be local, the other remote" 330 | ), 331 | ): 332 | splitcopy.process_args(source, target) 333 | 334 | 335 | def test_process_args_put(monkeypatch: MonkeyPatch): 336 | def parse_src_arg_as_local(*args): 337 | return "foo", "/var/tmp", "/var/tmp/foo" 338 | 339 | def parse_arg_as_remote(*args): 340 | if args[0] == "/var/tmp/foo": 341 | raise ValueError 342 | else: 343 | return (None, "192.168.65.2", "/var/tmp/foo") 344 | 345 | monkeypatch.setattr( 346 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 347 | ) 348 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 349 | source = "/var/tmp/foo" 350 | target = "192.168.25.2:/var/tmp/foo" 351 | result = splitcopy.process_args(source, target) 352 | assert result == { 353 | "user": None, 354 | "host": "192.168.65.2", 355 | "remote_path": "/var/tmp/foo", 356 | "local_dir": "/var/tmp", 357 | "local_file": "foo", 358 | "local_path": "/var/tmp/foo", 359 | "copy_op": "put", 360 | "target": "192.168.25.2:/var/tmp/foo", 361 | } 362 | 363 | 364 | def test_process_args_bad_target_format(monkeypatch: MonkeyPatch): 365 | def parse_src_arg_as_local(*args): 366 | return "foo", "/var/tmp", "/var/tmp/foo" 367 | 368 | def parse_arg_as_remote(*args): 369 | raise ValueError 370 | 371 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 372 | monkeypatch.setattr( 373 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 374 | ) 375 | source = "/var/tmp/foo" 376 | target = "foo@192.168.3.2" 377 | with raises( 378 | SystemExit, 379 | match=f"file '{source}' found, remote path '{target}' is not in the correct format \\[user@\\]host:path", 380 | ): 381 | splitcopy.process_args(source, target) 382 | 383 | 384 | def test_process_args_both_args_local(monkeypatch: MonkeyPatch): 385 | def parse_src_arg_as_local(*args): 386 | return "foo", "/var/tmp", "/var/tmp/foo" 387 | 388 | def parse_arg_as_remote(*args): 389 | raise ValueError 390 | 391 | monkeypatch.setattr( 392 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 393 | ) 394 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 395 | source = "/var/tmp/foo" 396 | target = "/var/tmp/foo2" 397 | with raises( 398 | SystemExit, 399 | match=( 400 | f"file '{source}' found, remote path '{target}' is not in the correct format \\[user@\\]host:path" 401 | ), 402 | ): 403 | splitcopy.process_args(source, target) 404 | 405 | 406 | def test_process_args_no_local_file(monkeypatch: MonkeyPatch): 407 | def parse_src_arg_as_local(*args): 408 | raise FileNotFoundError 409 | 410 | def parse_arg_as_remote(*args): 411 | if args[0] == "foo@192.168.3.2:": 412 | return ("foo", "192.168.3.2", "") 413 | else: 414 | raise ValueError 415 | 416 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 417 | monkeypatch.setattr( 418 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 419 | ) 420 | source = "/var/tmp/foo" 421 | target = "foo@192.168.3.2:" 422 | with raises( 423 | SystemExit, 424 | match=f"'{source}' file not found", 425 | ): 426 | splitcopy.process_args(source, target) 427 | 428 | 429 | def test_process_args_both_args_local_no_local(monkeypatch: MonkeyPatch): 430 | def parse_src_arg_as_local(*args): 431 | raise FileNotFoundError 432 | 433 | def parse_arg_as_remote(*args): 434 | raise ValueError 435 | 436 | monkeypatch.setattr( 437 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 438 | ) 439 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 440 | source = "/var/tmp/foo" 441 | target = "/var/tmp/foo2" 442 | with raises( 443 | SystemExit, 444 | match=f"'{source}' file not found", 445 | ): 446 | splitcopy.process_args(source, target) 447 | 448 | 449 | def test_process_args_no_remote_filepath(monkeypatch: MonkeyPatch): 450 | def parse_src_arg_as_local(*args): 451 | raise FileNotFoundError 452 | 453 | def parse_arg_as_remote(*args): 454 | if args[0] == "foo@192.168.3.2:": 455 | return ("foo", "192.168.3.2", "") 456 | else: 457 | raise ValueError 458 | 459 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 460 | monkeypatch.setattr( 461 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 462 | ) 463 | source = "foo@192.168.3.2:" 464 | target = "/var/tmp/foo" 465 | with raises(SystemExit, match=f"'{source}' does not specify a filepath"): 466 | splitcopy.process_args(source, target) 467 | 468 | 469 | def test_process_args_get(monkeypatch: MonkeyPatch): 470 | def parse_src_arg_as_local(*args): 471 | raise FileNotFoundError 472 | 473 | def parse_arg_as_remote(*args): 474 | if args[0] == "/var/tmp/foo": 475 | raise ValueError 476 | else: 477 | return (None, "192.168.25.2", "/var/tmp/foo") 478 | 479 | monkeypatch.setattr( 480 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 481 | ) 482 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 483 | source = "192.168.25.2:/var/tmp/foo" 484 | target = "/var/tmp/foo" 485 | result = splitcopy.process_args(source, target) 486 | assert result == { 487 | "user": None, 488 | "host": "192.168.25.2", 489 | "remote_path": "/var/tmp/foo", 490 | "local_dir": "", 491 | "local_file": "", 492 | "local_path": "", 493 | "copy_op": "get", 494 | "target": "/var/tmp/foo", 495 | } 496 | 497 | 498 | def test_process_args_resolution_fail(monkeypatch: MonkeyPatch): 499 | def parse_src_arg_as_local(*args): 500 | return "foo", "/var/tmp", "/var/tmp/foo" 501 | 502 | def parse_arg_as_remote(*args): 503 | if args[0] == "foo@foo:/var/tmp/foobar": 504 | return ("foo", "foo", "/var/tmp/foobar") 505 | else: 506 | raise ValueError 507 | 508 | def gethostbyname(*args): 509 | raise gaierror 510 | 511 | monkeypatch.setattr( 512 | "splitcopy.splitcopy.parse_src_arg_as_local", parse_src_arg_as_local 513 | ) 514 | monkeypatch.setattr("splitcopy.splitcopy.parse_arg_as_remote", parse_arg_as_remote) 515 | monkeypatch.setattr("socket.gethostbyname", gethostbyname) 516 | source = "/var/tmp/foo" 517 | target = "foo@foo:/var/tmp/foobar" 518 | with raises( 519 | SystemExit, match=f"Could not resolve hostname 'foo', resolution failed" 520 | ): 521 | splitcopy.process_args(source, target) 522 | 523 | 524 | def test_main_get_scp_success(monkeypatch: MonkeyPatch): 525 | def parse_args(*args): 526 | return Namespace( 527 | source="192.168.64.7:/var/tmp/foobar", 528 | target="/var/tmp/bar", 529 | pwd="lab123", 530 | ssh_key=None, 531 | scp=True, 532 | noverify=False, 533 | split_timeout=None, 534 | ssh_port=None, 535 | overwrite=False, 536 | nocurses=True, 537 | log=None, 538 | ) 539 | 540 | def process_args(*args): 541 | return {"copy_op": "get"} 542 | 543 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 544 | monkeypatch.setattr("splitcopy.splitcopy.process_args", process_args) 545 | result = splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 546 | assert result == True 547 | 548 | 549 | def test_main_put_ftp_success(monkeypatch: MonkeyPatch): 550 | def parse_args(*args): 551 | return Namespace( 552 | source="/var/tmp/foobar", 553 | target="192.168.64.7:/var/tmp/", 554 | pwd=None, 555 | ssh_key=None, 556 | scp=False, 557 | noverify=False, 558 | split_timeout=None, 559 | ssh_port=None, 560 | overwrite=False, 561 | nocurses=False, 562 | log=None, 563 | ) 564 | 565 | def process_args(*args): 566 | return {"copy_op": "put"} 567 | 568 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 569 | monkeypatch.setattr("splitcopy.splitcopy.process_args", process_args) 570 | result = splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 571 | assert result == True 572 | 573 | 574 | def test_main_get_scp_loglevel(monkeypatch: MonkeyPatch): 575 | def parse_args(*args): 576 | return Namespace( 577 | source="192.168.64.7:/var/tmp/foobar", 578 | target="/var/tmp/bar", 579 | pwd=None, 580 | ssh_key=None, 581 | scp=True, 582 | noverify=False, 583 | split_timeout=None, 584 | ssh_port=None, 585 | overwrite=False, 586 | nocurses=False, 587 | log=["debug"], 588 | ) 589 | 590 | def process_args(*args): 591 | return {"copy_op": "get"} 592 | 593 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 594 | monkeypatch.setattr("splitcopy.splitcopy.process_args", process_args) 595 | result = splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 596 | assert result == True 597 | 598 | 599 | def test_main_get_scp_bad_loglevel(monkeypatch: MonkeyPatch): 600 | def parse_args(*args): 601 | return Namespace( 602 | source="192.168.64.7:/var/tmp/foobar", 603 | target="/var/tmp/bar", 604 | pwd=None, 605 | ssh_key=None, 606 | scp=True, 607 | noverify=False, 608 | split_timeout=None, 609 | ssh_port=None, 610 | overwrite=False, 611 | nocurses=False, 612 | log=["123"], 613 | ) 614 | 615 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 616 | with raises(ValueError, match=f"Invalid log level: 123"): 617 | splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 618 | 619 | 620 | def test_main_put_sshkey_filenotfounderror(monkeypatch: MonkeyPatch): 621 | def parse_args(*args): 622 | return Namespace( 623 | source="/var/tmp/foobar", 624 | target="somerandomhost:/var/tmp/", 625 | pwd=None, 626 | ssh_key=["/var/tmp/sshkey"], 627 | scp=True, 628 | noverify=False, 629 | split_timeout=None, 630 | ssh_port=None, 631 | overwrite=False, 632 | nocurses=False, 633 | log=None, 634 | ) 635 | 636 | def open_ssh_keyfile(*args): 637 | raise FileNotFoundError 638 | 639 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 640 | monkeypatch.setattr("splitcopy.splitcopy.open_ssh_keyfile", open_ssh_keyfile) 641 | with raises(SystemExit, match="'/var/tmp/sshkey' file does not exist"): 642 | splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 643 | 644 | 645 | def test_main_put_sshkey_permerror(monkeypatch: MonkeyPatch): 646 | def parse_args(*args): 647 | return Namespace( 648 | source="/var/tmp/foobar", 649 | target="somerandomhost:/var/tmp/", 650 | pwd=None, 651 | ssh_key=["/var/tmp/sshkey"], 652 | scp=True, 653 | noverify=False, 654 | split_timeout=None, 655 | ssh_port=None, 656 | overwrite=False, 657 | nocurses=False, 658 | log=None, 659 | ) 660 | 661 | def open_ssh_keyfile(*args): 662 | raise PermissionError 663 | 664 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 665 | monkeypatch.setattr("splitcopy.splitcopy.open_ssh_keyfile", open_ssh_keyfile) 666 | with raises( 667 | SystemExit, 668 | match="'/var/tmp/sshkey' exists, but file cannot be read due to a permissions error", 669 | ): 670 | splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 671 | 672 | 673 | def test_main_put_sshkey_isadirerror(monkeypatch: MonkeyPatch): 674 | def parse_args(*args): 675 | return Namespace( 676 | source="/var/tmp/foobar", 677 | target="somerandomhost:/var/tmp/", 678 | pwd=None, 679 | ssh_key=["/var/tmp/sshkey"], 680 | scp=True, 681 | noverify=False, 682 | split_timeout=None, 683 | ssh_port=None, 684 | overwrite=False, 685 | nocurses=False, 686 | log=None, 687 | ) 688 | 689 | def open_ssh_keyfile(*args): 690 | raise IsADirectoryError 691 | 692 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 693 | monkeypatch.setattr("splitcopy.splitcopy.open_ssh_keyfile", open_ssh_keyfile) 694 | with raises(SystemExit, match="'/var/tmp/sshkey' is a directory, not a file"): 695 | splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 696 | 697 | 698 | def test_main_put_ftp_sshport_notint(capsys, monkeypatch: MonkeyPatch): 699 | def parse_args(*args): 700 | return Namespace( 701 | source="/var/tmp/foobar", 702 | target="192.168.64.7:/var/tmp/", 703 | pwd=None, 704 | ssh_key=None, 705 | scp=True, 706 | noverify=False, 707 | split_timeout=None, 708 | ssh_port=["foo"], 709 | overwrite=False, 710 | nocurses=False, 711 | log=None, 712 | ) 713 | 714 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 715 | with raises(SystemExit, match="ssh_port must be an integer"): 716 | splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 717 | 718 | 719 | def test_main_put_ftp_split_timeout_notint(monkeypatch: MonkeyPatch): 720 | def parse_args(*args): 721 | return Namespace( 722 | source="/var/tmp/foobar", 723 | target="192.168.64.7:/var/tmp/", 724 | pwd=None, 725 | ssh_key=None, 726 | scp=True, 727 | noverify=False, 728 | split_timeout=["foo"], 729 | ssh_port=None, 730 | overwrite=False, 731 | nocurses=False, 732 | log=None, 733 | ) 734 | 735 | monkeypatch.setattr("splitcopy.splitcopy.parse_args", parse_args) 736 | with raises(SystemExit, match="split_timeout must be an integer"): 737 | splitcopy.main(MockSplitCopyGet, MockSplitCopyPut) 738 | 739 | 740 | def test_handlesigint(): 741 | with raises(SystemExit): 742 | splitcopy.handlesigint("SigInt", "stack") 743 | --------------------------------------------------------------------------------