├── .github └── workflows │ └── pylint.yml ├── .gitignore ├── .gitmodules ├── LICENSE ├── Makefile ├── README.md ├── generate_openssl_selfsigned_certificate.sh ├── py-ssh3 ├── __init__.py ├── asgi │ ├── htdocs │ │ ├── robots.txt │ │ └── style.css │ └── templates │ │ ├── index.html │ │ └── logs.html ├── auth │ ├── __init__.py │ └── openid_connect.py ├── client_cli.py ├── http3 │ ├── __init__.py │ ├── http3_client.py │ ├── http3_hijacker.py │ └── http3_server.py ├── linux_server │ ├── __init__.py │ ├── auth.py │ ├── authorized_identities.py │ └── handlers.py ├── message │ ├── __init__.py │ ├── channel_request.py │ ├── message.py │ └── message_type.py ├── server_cli.py ├── ssh3 │ ├── channel.py │ ├── conversation.py │ ├── identity.py │ ├── known_host.py │ ├── resources_manager.py │ ├── ssh3_client.py │ ├── ssh3_server.py │ └── version.py ├── test │ ├── __init__.py │ ├── integration_test │ │ ├── __init__.py │ │ └── ssh3_test.py │ └── unit_test │ │ └── __init__.py ├── util │ ├── __init__.py │ ├── globals.py │ ├── linux_util │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── cmd.py │ │ └── linux_user.py │ ├── quic_util.py │ ├── type.py │ ├── util.py │ ├── waitgroup.py │ └── wire.py └── winsize │ ├── __init__.py │ ├── common.py │ ├── winsize.py │ └── winsize_windows.py ├── qlogs └── .gitkeep ├── qlogs_client └── .gitkeep └── setup.py /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10"] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install pylint 21 | - name: Analysing the code with pylint 22 | run: | 23 | pylint $(git ls-files '*.py') 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | *.pem 29 | *.key 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | ssh3_env/ 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "py-ssh3/aioquic"] 2 | path = py-ssh3/aioquic 3 | url = https://github.com/aiortc/aioquic.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # TODO should fix python 3 version, min 3.6 2 | env: 3 | test -d ssh3_env || python3 -m venv ssh3_env 4 | 5 | install: 6 | . ssh3_env/bin/activate && python3 -m pip install wheel && cd py-ssh3/aioquic/ && python3 -m pip install . && cd ../../ && python3 -m pip install . 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PySSH3 2 | 3 | Translation of [SSH3](https://github.com/francoismichel/ssh3/tree/c39bb79cdce479f6095ab154a32a168e14d73b57) project (from commit `c39bb79cdce479f6095ab154a32a168e14d73b57`) to Python 3 library. Check the original project for more information ! 4 | 5 | ## Installation 6 | 7 | ### Python 3.6 (TODO) 8 | 9 | TODO 10 | 11 | ### Requirements 12 | 13 | ```bash 14 | make env; make install; 15 | ``` 16 | 17 | ## Usage 18 | 19 | ### PySSH3 server 20 | 21 | ```bash 22 | ./ssh3_env/bin/activate && sudo -E env PATH=$PATH python3 py-ssh3/server_cli.py --help 23 | ./ssh3_env/bin/activate && sudo -E env PATH=$PATH python3 py-ssh3/server_cli.py --generateSelfSignedCert --enablePasswordLogin --bind "127.0.0.1:4443" --urlPath "/my-secret-path" --verbose --insecure 24 | ``` 25 | 26 | #### Authorized keys and authorized identities 27 | TODO 28 | 29 | ### PySSH3 client 30 | ```bash 31 | ./ssh3_env/bin/activate && python3 py-ssh3/client_cli.py --help 32 | ./ssh3_env/bin/activate && python3 py-ssh3/client_cli.py --url "https://localhost:4443/my-secret-path?user=elniak" --verbose --usePassword --insecure 33 | ./ssh3_env/bin/activate && python3 py-ssh3/client_cli.py --url "https://localhost:4443/my-secret-path?user=elniak" --verbose --privkey ~/.ssh/id_rsa --insecure 34 | ``` 35 | 36 | #### Private-key authentication 37 | TODO 38 | #### Agent-based private key authentication 39 | TODO 40 | #### Password authentication 41 | TODO 42 | #### Config-based session establishment 43 | TODO 44 | #### OpenID Connect authentication (TODO) 45 | TODO 46 | 47 | ## TODO 48 | - [ ] Add tests 49 | - [ ] Add documentation 50 | - [ ] Add examples 51 | - [ ] Add more features 52 | - [ ] Add threading support 53 | - [ ] Inspire more from [paramiko] 54 | - [ ] Secure version 55 | - [ ] request.url.scheme == "ssh3" -------------------------------------------------------------------------------- /generate_openssl_selfsigned_certificate.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Copy paste from original SSH3 repository 3 | # we strongly recommend using classical certificates through e.g. letencrypt 4 | # however, if you want a comparable security level to OpenSSH's host keys, 5 | # you can use this script to generate a self-signed certificate for every host 6 | # and every IP address to install on your server 7 | openssl req -x509 -sha256 -nodes -newkey rsa:4096 -keyout priv.key -days 3660 -out cert.pem -subj "/C=XX/O=Default Company/OU=XX/CN=selfsigned.ssh3" -addext "subjectAltName = DNS:selfsigned.ssh3,DNS:*" 8 | -------------------------------------------------------------------------------- /py-ssh3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/__init__.py -------------------------------------------------------------------------------- /py-ssh3/asgi/htdocs/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | Disallow: /logs 3 | -------------------------------------------------------------------------------- /py-ssh3/asgi/htdocs/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: Arial, sans-serif; 3 | font-size: 16px; 4 | margin: 0 auto; 5 | width: 40em; 6 | } 7 | 8 | table.logs { 9 | width: 100%; 10 | } 11 | -------------------------------------------------------------------------------- /py-ssh3/asgi/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | aioquic 6 | 7 | 8 | 9 |

Welcome to aioquic

10 |

11 | This is a test page for aioquic, 12 | a QUIC and HTTP/3 implementation written in Python. 13 |

14 | {% if request.scope["http_version"] == "3" %} 15 |

16 | Congratulations, you loaded this page using HTTP/3! 17 |

18 | {% endif %} 19 |

Available endpoints

20 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /py-ssh3/asgi/templates/logs.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | aioquic - logs 6 | 7 | 8 | 9 |

QLOG files

10 | 11 | 12 | 13 | 14 | 15 | 16 | {% for log in logs %} 17 | 18 | 22 | 23 | 24 | 25 | {% endfor %} 26 |
namedate (UTC)size
19 | {{ log.name }} 20 | [qvis] 21 | {{ log.date }}{{ log.size }}
27 | 28 | 29 | -------------------------------------------------------------------------------- /py-ssh3/auth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/auth/__init__.py -------------------------------------------------------------------------------- /py-ssh3/auth/openid_connect.py: -------------------------------------------------------------------------------- 1 | import os 2 | from urllib import request 3 | import webbrowser 4 | import http.server 5 | import socketserver 6 | import threading 7 | import base64 8 | import json 9 | import logging 10 | from authlib.integrations.requests_client import OAuth2Session 11 | from oauthlib.oauth2 import WebApplicationClient 12 | from cryptography.hazmat.primitives import hashes 13 | from cryptography.hazmat.backends import default_backend 14 | 15 | class OIDCConfig: 16 | def __init__(self, issuer_url, client_id, client_secret): 17 | self.issuer_url = issuer_url 18 | self.client_id = client_id 19 | self.client_secret = client_secret 20 | 21 | def connect(oidc_config: OIDCConfig, issuer_url: str, do_pkce: bool): 22 | client = WebApplicationClient(oidc_config.client_id) 23 | 24 | # Discover the provider 25 | # Note: Discovery endpoint can vary by provider 26 | discovery_url = f"{issuer_url}/.well-known/openid-configuration" 27 | oidc_provider_config = request.get(discovery_url).json() 28 | 29 | authorization_endpoint = oidc_provider_config["authorization_endpoint"] 30 | token_endpoint = oidc_provider_config["token_endpoint"] 31 | 32 | # Create a random secret URL 33 | random_secret = os.urandom(32) 34 | random_secret_url = base64.urlsafe_b64encode(random_secret).decode() 35 | 36 | # Start a local webserver to handle the OAuth2 callback 37 | # ... 38 | 39 | # Open a browser window for authorization 40 | oauth_session = OAuth2Session(client_id=oidc_config.client_id, redirect_uri=f"http://localhost:{port}/{random_secret_url}") 41 | authorization_url, state = oauth_session.authorization_url(authorization_endpoint) 42 | 43 | webbrowser.open_new(authorization_url) 44 | # Wait for the callback to be handled 45 | # ... 46 | 47 | # Exchange the authorization code for an access token 48 | # ... 49 | 50 | return raw_id_token 51 | 52 | def oauth2_callback_handler(client_id: str, oauth_session: OAuth2Session, token_endpoint: str, token_channel): 53 | class Handler(http.server.BaseHTTPRequestHandler): 54 | def do_GET(self): 55 | if self.path.startswith(f"/{random_secret_url}"): 56 | # Extract the authorization code and exchange it for a token 57 | # ... 58 | 59 | # Send the token back through the channel 60 | # ... 61 | 62 | self.send_response(200) 63 | self.end_headers() 64 | self.wfile.write(b"You can now close this tab") 65 | 66 | return Handler 67 | 68 | def verify_raw_token(client_id: str, issuer_url: str, raw_id_token: str): 69 | # Discover the provider and create a verifier 70 | # ... 71 | 72 | # Verify the ID token 73 | # ... 74 | 75 | return id_token 76 | -------------------------------------------------------------------------------- /py-ssh3/http3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/http3/__init__.py -------------------------------------------------------------------------------- /py-ssh3/http3/http3_hijacker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from aioquic.asyncio import QuicConnectionProtocol 3 | from aioquic.asyncio.protocol import QuicStreamHandler 4 | from aioquic.quic.events import StreamDataReceived, StreamReset, ConnectionTerminated 5 | from aioquic.quic.connection import QuicConnection 6 | 7 | class HTTPStreamer: 8 | """ 9 | Allows taking over a HTTP/3 stream. This is a simplified version that 10 | assumes the stream is already established. 11 | """ 12 | def __init__(self, stream_reader, stream_writer): 13 | self.stream_reader = stream_reader 14 | self.stream_writer = stream_writer 15 | self.stream_id = stream_writer._transport.stream_id 16 | 17 | async def read(self, size): 18 | return await self.stream_reader.read(size) 19 | 20 | async def write(self, data): 21 | self.stream_writer.write(data) 22 | await self.stream_writer.drain() 23 | 24 | def close(self): 25 | self.stream_writer.close() 26 | 27 | class StreamCreator: 28 | """ 29 | This class represents an entity capable of creating QUIC streams. 30 | """ 31 | def __init__(self, protocol: QuicConnectionProtocol): 32 | self.protocol = protocol 33 | 34 | async def open_stream(self) -> HTTPStreamer: 35 | reader, writer = await self.protocol.create_stream() 36 | return HTTPStreamer(reader, writer) 37 | 38 | async def open_uni_stream(self) -> HTTPStreamer: 39 | reader, writer = await self.protocol.create_stream(is_unidirectional=True) 40 | return HTTPStreamer(reader, writer) 41 | 42 | def local_addr(self): 43 | return self.protocol._quic._local_endpoint 44 | 45 | def remote_addr(self): 46 | return self.protocol._quic._peer_endpoint 47 | 48 | def connection_state(self): 49 | return self.protocol._quic._state 50 | 51 | class Hijacker: 52 | """ 53 | Allows hijacking of the stream creating part of a QuicConnectionProtocol. 54 | """ 55 | def __init__(self, protocol: QuicConnectionProtocol): 56 | self.protocol = protocol 57 | 58 | def stream_creator(self) -> StreamCreator: 59 | return StreamCreator(self.protocol) 60 | 61 | class Body: 62 | """ 63 | The body of a HTTP Request or Response. 64 | """ 65 | def __init__(self, stream: HTTPStreamer): 66 | self.stream = stream 67 | self.was_hijacked = False 68 | 69 | async def read(self, size) -> bytes: 70 | return await self.stream.read(size) 71 | 72 | def http_stream(self) -> HTTPStreamer: 73 | self.was_hijacked = True 74 | return self.stream 75 | 76 | async def close(self): 77 | self.stream.close() 78 | 79 | # Example usage 80 | # async def main(): 81 | # # Example usage - this will vary depending on how you establish your QUIC connection. 82 | # # Replace with actual connection and protocol setup. 83 | # protocol = QuicConnectionProtocol(QuicConnection(...)) 84 | # stream_creator = StreamCreator(protocol) 85 | # http_stream = await stream_creator.open_stream() 86 | # # Now you can read from or write to http_stream as needed. 87 | 88 | # if __name__ == "__main__": 89 | # asyncio.run(main()) 90 | -------------------------------------------------------------------------------- /py-ssh3/http3/http3_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import importlib 4 | import logging 5 | import time 6 | from collections import deque 7 | from email.utils import formatdate 8 | from typing import Callable, Deque, Dict, List, Optional, Union, cast 9 | 10 | import aioquic 11 | import wsproto 12 | import wsproto.events 13 | from aioquic.asyncio import QuicConnectionProtocol, serve 14 | from aioquic.h0.connection import H0_ALPN, H0Connection 15 | from aioquic.h3.connection import H3_ALPN, H3Connection 16 | from aioquic.h3.events import ( 17 | DatagramReceived, 18 | DataReceived, 19 | H3Event, 20 | HeadersReceived, 21 | WebTransportStreamDataReceived, 22 | ) 23 | from aioquic.h3.exceptions import NoAvailablePushIDError 24 | from aioquic.quic.configuration import QuicConfiguration 25 | from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent 26 | from aioquic.quic.logger import QuicFileLogger 27 | from aioquic.tls import SessionTicket 28 | import util.globals as glob 29 | from http3.http3_hijacker import * 30 | try: 31 | import uvloop 32 | except ImportError: 33 | uvloop = None 34 | 35 | log = logging.getLogger(__name__) 36 | 37 | from ssh3.version import get_current_version 38 | 39 | 40 | AsgiApplication = Callable 41 | HttpConnection = Union[H0Connection, H3Connection] 42 | 43 | SERVER_NAME = get_current_version() 44 | 45 | 46 | class HttpRequestHandler: 47 | def __init__( 48 | self, 49 | *, 50 | authority: bytes, 51 | connection: HttpConnection, 52 | protocol: QuicConnectionProtocol, 53 | scope: Dict, 54 | stream_ended: bool, 55 | stream_id: int, 56 | transmit: Callable[[], None], 57 | ) -> None: 58 | self.authority = authority 59 | self.connection = connection 60 | self.protocol = protocol 61 | self.queue: asyncio.Queue[Dict] = asyncio.Queue() 62 | self.scope = scope 63 | self.stream_id = stream_id 64 | self.transmit = transmit 65 | 66 | if stream_ended: 67 | self.queue.put_nowait({"type": "http.request"}) 68 | 69 | def http_event_received(self, event: H3Event) -> None: 70 | log.debug("HTTP event received: %s", event) 71 | if isinstance(event, DataReceived): 72 | self.queue.put_nowait( 73 | { 74 | "type": "http.request", 75 | "body": event.data, 76 | "more_body": not event.stream_ended, 77 | } 78 | ) 79 | elif isinstance(event, HeadersReceived) and event.stream_ended: 80 | self.queue.put_nowait( 81 | {"type": "http.request", "body": b"", "more_body": False} 82 | ) 83 | 84 | async def run_asgi(self, app: AsgiApplication) -> None: 85 | log.debug("Running ASGI app with parameters: %s", self.scope) 86 | await app(self.scope, self.receive, self.send) 87 | 88 | async def receive(self) -> Dict: 89 | return await self.queue.get() 90 | 91 | async def send(self, message: Dict) -> None: 92 | log.debug(message) 93 | if message["type"] == "http.response.start": 94 | self.connection.send_headers( 95 | stream_id=self.stream_id, 96 | headers=[ 97 | (b":status", str(message["status"]).encode()), 98 | (b"server", SERVER_NAME.encode()), 99 | (b"date", formatdate(time.time(), usegmt=True).encode()), 100 | ] 101 | + [(k, v) for k, v in message["headers"]], 102 | ) 103 | elif message["type"] == "http.response.body": 104 | self.connection.send_data( 105 | stream_id=self.stream_id, 106 | data=message.get("body", b""), 107 | end_stream=not message.get("more_body", False), 108 | ) 109 | elif message["type"] == "http.response.push" and isinstance( 110 | self.connection, H3Connection 111 | ): 112 | request_headers = [ 113 | (b":method", b"GET"), 114 | (b":scheme", b"https"), 115 | (b":authority", self.authority), 116 | (b":path", message["path"].encode()), 117 | ] + [(k, v) for k, v in message["headers"]] 118 | 119 | # send push promise 120 | try: 121 | push_stream_id = self.connection.send_push_promise( 122 | stream_id=self.stream_id, headers=request_headers 123 | ) 124 | except NoAvailablePushIDError: 125 | return 126 | 127 | # fake request 128 | cast(HttpServerProtocol, self.protocol).http_event_received( 129 | HeadersReceived( 130 | headers=request_headers, stream_ended=True, stream_id=push_stream_id 131 | ) 132 | ) 133 | self.transmit() 134 | 135 | 136 | class WebSocketHandler: 137 | def __init__( 138 | self, 139 | *, 140 | connection: HttpConnection, 141 | scope: Dict, 142 | stream_id: int, 143 | transmit: Callable[[], None], 144 | ) -> None: 145 | self.closed = False 146 | self.connection = connection 147 | self.http_event_queue: Deque[DataReceived] = deque() 148 | self.queue: asyncio.Queue[Dict] = asyncio.Queue() 149 | self.scope = scope 150 | self.stream_id = stream_id 151 | self.transmit = transmit 152 | self.websocket: Optional[wsproto.Connection] = None 153 | 154 | def http_event_received(self, event: H3Event) -> None: 155 | if isinstance(event, DataReceived) and not self.closed: 156 | if self.websocket is not None: 157 | self.websocket.receive_data(event.data) 158 | 159 | for ws_event in self.websocket.events(): 160 | self.websocket_event_received(ws_event) 161 | else: 162 | # delay event processing until we get `websocket.accept` 163 | # from the ASGI application 164 | self.http_event_queue.append(event) 165 | 166 | def websocket_event_received(self, event: wsproto.events.Event) -> None: 167 | if isinstance(event, wsproto.events.TextMessage): 168 | self.queue.put_nowait({"type": "websocket.receive", "text": event.data}) 169 | elif isinstance(event, wsproto.events.Message): 170 | self.queue.put_nowait({"type": "websocket.receive", "bytes": event.data}) 171 | elif isinstance(event, wsproto.events.CloseConnection): 172 | self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code}) 173 | 174 | async def run_asgi(self, app: AsgiApplication) -> None: 175 | self.queue.put_nowait({"type": "websocket.connect"}) 176 | 177 | try: 178 | await app(self.scope, self.receive, self.send) 179 | finally: 180 | if not self.closed: 181 | await self.send({"type": "websocket.close", "code": 1000}) 182 | 183 | async def receive(self) -> Dict: 184 | return await self.queue.get() 185 | 186 | async def send(self, message: Dict) -> None: 187 | data = b"" 188 | end_stream = False 189 | if message["type"] == "websocket.accept": 190 | subprotocol = message.get("subprotocol") 191 | 192 | self.websocket = wsproto.Connection(wsproto.ConnectionType.SERVER) 193 | 194 | headers = [ 195 | (b":status", b"200"), 196 | (b"server", SERVER_NAME.encode()), 197 | (b"date", formatdate(time.time(), usegmt=True).encode()), 198 | ] 199 | if subprotocol is not None: 200 | headers.append((b"sec-websocket-protocol", subprotocol.encode())) 201 | self.connection.send_headers(stream_id=self.stream_id, headers=headers) 202 | 203 | # consume backlog 204 | while self.http_event_queue: 205 | self.http_event_received(self.http_event_queue.popleft()) 206 | 207 | elif message["type"] == "websocket.close": 208 | if self.websocket is not None: 209 | data = self.websocket.send( 210 | wsproto.events.CloseConnection(code=message["code"]) 211 | ) 212 | else: 213 | self.connection.send_headers( 214 | stream_id=self.stream_id, headers=[(b":status", b"403")] 215 | ) 216 | end_stream = True 217 | elif message["type"] == "websocket.send": 218 | if message.get("text") is not None: 219 | data = self.websocket.send( 220 | wsproto.events.TextMessage(data=message["text"]) 221 | ) 222 | elif message.get("bytes") is not None: 223 | data = self.websocket.send( 224 | wsproto.events.Message(data=message["bytes"]) 225 | ) 226 | 227 | if data: 228 | self.connection.send_data( 229 | stream_id=self.stream_id, data=data, end_stream=end_stream 230 | ) 231 | if end_stream: 232 | self.closed = True 233 | self.transmit() 234 | 235 | 236 | class WebTransportHandler: 237 | def __init__( 238 | self, 239 | *, 240 | connection: HttpConnection, 241 | scope: Dict, 242 | stream_id: int, 243 | transmit: Callable[[], None], 244 | ) -> None: 245 | self.accepted = False 246 | self.closed = False 247 | self.connection = connection 248 | self.http_event_queue: Deque[DataReceived] = deque() 249 | self.queue: asyncio.Queue[Dict] = asyncio.Queue() 250 | self.scope = scope 251 | self.stream_id = stream_id 252 | self.transmit = transmit 253 | 254 | def http_event_received(self, event: H3Event) -> None: 255 | if not self.closed: 256 | if self.accepted: 257 | if isinstance(event, DatagramReceived): 258 | self.queue.put_nowait( 259 | { 260 | "data": event.data, 261 | "type": "webtransport.datagram.receive", 262 | } 263 | ) 264 | elif isinstance(event, WebTransportStreamDataReceived): 265 | self.queue.put_nowait( 266 | { 267 | "data": event.data, 268 | "stream": event.stream_id, 269 | "type": "webtransport.stream.receive", 270 | } 271 | ) 272 | else: 273 | # delay event processing until we get `webtransport.accept` 274 | # from the ASGI application 275 | self.http_event_queue.append(event) 276 | 277 | async def run_asgi(self, app: AsgiApplication) -> None: 278 | self.queue.put_nowait({"type": "webtransport.connect"}) 279 | 280 | try: 281 | await app(self.scope, self.receive, self.send) 282 | finally: 283 | if not self.closed: 284 | await self.send({"type": "webtransport.close"}) 285 | 286 | async def receive(self) -> Dict: 287 | return await self.queue.get() 288 | 289 | async def send(self, message: Dict) -> None: 290 | data = b"" 291 | end_stream = False 292 | 293 | if message["type"] == "webtransport.accept": 294 | self.accepted = True 295 | 296 | headers = [ 297 | (b":status", b"200"), 298 | (b"server", SERVER_NAME.encode()), 299 | (b"date", formatdate(time.time(), usegmt=True).encode()), 300 | (b"sec-webtransport-http3-draft", b"draft02"), 301 | ] 302 | self.connection.send_headers(stream_id=self.stream_id, headers=headers) 303 | 304 | # consume backlog 305 | while self.http_event_queue: 306 | self.http_event_received(self.http_event_queue.popleft()) 307 | elif message["type"] == "webtransport.close": 308 | if not self.accepted: 309 | self.connection.send_headers( 310 | stream_id=self.stream_id, headers=[(b":status", b"403")] 311 | ) 312 | end_stream = True 313 | elif message["type"] == "webtransport.datagram.send": 314 | self.connection.send_datagram( 315 | stream_id=self.stream_id, data=message["data"] 316 | ) 317 | elif message["type"] == "webtransport.stream.send": 318 | self.connection._quic.send_stream_data( 319 | stream_id=message["stream"], data=message["data"] 320 | ) 321 | 322 | if data or end_stream: 323 | self.connection.send_data( 324 | stream_id=self.stream_id, data=data, end_stream=end_stream 325 | ) 326 | if end_stream: 327 | self.closed = True 328 | self.transmit() 329 | 330 | 331 | Handler = Union[HttpRequestHandler, WebSocketHandler, WebTransportHandler] 332 | 333 | 334 | class HttpServerProtocol(QuicConnectionProtocol): 335 | def __init__(self, *args, **kwargs) -> None: 336 | super().__init__(*args, **kwargs) 337 | self._handlers: Dict[int, Handler] = {} 338 | self._http: Optional[HttpConnection] = None 339 | 340 | 341 | def http_event_received(self, event: H3Event) -> None: 342 | log.debug("HTTP event received: %s", event) 343 | if isinstance(event, HeadersReceived) and event.stream_id not in self._handlers: 344 | authority = None 345 | headers = [] 346 | http_version = "0.9" if isinstance(self._http, H0Connection) else "3" 347 | raw_path = b"" 348 | method = "" 349 | protocol = None 350 | for header, value in event.headers: 351 | if header == b":authority": 352 | authority = value 353 | headers.append((b"host", value)) 354 | elif header == b":method": 355 | method = value.decode() 356 | elif header == b":path": 357 | raw_path = value 358 | elif header == b":protocol": 359 | protocol = value.decode() 360 | elif header and not header.startswith(b":"): 361 | headers.append((header, value)) 362 | 363 | if b"?" in raw_path: 364 | path_bytes, query_string = raw_path.split(b"?", maxsplit=1) 365 | else: 366 | path_bytes, query_string = raw_path, b"" 367 | path = path_bytes.decode() 368 | self._quic._logger.info("HTTP request %s %s %s", method, path, protocol) 369 | 370 | # FIXME: add a public API to retrieve peer address 371 | client_addr = self._http._quic._network_paths[0].addr 372 | client = (client_addr[0], client_addr[1]) 373 | 374 | handler: Handler 375 | scope: Dict 376 | if method == "CONNECT" and protocol == "websocket": 377 | subprotocols: List[str] = [] 378 | for header, value in event.headers: 379 | if header == b"sec-websocket-protocol": 380 | subprotocols = [x.strip() for x in value.decode().split(",")] 381 | scope = { 382 | "client": client, 383 | "headers": headers, 384 | "http_version": http_version, 385 | "method": method, 386 | "path": path, 387 | "query_string": query_string, 388 | "raw_path": raw_path, 389 | "root_path": "", 390 | "scheme": "wss", 391 | "subprotocols": subprotocols, 392 | "type": "websocket", 393 | } 394 | handler = WebSocketHandler( 395 | connection=self._http, 396 | scope=scope, 397 | stream_id=event.stream_id, 398 | transmit=self.transmit, 399 | ) 400 | elif method == "CONNECT" and protocol == "webtransport": 401 | scope = { 402 | "client": client, 403 | "headers": headers, 404 | "http_version": http_version, 405 | "method": method, 406 | "path": path, 407 | "query_string": query_string, 408 | "raw_path": raw_path, 409 | "root_path": "", 410 | "scheme": "https", 411 | "type": "webtransport", 412 | } 413 | handler = WebTransportHandler( 414 | connection=self._http, 415 | scope=scope, 416 | stream_id=event.stream_id, 417 | transmit=self.transmit, 418 | ) 419 | else: 420 | scheme = protocol 421 | extensions: Dict[str, Dict] = {} 422 | if isinstance(self._http, H3Connection): 423 | extensions["http.response.push"] = {} 424 | scope = { 425 | "client": client, 426 | "extensions": extensions, 427 | "headers": headers, 428 | "http_version": http_version, 429 | "method": method, 430 | "path": path, 431 | "query_string": query_string, 432 | "raw_path": raw_path, 433 | "root_path": "", 434 | "scheme": scheme, 435 | "type": "http", 436 | "stream_ended":event.stream_ended, 437 | "stream_id":event.stream_id, 438 | } 439 | log.debug("HTTP request 2: %s %s %s %s", method, path, protocol, scope) 440 | handler = HttpRequestHandler( 441 | authority=authority, 442 | connection=self._http, 443 | protocol=self, 444 | scope=scope, 445 | stream_ended=event.stream_ended, 446 | stream_id=event.stream_id, 447 | transmit=self.transmit, 448 | ) 449 | self._handlers[event.stream_id] = handler 450 | asyncio.ensure_future(handler.run_asgi(glob.APPLICATION)) 451 | elif ( 452 | isinstance(event, (DataReceived, HeadersReceived)) 453 | and event.stream_id in self._handlers 454 | ): 455 | handler = self._handlers[event.stream_id] 456 | handler.http_event_received(event) 457 | elif isinstance(event, DatagramReceived): 458 | handler = self._handlers[event.stream_id] 459 | handler.http_event_received(event) 460 | elif isinstance(event, WebTransportStreamDataReceived): 461 | handler = self._handlers[event.session_id] 462 | handler.http_event_received(event) 463 | 464 | def quic_event_received(self, event: QuicEvent) -> None: 465 | log.debug("QUIC event received: %s", event) 466 | if isinstance(event, ProtocolNegotiated): 467 | if event.alpn_protocol in H3_ALPN: 468 | log.debug("Negotiated HTTP/3") 469 | self._http = H3Connection(self._quic, enable_webtransport=True) 470 | elif event.alpn_protocol in H0_ALPN: 471 | self._http = H0Connection(self._quic) 472 | elif isinstance(event, DatagramFrameReceived): 473 | if event.data == b"quack": 474 | self._quic.send_datagram_frame(b"quack-ack") 475 | 476 | #  pass event to the HTTP layer 477 | if self._http is not None: 478 | for http_event in self._http.handle_event(event): 479 | self.http_event_received(http_event) 480 | 481 | 482 | class SessionTicketStore: 483 | """ 484 | Simple in-memory store for session tickets. 485 | """ 486 | 487 | def __init__(self) -> None: 488 | self.tickets: Dict[bytes, SessionTicket] = {} 489 | 490 | def add(self, ticket: SessionTicket) -> None: 491 | self.tickets[ticket.ticket] = ticket 492 | 493 | def pop(self, label: bytes) -> Optional[SessionTicket]: 494 | return self.tickets.pop(label, None) 495 | 496 | 497 | async def main( 498 | host: str, 499 | port: int, 500 | configuration: QuicConfiguration, 501 | session_ticket_store: SessionTicketStore, 502 | retry: bool, 503 | ) -> None: 504 | await serve( 505 | host, 506 | port, 507 | configuration=configuration, 508 | create_protocol=HttpServerProtocol, 509 | session_ticket_fetcher=session_ticket_store.pop, 510 | session_ticket_handler=session_ticket_store.add, 511 | retry=retry, 512 | ) 513 | await asyncio.Future() 514 | 515 | 516 | if __name__ == "__main__": 517 | defaults = QuicConfiguration(is_client=False) 518 | 519 | parser = argparse.ArgumentParser(description="QUIC server") 520 | parser.add_argument( 521 | "app", 522 | type=str, 523 | nargs="?", 524 | default="demo:app", 525 | help="the ASGI application as :", 526 | ) 527 | parser.add_argument( 528 | "-c", 529 | "--certificate", 530 | type=str, 531 | required=True, 532 | help="load the TLS certificate from the specified file", 533 | ) 534 | parser.add_argument( 535 | "--congestion-control-algorithm", 536 | type=str, 537 | default="reno", 538 | help="use the specified congestion control algorithm", 539 | ) 540 | parser.add_argument( 541 | "--host", 542 | type=str, 543 | default="::", 544 | help="listen on the specified address (defaults to ::)", 545 | ) 546 | parser.add_argument( 547 | "--port", 548 | type=int, 549 | default=4433, 550 | help="listen on the specified port (defaults to 4433)", 551 | ) 552 | parser.add_argument( 553 | "-k", 554 | "--private-key", 555 | type=str, 556 | help="load the TLS private key from the specified file", 557 | ) 558 | parser.add_argument( 559 | "-l", 560 | "--secrets-log", 561 | type=str, 562 | help="log secrets to a file, for use with Wireshark", 563 | ) 564 | parser.add_argument( 565 | "--max-datagram-size", 566 | type=int, 567 | default=defaults.max_datagram_size, 568 | help="maximum datagram size to send, excluding UDP or IP overhead", 569 | ) 570 | parser.add_argument( 571 | "-q", 572 | "--quic-log", 573 | type=str, 574 | help="log QUIC events to QLOG files in the specified directory", 575 | ) 576 | parser.add_argument( 577 | "--retry", 578 | action="store_true", 579 | help="send a retry for new connections", 580 | ) 581 | parser.add_argument( 582 | "-v", "--verbose", action="store_true", help="increase logging verbosity" 583 | ) 584 | args = parser.parse_args() 585 | 586 | logging.basicConfig( 587 | format="%(asctime)s %(levelname)s %(name)s %(message)s", 588 | level=logging.DEBUG if args.verbose else logging.INFO, 589 | ) 590 | 591 | # import ASGI application 592 | module_str, attr_str = args.app.split(":", maxsplit=1) 593 | module = importlib.import_module(module_str) 594 | application = getattr(module, attr_str) 595 | 596 | # create QUIC logger 597 | if args.quic_log: 598 | quic_logger = QuicFileLogger(args.quic_log) 599 | else: 600 | quic_logger = None 601 | 602 | # open SSL log file 603 | if args.secrets_log: 604 | secrets_log_file = open(args.secrets_log, "a") 605 | else: 606 | secrets_log_file = None 607 | 608 | configuration = QuicConfiguration( 609 | alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"], 610 | congestion_control_algorithm=args.congestion_control_algorithm, 611 | is_client=False, 612 | max_datagram_frame_size=65536, 613 | max_datagram_size=args.max_datagram_size, 614 | quic_logger=quic_logger, 615 | secrets_log_file=secrets_log_file, 616 | ) 617 | 618 | # load SSL certificate and key 619 | configuration.load_cert_chain(args.certificate, args.private_key) 620 | 621 | if uvloop is not None: 622 | uvloop.install() 623 | 624 | try: 625 | asyncio.run( 626 | main( 627 | host=args.host, 628 | port=args.port, 629 | configuration=configuration, 630 | session_ticket_store=SessionTicketStore(), 631 | retry=args.retry, 632 | ) 633 | ) 634 | except KeyboardInterrupt: 635 | pass 636 | -------------------------------------------------------------------------------- /py-ssh3/linux_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/linux_server/__init__.py -------------------------------------------------------------------------------- /py-ssh3/linux_server/auth.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | import util.globals as glob 4 | from typing import Callable 5 | from util.linux_util.linux_user import * 6 | from aioquic.asyncio.server import QuicServer 7 | from linux_server.handlers import * 8 | from ssh3.version import * 9 | from ssh3.conversation import * 10 | from http3.http3_server import HttpServerProtocol 11 | from aioquic.h3.connection import H3_ALPN, H3Connection 12 | from aioquic.h3.events import ( 13 | DatagramReceived, 14 | DataReceived, 15 | H3Event, 16 | HeadersReceived, 17 | WebTransportStreamDataReceived, 18 | ) 19 | from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent 20 | from aioquic.quic.connection import * 21 | from ssh3.version import * 22 | from starlette.responses import PlainTextResponse, Response 23 | from aioquic.tls import * 24 | from http3.http3_hijacker import * 25 | logging.basicConfig(level=logging.DEBUG) 26 | logger = logging.getLogger(__name__) 27 | 28 | SERVER_NAME = get_current_version() 29 | 30 | class AuthHttpServerProtocol(HttpServerProtocol): 31 | def __init__(self, *args, **kwargs) -> None: 32 | super().__init__(*args, **kwargs) 33 | self.hijacker = Hijacker(self) 34 | 35 | def http_event_received(self, event: H3Event) -> None: 36 | for header, value in event.headers: 37 | if header == b"user-agent": 38 | try: 39 | major, minor, patch = parse_version(value.decode()) 40 | except InvalidSSHVersion: 41 | logger.debug(f"Received invalid SSH version: {value}") 42 | return 43 | logger.debug(f"Received HTTP event: {event} with version {major}.{minor}.{patch}") 44 | super().http_event_received(event) 45 | 46 | def quic_event_received(self, event: QuicEvent) -> None: 47 | super().quic_event_received(event) 48 | 49 | 50 | # import ASGI application 51 | async def handle_auths( 52 | request 53 | ): 54 | """ 55 | Handle different types of authentication for a given HTTP request. 56 | enable_password_login: bool, 57 | default_max_packet_size: int, 58 | handler_func: callable, 59 | quic_server: QuicServer 60 | """ 61 | logger.info(f"Auth - Received request {request}") 62 | logger.info(f"Auth - Received request headers {request.headers}") 63 | # Set response server header 64 | content = "" 65 | status = 200 66 | header = { 67 | b"Server": SERVER_NAME 68 | } 69 | 70 | # Check SSH3 version 71 | user_agent = b"" 72 | for h,v in request["headers"]: 73 | if h == b"user-agent": 74 | try: 75 | user_agent = v.decode() 76 | except InvalidSSHVersion: 77 | logger.debug(f"Received invalid SSH version: {v}") 78 | status = 400 79 | return Response(content=b"Invalid version", 80 | headers=header, 81 | status_code=status) 82 | 83 | major, minor, patch = parse_version(user_agent) # Implement this function 84 | if major != MAJOR or minor != MINOR: 85 | return Response(content=b"Unsupported version", 86 | headers=header, 87 | status_code=status) 88 | 89 | # For the response 90 | protocols_keys = list(glob.QUIC_SERVER._protocols.keys()) 91 | prot = glob.QUIC_SERVER._protocols[protocols_keys[-1]] 92 | hijacker = prot.hijacker 93 | if not hijacker: 94 | logger.debug(f"failed to hijack") 95 | status = 400 96 | return Response(content=b"failed to hijack", 97 | headers=header, 98 | status_code=status) 99 | stream_creator = hijacker.stream_creator() 100 | tls_state = stream_creator.connection_state() 101 | logger.info(f"TLS state is {tls_state}") 102 | if tls_state != QuicConnectionState.CONNECTED: 103 | logger.debug(f"Too early connection") 104 | status = 400 105 | return Response(content=b"Too early connection", 106 | headers=header, 107 | status_code=status) 108 | 109 | # Create a new conversation 110 | # Implement NewServerConversation based on your protocol's specifics 111 | # From the request TODO 112 | stream = await stream_creator.open_stream() 113 | logger.info(f"Received stream {stream}") 114 | conv = await new_server_conversation( 115 | max_packet_size=glob.DEFAULT_MAX_PACKET_SIZE, 116 | queue_size=10, 117 | tls_state= tls_state, 118 | control_stream=stream, 119 | stream_creator=stream_creator, 120 | ) 121 | logger.info(f"Created new conversation {conv}") 122 | # Handle authentication 123 | authorization = b"" 124 | for h,v in request["headers"]: 125 | if h == b"authorization": 126 | try: 127 | authorization = v.decode() 128 | except Exception: 129 | logger.debug(f"Received invalid authorization version: {v}") 130 | status = 400 131 | return Response(content=b"Invalid authorization", 132 | headers=header, 133 | status_code=status) 134 | logger.info(f"Received authorization {authorization}") 135 | if glob.ENABLE_PASSWORD_LOGIN and authorization.startswith("Basic "): 136 | logger.info("Handling basic auth") 137 | return await handle_basic_auth(request=request, conv=conv) 138 | elif authorization.startswith("Bearer "): # TODO 139 | logger.info("Handling bearer auth") 140 | username = request.headers.get(b":path").decode().split("?", 1)[0].lstrip("/") 141 | conv_id = base64.b64encode(conv.id).decode() 142 | return await handle_bearer_auth(username, conv_id) 143 | else: 144 | logger.info("Handling no auth") 145 | header[b"www-authenticate"] = b"Basic" 146 | status = 401 147 | return Response(content=content, 148 | headers=header, 149 | status_code=status) 150 | -------------------------------------------------------------------------------- /py-ssh3/linux_server/authorized_identities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jwt 3 | import logging 4 | from cryptography.hazmat.primitives.serialization import load_ssh_public_key 5 | from cryptography.hazmat.primitives.asymmetric import rsa, ed25519 6 | from cryptography.hazmat.primitives.asymmetric import ec 7 | from cryptography.hazmat.backends import default_backend 8 | from util.type import * 9 | 10 | class Identity: 11 | def verify(self, candidate, base64_conversation_id): 12 | pass 13 | 14 | class PubKeyIdentity(Identity): 15 | def __init__(self, username, pubkey): 16 | self.username = username 17 | self.pubkey = pubkey 18 | 19 | def verify(self, candidate, base64_conversation_id): 20 | if isinstance(candidate, JWTTokenString): 21 | try: 22 | token = jwt.decode(candidate.token, self.pubkey, algorithms=["RS256", "EdDSA"], issuer=self.username, subject="ssh3", audience="unused") 23 | # Perform additional checks on claims here 24 | claims = token.get("claims", {}) 25 | if "exp" not in claims: 26 | return False 27 | if "client_id" not in claims or claims["client_id"] != f"ssh3-{self.username}": 28 | return False 29 | if "jti" not in claims or claims["jti"] != base64_conversation_id: 30 | logging.error("RSA verification failed: the jti claim does not contain the base64-encoded conversation ID") 31 | return False 32 | return True 33 | except Exception as e: 34 | logging.error(f"Invalid private key token: {str(e)}") 35 | return False 36 | else: 37 | return False 38 | 39 | 40 | class OpenIDConnectIdentity(Identity): 41 | def __init__(self, client_id, issuer_url, email): 42 | self.client_id = client_id 43 | self.issuer_url = issuer_url 44 | self.email = email 45 | 46 | def verify(self, candidate, base64_conversation_id): 47 | if isinstance(candidate, JWTTokenString): 48 | try: 49 | token = verify_raw_token(self.client_id, self.issuer_url, candidate.token) 50 | if token.issuer != self.issuer_url or not token.email_verified or token.email != self.email: 51 | return False 52 | return True 53 | except Exception as e: 54 | logging.error(f"Cannot verify raw token: {str(e)}") 55 | return False 56 | return False 57 | 58 | def default_identities_file_names(user): 59 | return [ 60 | os.path.join(user.dir, ".ssh3", "authorized_identities"), 61 | os.path.join(user.dir, ".ssh", "authorized_keys") 62 | ] 63 | 64 | def parse_identity(user, identity_str): 65 | try: 66 | pubkey = load_ssh_public_key(identity_str.encode(), backend=default_backend()) 67 | return PubKeyIdentity(user.username, pubkey) 68 | except Exception as e: 69 | if identity_str.startswith("oidc"): 70 | tokens = identity_str.split() 71 | if len(tokens) != 4: 72 | raise ValueError("Bad identity format for oidc identity") 73 | client_id, issuer_url, email = tokens[1:4] 74 | return OpenIDConnectIdentity(client_id, issuer_url, email) 75 | raise ValueError("Unknown identity format") 76 | 77 | def parse_authorized_identities_file(user, file_path): 78 | identities = [] 79 | with open(file_path, 'r') as file: 80 | for line_number, line in enumerate(file, 1): 81 | line = line.strip() 82 | if not line or line.startswith('#'): 83 | continue 84 | try: 85 | identity = parse_identity(user, line) 86 | identities.append(identity) 87 | except Exception as e: 88 | logging.error(f"Cannot parse identity line {line_number}: {str(e)}") 89 | return identities 90 | 91 | -------------------------------------------------------------------------------- /py-ssh3/linux_server/handlers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple, Callable 3 | import base64 4 | import logging 5 | from http3.http3_server import HttpRequestHandler 6 | from util.linux_util.linux_user import * 7 | from linux_server.authorized_identities import * 8 | import util.globals as glob 9 | from starlette.responses import PlainTextResponse, Response 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | def bearer_auth(headers: dict) -> Tuple[str, bool]: 14 | """ 15 | Extracts the bearer token from the Authorization header. 16 | """ 17 | auth = headers.get(":authorization", "") 18 | if not auth: 19 | return "", False 20 | return parse_bearer_auth(auth) 21 | 22 | def parse_bearer_auth(auth: str) -> Tuple[str, bool]: 23 | """ 24 | Parses an HTTP Bearer Authentication string. 25 | """ 26 | prefix = "Bearer " 27 | if not auth.lower().startswith(prefix.lower()): 28 | return "", False 29 | return auth[len(prefix):], True 30 | 31 | def handle_bearer_auth(username: str, base64_conv_id: str) -> Callable: 32 | """ 33 | HTTP handler function to handle Bearer authentication. 34 | """ 35 | async def inner_handler(request_handler: HttpRequestHandler): 36 | bearer_string, ok = bearer_auth(request_handler.scope["headers"]) 37 | if not ok: 38 | request_handler.send_unauthorized_response() 39 | return 40 | await glob.HANDLER_FUNC(bearer_string, base64_conv_id, request_handler) 41 | 42 | return inner_handler 43 | 44 | async def handle_jwt_auth(username: str, new_conv: object) -> Callable: 45 | """ 46 | Validates JWT token and calls the handler function if authentication is successful. 47 | """ 48 | async def inner_handler(unauth_bearer_string: str, base64_conv_id: str, request_handler: HttpRequestHandler): 49 | user = get_user(username) # Replace with your user retrieval method 50 | if user is None: 51 | request_handler.send_unauthorized_response() 52 | return 53 | 54 | filenames = default_identities_file_names(user) 55 | identities = [] 56 | for filename in filenames: 57 | try: 58 | with open(filename, 'r') as identities_file: 59 | new_identities = parse_authorized_identities_file(user, identities_file) 60 | identities.extend(new_identities) 61 | except FileNotFoundError: 62 | pass # File not found, continue with the next file 63 | except Exception as e: 64 | logging.error(f"Error could not open {filename}: {e}") 65 | request_handler.send_unauthorized_response() 66 | return 67 | 68 | for identity in identities: 69 | if identity.verify(unauth_bearer_string, base64_conv_id): 70 | await glob.HANDLER_FUNC(username, new_conv, request_handler) 71 | return 72 | 73 | request_handler.send_unauthorized_response() 74 | 75 | return inner_handler 76 | 77 | 78 | async def handle_basic_auth(request, conv): 79 | # Extract Basic Auth credentials 80 | username, password, ok = extract_basic_auth(request) 81 | if not ok: 82 | logger.error(f"Invalid basic auth credentials extraction") 83 | status = 401 84 | return Response(status_code=status) 85 | 86 | # Replace this with your own authentication method 87 | ok = user_password_authentication(username, password) 88 | if not ok: 89 | logger.error(f"Invalid basic auth credentials") 90 | status = 401 91 | return Response(status_code=status) 92 | 93 | return await glob.HANDLER_FUNC(username, conv, request) 94 | 95 | def extract_basic_auth(request): 96 | auth_header = request.headers.get('authorization') 97 | logger.info(f"Received authorization header {auth_header}") 98 | if not auth_header: 99 | return None, None, False 100 | 101 | # Basic Auth Parsing 102 | try: 103 | auth_type, auth_info = auth_header.split(' ', 1) 104 | if auth_type.lower() != 'basic': 105 | logger.error(f"Invalid auth type {auth_type}") 106 | return None, None, False 107 | 108 | username, password = base64.b64decode(auth_info).decode().split(':', 1) 109 | logger.info(f"Received username {username} and password {password}") 110 | return username, password, True 111 | except Exception as e: 112 | return None, None, False -------------------------------------------------------------------------------- /py-ssh3/message/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/message/__init__.py -------------------------------------------------------------------------------- /py-ssh3/message/channel_request.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import io 3 | import util.util as util 4 | import util.quic_util as quic_util 5 | import util.type as stype 6 | from typing import Tuple 7 | import ipaddress 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | class PtyRequest: 12 | def __init__(self, term, char_width, char_height, pixel_width, pixel_height, encoded_terminal_modes): 13 | logger.debug("Creating PtyRequest object") 14 | self.term = term 15 | self.char_width = char_width 16 | self.char_height = char_height 17 | self.pixel_width = pixel_width 18 | self.pixel_height = pixel_height 19 | self.encoded_terminal_modes = encoded_terminal_modes 20 | 21 | def length(self): 22 | logger.debug("Calculating length of PtyRequest") 23 | return util.ssh_string_len(self.term) + \ 24 | quic_util.var_int_len(self.char_width) + \ 25 | quic_util.var_int_len(self.char_height) + \ 26 | quic_util.var_int_len(self.pixel_width) + \ 27 | quic_util.var_int_len(self.pixel_height) + \ 28 | util.ssh_string_len(self.encoded_terminal_modes) 29 | 30 | def write(self, buf): 31 | logger.debug("Writing PtyRequest to buffer") 32 | if len(buf) < self.length(): 33 | raise ValueError("Buffer too small to write PTY request") 34 | 35 | consumed = 0 36 | n = util.write_ssh_string(buf, self.term) 37 | consumed += n 38 | 39 | for attr in [self.char_width, self.char_height, self.pixel_width, self.pixel_height]: 40 | buf[consumed:consumed+quic_util.var_int_len(attr)] = quic_util.var_int_to_bytes(attr) 41 | consumed += quic_util.var_int_len(attr) 42 | 43 | n = util.write_ssh_string(buf[consumed:], self.encoded_terminal_modes) 44 | consumed += n 45 | 46 | return consumed 47 | 48 | def request_type_str(self): 49 | logger.debug("Getting request type string for PtyRequest") 50 | return "pty-req" 51 | 52 | def parse_pty_request(buf): 53 | logger.debug("Parsing PtyRequest from buffer") 54 | term = util.parse_ssh_string(buf) 55 | char_width = quic_util.read_var_int(buf) 56 | char_height = quic_util.read_var_int(buf) 57 | pixel_width = quic_util.read_var_int(buf) 58 | pixel_height = quic_util.read_var_int(buf) 59 | encoded_terminal_modes = util.parse_ssh_string(buf) 60 | return PtyRequest(term, char_width, char_height, pixel_width, pixel_height, encoded_terminal_modes) 61 | 62 | class X11Request: 63 | def __init__(self, single_connection, x11_authentication_protocol, x11_authentication_cookie, x11_screen_number): 64 | logger.debug("Creating X11Request object") 65 | self.single_connection = single_connection 66 | self.x11_authentication_protocol = x11_authentication_protocol 67 | self.x11_authentication_cookie = x11_authentication_cookie 68 | self.x11_screen_number = x11_screen_number 69 | 70 | def length(self): 71 | logger.debug("Calculating length of X11Request") 72 | return 1 + \ 73 | util.ssh_string_len(self.x11_authentication_protocol) + \ 74 | util.ssh_string_len(self.x11_authentication_cookie) + \ 75 | quic_util.var_int_len(self.x11_screen_number) 76 | 77 | def write(self, buf): 78 | logger.debug("Writing X11Request to buffer") 79 | if len(buf) < self.length(): 80 | raise ValueError("Buffer too small to write X11 request") 81 | 82 | consumed = 0 83 | buf[consumed] = 1 if self.single_connection else 0 84 | consumed += 1 85 | 86 | n = util.write_ssh_string(buf[consumed:], self.x11_authentication_protocol) 87 | consumed += n 88 | 89 | n = util.write_ssh_string(buf[consumed:], self.x11_authentication_cookie) 90 | consumed += n 91 | 92 | buf[consumed:consumed+quic_util.var_int_len(self.x11_screen_number)] = quic_util.var_int_to_bytes(self.x11_screen_number) 93 | consumed += quic_util.var_int_len(self.x11_screen_number) 94 | 95 | return consumed 96 | 97 | def request_type_str(self): 98 | logger.debug("Getting request type string for X11Request") 99 | return "x11-req" 100 | 101 | def parse_x11_request(buf): 102 | logger.debug("Parsing X11Request from buffer") 103 | single_connection = util.read_boolean(buf) 104 | x11_authentication_protocol = util.parse_ssh_string(buf) 105 | x11_authentication_cookie = util.parse_ssh_string(buf) 106 | x11_screen_number = quic_util.read_var_int(buf) 107 | return X11Request(single_connection, x11_authentication_protocol, x11_authentication_cookie, x11_screen_number) 108 | 109 | class ShellRequest: 110 | def length(self): 111 | logger.debug("Calculating length of ShellRequest") 112 | return 0 113 | 114 | def request_type_str(self): 115 | logger.debug("Getting request type string for ShellRequest") 116 | return "shell" 117 | 118 | def write(self, buf): 119 | logger.debug("Writing ShellRequest to buffer") 120 | return 0 121 | 122 | def parse_shell_request(buf): 123 | logger.debug("Parsing ShellRequest from buffer") 124 | return ShellRequest() 125 | 126 | class ExecRequest: 127 | def __init__(self, command): 128 | logger.debug("Creating ExecRequest object") 129 | self.command = command 130 | 131 | def length(self): 132 | logger.debug("Calculating length of ExecRequest") 133 | return util.ssh_string_len(self.command) 134 | 135 | def request_type_str(self): 136 | logger.debug("Getting request type string for ExecRequest") 137 | return "exec" 138 | 139 | def write(self, buf): 140 | logger.debug("Writing ExecRequest to buffer") 141 | return util.write_ssh_string(buf, self.command) 142 | 143 | def parse_exec_request(buf): 144 | logger.debug("Parsing ExecRequest from buffer") 145 | command = util.parse_ssh_string(buf) 146 | return ExecRequest(command) 147 | 148 | class SubsystemRequest: 149 | def __init__(self, subsystem_name): 150 | logger.debug("Creating SubsystemRequest object") 151 | self.subsystem_name = subsystem_name 152 | 153 | def length(self): 154 | logger.debug("Calculating length of SubsystemRequest") 155 | return util.ssh_string_len(self.subsystem_name) 156 | 157 | def request_type_str(self): 158 | logger.debug("Getting request type string for SubsystemRequest") 159 | return "subsystem" 160 | 161 | def write(self, buf): 162 | logger.debug("Writing SubsystemRequest to buffer") 163 | return util.write_ssh_string(buf, self.subsystem_name) 164 | 165 | def parse_subsystem_request(buf): 166 | logger.debug("Parsing SubsystemRequest from buffer") 167 | subsystem_name = util.parse_ssh_string(buf) 168 | return SubsystemRequest(subsystem_name) 169 | 170 | class WindowChangeRequest: 171 | def __init__(self, char_width, char_height, pixel_width, pixel_height): 172 | logger.debug("Creating WindowChangeRequest object") 173 | self.char_width = char_width 174 | self.char_height = char_height 175 | self.pixel_width = pixel_width 176 | self.pixel_height = pixel_height 177 | 178 | def length(self): 179 | logger.debug("Calculating length of WindowChangeRequest") 180 | return sum(quic_util.var_int_len(attr) for attr in [self.char_width, self.char_height, self.pixel_width, self.pixel_height]) 181 | 182 | def request_type_str(self): 183 | logger.debug("Getting request type string for WindowChangeRequest") 184 | return "window-change" 185 | 186 | def write(self, buf): 187 | logger.debug("Writing WindowChangeRequest to buffer") 188 | consumed = 0 189 | for attr in [self.char_width, self.char_height, self.pixel_width, self.pixel_height]: 190 | buf[consumed:consumed+quic_util.var_int_len(attr)] = quic_util.var_int_to_bytes(attr) 191 | consumed += quic_util.var_int_len(attr) 192 | return consumed 193 | 194 | def parse_window_change_request(buf): 195 | logger.debug("Parsing WindowChangeRequest from buffer") 196 | char_width = quic_util.read_var_int(buf) 197 | char_height = quic_util.read_var_int(buf) 198 | pixel_width = quic_util.read_var_int(buf) 199 | pixel_height = quic_util.read_var_int(buf) 200 | return WindowChangeRequest(char_width, char_height, pixel_width, pixel_height) 201 | 202 | class SignalRequest: 203 | def __init__(self, signal_name_without_sig): 204 | logger.debug("Creating SignalRequest object") 205 | self.signal_name_without_sig = signal_name_without_sig 206 | 207 | def length(self): 208 | logger.debug("Calculating length of SignalRequest") 209 | return util.ssh_string_len(self.signal_name_without_sig) 210 | 211 | def request_type_str(self): 212 | logger.debug("Getting request type string for SignalRequest") 213 | return "signal" 214 | 215 | def write(self, buf): 216 | logger.debug("Writing SignalRequest to buffer") 217 | return util.write_ssh_string(buf, self.signal_name_without_sig) 218 | 219 | def parse_signal_request(buf): 220 | logger.debug("Parsing SignalRequest from buffer") 221 | signal_name_without_sig = util.parse_ssh_string(buf) 222 | return SignalRequest(signal_name_without_sig) 223 | 224 | class ExitStatusRequest: 225 | def __init__(self, exit_status): 226 | logger.debug("Creating ExitStatusRequest object") 227 | self.exit_status = exit_status 228 | 229 | def length(self): 230 | logger.debug("Calculating length of ExitStatusRequest") 231 | return quic_util.var_int_len(self.exit_status) 232 | 233 | def request_type_str(self): 234 | logger.debug("Getting request type string for ExitStatusRequest") 235 | return "exit-status" 236 | 237 | def write(self, buf): 238 | logger.debug("Writing ExitStatusRequest to buffer") 239 | buf[:quic_util.var_int_len(self.exit_status)] = quic_util.var_int_to_bytes(self.exit_status) 240 | return quic_util.var_int_len(self.exit_status) 241 | 242 | def parse_exit_status_request(buf): 243 | logger.debug("Parsing ExitStatusRequest from buffer") 244 | exit_status = quic_util.read_var_int(buf) 245 | return ExitStatusRequest(exit_status) 246 | 247 | class ExitSignalRequest: 248 | def __init__(self, signal_name_without_sig, core_dumped, error_message_utf8, language_tag): 249 | logger.debug("Creating ExitSignalRequest object") 250 | self.signal_name_without_sig = signal_name_without_sig 251 | self.core_dumped = core_dumped 252 | self.error_message_utf8 = error_message_utf8 253 | self.language_tag = language_tag 254 | 255 | def length(self): 256 | logger.debug("Calculating length of ExitSignalRequest") 257 | return util.ssh_string_len(self.signal_name_without_sig) + 1 + \ 258 | util.ssh_string_len(self.error_message_utf8) + util.ssh_string_len(self.language_tag) 259 | 260 | def request_type_str(self): 261 | logger.debug("Getting request type string for ExitSignalRequest") 262 | return "exit-signal" 263 | 264 | def write(self, buf): 265 | logger.debug("Writing ExitSignalRequest to buffer") 266 | consumed = util.write_ssh_string(buf, self.signal_name_without_sig) 267 | buf[consumed] = 1 if self.core_dumped else 0 268 | consumed += 1 269 | consumed += util.write_ssh_string(buf[consumed:], self.error_message_utf8) 270 | consumed += util.write_ssh_string(buf[consumed:], self.language_tag) 271 | return consumed 272 | 273 | def parse_exit_signal_request(buf): 274 | logger.debug("Parsing ExitSignalRequest from buffer") 275 | signal_name_without_sig = util.parse_ssh_string(buf) 276 | core_dumped = util.read_boolean(buf) 277 | error_message_utf8 = util.parse_ssh_string(buf) 278 | language_tag = util.parse_ssh_string(buf) 279 | return ExitSignalRequest(signal_name_without_sig, core_dumped, error_message_utf8, language_tag) 280 | 281 | 282 | class ForwardingRequest: 283 | def __init__(self, protocol, address_family, ip_address, port): 284 | logger.debug("Creating ForwardingRequest object") 285 | self.protocol = protocol 286 | self.address_family = address_family 287 | self.ip_address = ip_address 288 | self.port = port 289 | 290 | def length(self) -> int: 291 | logger.debug("Calculating length of ForwardingRequest") 292 | return quic_util.var_int_len(self.protocol) + \ 293 | quic_util.var_int_len(self.address_family) + \ 294 | len(self.ip_address.packed) + \ 295 | 2 # Length of port 296 | 297 | def request_type_str(self) -> str: 298 | logger.debug("Getting request type string for ForwardingRequest") 299 | return "forward-port" 300 | 301 | def write(self, buf: bytearray) -> int: 302 | logger.debug("Writing ForwardingRequest to buffer") 303 | consumed = 0 304 | buf.extend(quic_util.var_int_to_bytes(self.protocol)) 305 | consumed += quic_util.var_int_len(self.protocol) 306 | 307 | buf.extend(quic_util.var_int_to_bytes(self.address_family)) 308 | consumed += quic_util.var_int_len(self.address_family) 309 | 310 | buf.extend(self.ip_address.packed) 311 | consumed += len(self.ip_address.packed) 312 | 313 | buf.extend(struct.pack('!H', self.port)) # Network byte order (big endian) 314 | consumed += 2 315 | 316 | return consumed 317 | 318 | def parse_forwarding_request(buf: io.BytesIO) -> Tuple[ForwardingRequest, Exception]: 319 | logger.debug("Parsing ForwardingRequest from buffer") 320 | protocol, err = quic_util.read_var_int(buf) 321 | if err: 322 | return None, err 323 | 324 | if protocol not in [stype.SSHForwardingProtocolTCP, stype.SSHProtocolUDP]: 325 | return None, ValueError(f"Invalid protocol number: {protocol}") 326 | 327 | address_family, err = quic_util.read_var_int(buf) 328 | if err: 329 | return None, err 330 | 331 | if address_family == util.SSHAFIpv4: 332 | address_bytes = buf.read(4) 333 | elif address_family == util.SSHAFIpv6: 334 | address_bytes = buf.read(16) 335 | else: 336 | return None, ValueError(f"Invalid address family: {address_family}") 337 | 338 | address = ipaddress.ip_address(address_bytes) 339 | 340 | port_buf = buf.read(2) 341 | port = struct.unpack('!H', port_buf)[0] # Network byte order (big endian) 342 | 343 | return ForwardingRequest(protocol, address_family, address, port), None 344 | 345 | 346 | channel_request_parse_funcs = { 347 | "pty-req": parse_pty_request, 348 | "x11-req": parse_x11_request, 349 | "shell": parse_shell_request, 350 | "exec": parse_exec_request, 351 | "subsystem": parse_subsystem_request, 352 | "window-change": parse_window_change_request, 353 | "signal": parse_signal_request, 354 | "exit-status": parse_exit_status_request, 355 | "exit-signal": parse_exit_signal_request, 356 | "forward-port": parse_forwarding_request 357 | } 358 | -------------------------------------------------------------------------------- /py-ssh3/message/message.py: -------------------------------------------------------------------------------- 1 | import io 2 | import struct 3 | from util.util import * 4 | from util.quic_util import * 5 | from util.wire import append_varint 6 | from message.channel_request import channel_request_parse_funcs 7 | from message.message_type import * 8 | 9 | # Enum for SSH data types 10 | class SSHDataType: 11 | SSH_EXTENDED_DATA_NONE = 0 12 | SSH_EXTENDED_DATA_STDERR = 1 13 | 14 | class Message: 15 | def write(self, buf): 16 | pass 17 | 18 | def length(self): 19 | pass 20 | 21 | def parse_request_message(buf): 22 | request_type, err = parse_ssh_string(buf) 23 | if err: 24 | return None, err 25 | 26 | want_reply = struct.unpack('>b', buf.read(1))[0] 27 | parse_func = channel_request_parse_funcs.get(request_type) 28 | if not parse_func: 29 | return None, ValueError(f"Invalid request message type {request_type}") 30 | 31 | channel_request, err = parse_func(buf) 32 | if err and not isinstance(err, io.EOFError): 33 | return None, err 34 | 35 | return ChannelRequestMessage(want_reply, channel_request), err 36 | 37 | class ChannelRequestMessage(Message): 38 | def __init__(self, want_reply, channel_request): 39 | self.want_reply = want_reply 40 | self.channel_request = channel_request 41 | 42 | def length(self): 43 | # msg type + request type + wantReply + request content 44 | return len(var_int_len(SSH_MSG_CHANNEL_REQUEST)) + \ 45 | ssh_string_len(self.channel_request.request_type_str()) + 1 + \ 46 | self.channel_request.length() 47 | 48 | def write(self, buf): 49 | if len(buf) < self.length(): 50 | raise ValueError(f"Buffer too small to write message for channel request of type {type(self.channel_request)}: {len(buf)} < {self.length()}") 51 | 52 | consumed = 0 53 | msg_type_buf = append_varint(None, SSH_MSG_CHANNEL_REQUEST) 54 | buf[consumed:consumed+len(msg_type_buf)] = msg_type_buf 55 | consumed += len(msg_type_buf) 56 | 57 | n = write_ssh_string(buf[consumed:], self.channel_request.request_type_str()) 58 | consumed += n 59 | 60 | buf[consumed] = 1 if self.want_reply else 0 61 | consumed += 1 62 | 63 | n = self.channel_request.write(buf[consumed:]) 64 | consumed += n 65 | 66 | return consumed 67 | 68 | 69 | class ChannelOpenConfirmationMessage(Message): 70 | def __init__(self, max_packet_size): 71 | self.max_packet_size = max_packet_size 72 | 73 | def write(self, buf): 74 | msg_type = SSH_MSG_CHANNEL_OPEN_CONFIRMATION 75 | buf.write(struct.pack('>Q', msg_type)) 76 | buf.write(struct.pack('>Q', self.max_packet_size)) 77 | 78 | def length(self): 79 | return 2 * struct.calcsize('>Q') 80 | 81 | class ChannelOpenFailureMessage(Message): 82 | def __init__(self, reason_code, error_message_utf8, language_tag): 83 | self.reason_code = reason_code 84 | self.error_message_utf8 = error_message_utf8 85 | self.language_tag = language_tag 86 | 87 | def write(self, buf): 88 | msg_type = SSH_MSG_CHANNEL_OPEN_FAILURE 89 | buf.write(struct.pack('>Q', msg_type)) 90 | buf.write(struct.pack('>Q', self.reason_code)) 91 | self._write_ssh_string(buf, self.error_message_utf8) 92 | self._write_ssh_string(buf, self.language_tag) 93 | 94 | def length(self): 95 | msg_type_len = struct.calcsize('>Q') 96 | reason_code_len = struct.calcsize('>Q') 97 | error_message_len = self._ssh_string_length(self.error_message_utf8) 98 | language_tag_len = self._ssh_string_length(self.language_tag) 99 | return msg_type_len + reason_code_len + error_message_len + language_tag_len 100 | 101 | def _write_ssh_string(self, buf, value): 102 | encoded_value = value.encode('utf-8') 103 | buf.write(struct.pack('>I', len(encoded_value))) 104 | buf.write(encoded_value) 105 | 106 | def _ssh_string_length(self, value): 107 | return struct.calcsize('>I') + len(value.encode('utf-8')) 108 | 109 | class DataOrExtendedDataMessage(Message): 110 | def __init__(self, data_type, data): 111 | self.data_type = data_type 112 | self.data = data 113 | 114 | def write(self, buf): 115 | if self.data_type == SSHDataType.SSH_EXTENDED_DATA_NONE: 116 | msg_type = SSH_MSG_CHANNEL_DATA 117 | else: 118 | msg_type = SSH_MSG_CHANNEL_EXTENDED_DATA 119 | buf.write(struct.pack('>Q', self.data_type)) 120 | buf.write(struct.pack('>Q', msg_type)) 121 | self._write_ssh_string(buf, self.data) 122 | 123 | def length(self): 124 | msg_type_len = struct.calcsize('>Q') 125 | if self.data_type == SSHDataType.SSH_EXTENDED_DATA_NONE: 126 | return msg_type_len + self._ssh_string_length(self.data) 127 | data_type_len = struct.calcsize('>Q') 128 | return msg_type_len + data_type_len + self._ssh_string_length(self.data) 129 | 130 | def _write_ssh_string(self, buf, value): 131 | encoded_value = value.encode('utf-8') 132 | buf.write(struct.pack('>I', len(encoded_value))) 133 | buf.write(encoded_value) 134 | 135 | def _ssh_string_length(self, value): 136 | return struct.calcsize('>I') + len(value.encode('utf-8')) 137 | 138 | def parse_channel_open_confirmation_message(buf): 139 | max_packet_size = struct.unpack('>Q', buf.read(8))[0] 140 | return ChannelOpenConfirmationMessage(max_packet_size) 141 | 142 | def parse_channel_open_failure_message(buf): 143 | reason_code = struct.unpack('>Q', buf.read(8))[0] 144 | error_message_utf8 = parse_ssh_string(buf) 145 | language_tag = parse_ssh_string(buf) 146 | return ChannelOpenFailureMessage(reason_code, error_message_utf8, language_tag) 147 | 148 | def parse_data_message(buf): 149 | data = parse_ssh_string(buf) 150 | return DataOrExtendedDataMessage(SSHDataType.SSH_EXTENDED_DATA_NONE, data) 151 | 152 | def parse_extended_data_message(buf): 153 | data_type = struct.unpack('>Q', buf.read(8))[0] 154 | data = parse_ssh_string(buf) 155 | return DataOrExtendedDataMessage(data_type, data) 156 | 157 | def parse_message(r): 158 | type_id = struct.unpack('>Q', r.read(8))[0] 159 | if type_id == SSH_MSG_CHANNEL_REQUEST: 160 | pass # Implement ParseRequestMessage function here 161 | elif type_id == SSH_MSG_CHANNEL_OPEN_CONFIRMATION: 162 | return parse_channel_open_confirmation_message(r) 163 | elif type_id == SSH_MSG_CHANNEL_OPEN_FAILURE: 164 | return parse_channel_open_failure_message(r) 165 | elif type_id in (SSH_MSG_CHANNEL_DATA, SSH_MSG_CHANNEL_EXTENDED_DATA): 166 | if type_id == SSH_MSG_CHANNEL_DATA: 167 | return parse_data_message(r) 168 | else: 169 | return parse_extended_data_message(r) 170 | else: 171 | raise ValueError("Not implemented") 172 | 173 | # Example usage: 174 | if __name__ == "__main__": 175 | # You can test the code here 176 | pass 177 | -------------------------------------------------------------------------------- /py-ssh3/message/message_type.py: -------------------------------------------------------------------------------- 1 | # Constants for SSH message types 2 | SSH_MSG_DISCONNECT = 1 3 | SSH_MSG_IGNORE = 2 4 | SSH_MSG_UNIMPLEMENTED = 3 5 | SSH_MSG_DEBUG = 4 6 | SSH_MSG_SERVICE_REQUEST = 5 7 | SSH_MSG_SERVICE_ACCEPT = 6 8 | SSH_MSG_KEXINIT = 20 9 | SSH_MSG_NEWKEYS = 21 10 | SSH_MSG_USERAUTH_REQUEST = 50 11 | SSH_MSG_USERAUTH_FAILURE = 51 12 | SSH_MSG_USERAUTH_SUCCESS = 52 13 | SSH_MSG_USERAUTH_BANNER = 53 14 | SSH_MSG_GLOBAL_REQUEST = 80 15 | SSH_MSG_REQUEST_SUCCESS = 81 16 | SSH_MSG_REQUEST_FAILURE = 82 17 | SSH_MSG_CHANNEL_OPEN = 90 18 | SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 91 19 | SSH_MSG_CHANNEL_OPEN_FAILURE = 92 20 | SSH_MSG_CHANNEL_WINDOW_ADJUST = 93 21 | SSH_MSG_CHANNEL_DATA = 94 22 | SSH_MSG_CHANNEL_EXTENDED_DATA = 95 23 | SSH_MSG_CHANNEL_EOF = 96 24 | SSH_MSG_CHANNEL_CLOSE = 97 25 | SSH_MSG_CHANNEL_REQUEST = 98 26 | SSH_MSG_CHANNEL_SUCCESS = 99 27 | SSH_MSG_CHANNEL_FAILURE = 100 28 | -------------------------------------------------------------------------------- /py-ssh3/server_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | # import asyncio 4 | import signal 5 | import fcntl 6 | import struct 7 | import termios 8 | import logging 9 | # import asyncio 10 | import os 11 | import logging 12 | from util.linux_util import * 13 | import message.message as ssh3_message 14 | import message.channel_request as ssh3_channel 15 | import argparse 16 | import sys 17 | import util.waitgroup as sync 18 | from http3.http3_server import * 19 | from aioquic.quic.configuration import QuicConfiguration 20 | from ssh3.ssh3_server import SSH3Server 21 | from ssh3.conversation import Conversation 22 | from ssh3.channel import * 23 | from util.linux_util.linux_user import User, get_user 24 | from linux_server.auth import * 25 | import util.globals as glob 26 | from http3.http3_server import * 27 | 28 | from starlette.applications import Router 29 | from starlette.applications import Starlette 30 | from starlette.types import Receive, Scope, Send 31 | from starlette.routing import Route 32 | 33 | log = logging.getLogger(__name__) 34 | 35 | # Define signal mappings 36 | signals = { 37 | "SIGABRT": signal.SIGABRT, 38 | "SIGALRM": signal.SIGALRM, 39 | "SIGBUS": signal.SIGBUS, 40 | "SIGCHLD": signal.SIGCHLD, 41 | "SIGCONT": signal.SIGCONT, 42 | "SIGFPE": signal.SIGFPE, 43 | "SIGHUP": signal.SIGHUP, 44 | "SIGILL": signal.SIGILL, 45 | "SIGINT": signal.SIGINT, 46 | "SIGIO": signal.SIGIO, 47 | "SIGIOT": signal.SIGIOT, 48 | "SIGKILL": signal.SIGKILL, 49 | "SIGPIPE": signal.SIGPIPE, 50 | "SIGPOLL": signal.SIGPOLL, 51 | "SIGPROF": signal.SIGPROF, 52 | "SIGPWR": signal.SIGPWR, 53 | "SIGQUIT": signal.SIGQUIT, 54 | "SIGRTMAX": signal.SIGRTMAX, 55 | "SIGRTMIN": signal.SIGRTMIN, 56 | "SIGSEGV": signal.SIGSEGV, 57 | "SIGSTOP": signal.SIGSTOP, 58 | "SIGSYS": signal.SIGSYS, 59 | "SIGTERM": signal.SIGTERM, 60 | "SIGTRAP": signal.SIGTRAP, 61 | "SIGTSTP": signal.SIGTSTP, 62 | "SIGTTIN": signal.SIGTTIN, 63 | "SIGTTOU": signal.SIGTTOU, 64 | "SIGURG": signal.SIGURG, 65 | "SIGUSR1": signal.SIGUSR1, 66 | "SIGUSR2": signal.SIGUSR2, 67 | "SIGVTALRM": signal.SIGVTALRM, 68 | "SIGWINCH": signal.SIGWINCH, 69 | "SIGXCPU": signal.SIGXCPU, 70 | "SIGXFSZ": signal.SIGXFSZ 71 | } 72 | 73 | 74 | class ChannelType: 75 | LARVAL = 0 76 | OPEN = 1 77 | 78 | channel_type = ChannelType() 79 | 80 | class OpenPty: 81 | def __init__(self, pty, tty, win_size, term): 82 | self.pty = pty 83 | self.tty = tty 84 | self.win_size = win_size 85 | self.term = term 86 | 87 | class RunningCommand: 88 | def __init__(self, stdout_r, stderr_r, stdin_w): 89 | self.stdout_r = stdout_r 90 | self.stderr_r = stderr_r 91 | self.stdin_w = stdin_w 92 | 93 | class RunningSession: 94 | def __init__(self): 95 | self.channel_state = None 96 | self.pty = None 97 | self.running_cmd = None 98 | self.auth_agent_socket_path = None 99 | 100 | running_sessions = {} 101 | 102 | def set_winsize(f, char_width, char_height, pix_width, pix_height): 103 | winsize = struct.pack("HHHH", char_height, char_width, pix_width, pix_height) 104 | fcntl.ioctl(f, termios.TIOCSWINSZ, winsize) 105 | 106 | def setup_env(user, running_command, auth_agent_socket_path): 107 | # Set up the environment variables for the subprocess 108 | running_command.cmd.env.append(f"HOME={user.dir}") 109 | running_command.cmd.env.append(f"USER={user.username}") 110 | running_command.cmd.env.append("PATH=/usr/bin:/bin:/usr/sbin:/sbin") 111 | if auth_agent_socket_path != "": 112 | running_command.cmd.env.append(f"SSH_AUTH_SOCK={auth_agent_socket_path}") 113 | 114 | async def forward_udp_in_background(ctx, channel, conn): 115 | async def receive_datagram(): 116 | while True: 117 | try: 118 | datagram, err = await channel.receive_datagram(ctx) 119 | if err is not None: 120 | log.error(f"could not receive datagram: {err}") 121 | return 122 | await conn.write(datagram) 123 | except asyncio.CancelledError: 124 | return 125 | 126 | async def send_datagram(): 127 | buf = bytearray(1500) 128 | while True: 129 | try: 130 | n, err = await conn.read(buf) 131 | if err is not None: 132 | log.error(f"could read datagram on UDP socket: {err}") 133 | return 134 | await channel.send_datagram(buf[:n]) 135 | except asyncio.CancelledError: 136 | return 137 | 138 | receive_task = asyncio.create_task(receive_datagram()) 139 | send_task = asyncio.create_task(send_datagram()) 140 | 141 | try: 142 | await asyncio.gather(receive_task, send_task) 143 | finally: 144 | receive_task.cancel() 145 | send_task.cancel() 146 | await asyncio.gather(receive_task, send_task, return_exceptions=True) 147 | 148 | async def forward_tcp_in_background(ctx, channel, conn): 149 | async def read_from_tcp_socket(): 150 | try: 151 | while True: 152 | data = await conn.recv(4096) 153 | if not data: 154 | break 155 | await channel.send_data(data) 156 | except asyncio.CancelledError: 157 | pass 158 | except Exception as e: 159 | log.error(f"could read data on TCP socket: {e}") 160 | 161 | async def read_from_ssh_channel(): 162 | buf = bytearray(channel.max_packet_size()) 163 | try: 164 | while True: 165 | data = await channel.receive_data() 166 | if not data: 167 | break 168 | await conn.sendall(data) 169 | except asyncio.CancelledError: 170 | pass 171 | except Exception as e: 172 | log.error(f"could send data on channel: {e}") 173 | 174 | read_tcp_task = asyncio.create_task(read_from_tcp_socket()) 175 | read_ssh_task = asyncio.create_task(read_from_ssh_channel()) 176 | 177 | try: 178 | await asyncio.gather(read_tcp_task, read_ssh_task) 179 | finally: 180 | read_tcp_task.cancel() 181 | read_ssh_task.cancel() 182 | await asyncio.gather(read_tcp_task, read_ssh_task, return_exceptions=True) 183 | 184 | async def exec_cmd_in_background(channel, open_pty, user, running_command, auth_agent_socket_path): 185 | # Execute command in background and handle its output 186 | pass 187 | 188 | def new_pty_req(user, channel, request, want_reply): 189 | # Handle PTY request 190 | pass 191 | 192 | def new_x11_req(user, channel, request, want_reply): 193 | # Handle X11 request (if applicable) 194 | pass 195 | 196 | def new_command(user, channel, login_shell, command, args): 197 | # Execute a new command 198 | pass 199 | 200 | def new_shell_req(user, channel, want_reply): 201 | # Handle shell request 202 | pass 203 | 204 | def new_command_in_shell_req(user, channel, want_reply, command): 205 | # Execute command within shell 206 | pass 207 | 208 | def new_subsystem_req(user, channel, request, want_reply): 209 | # Handle subsystem request 210 | pass 211 | 212 | def new_window_change_req(user, channel, request, want_reply): 213 | # Handle window change request 214 | pass 215 | 216 | def new_signal_req(user, channel, request, want_reply): 217 | # Handle signal request 218 | pass 219 | 220 | def new_exit_status_req(user, channel, request, want_reply): 221 | # Handle exit status request 222 | pass 223 | 224 | def new_exit_signal_req(user, channel, request, want_reply): 225 | # Handle exit signal request 226 | pass 227 | 228 | async def handle_udp_forwarding_channel(user: User, conv: Conversation, channel: UDPForwardingChannelImpl): 229 | """ 230 | Handle UDP forwarding for a specific channel in an SSH3 conversation. 231 | 232 | Args: 233 | user: The user object containing user information. # TODO seems not used 234 | conv: The SSH3 conversation object. # TODO seems not used 235 | channel: The UDP forwarding channel implementation. 236 | 237 | Returns: 238 | None if successful, or an error if any occurs. 239 | """ 240 | # Note: Rights for socket creation are not checked 241 | # The socket is opened with the process's uid and gid 242 | try: 243 | # Create a UDP socket 244 | conn = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 245 | # Connect the socket to the remote address 246 | conn.connect(channel.remote_addr) 247 | 248 | # Start forwarding UDP in background 249 | await forward_udp_in_background(channel, conn) 250 | except Exception as e: 251 | return e 252 | return None 253 | 254 | async def handle_tcp_forwarding_channel(user: User, conv: Conversation, channel:TCPForwardingChannelImpl): 255 | """ 256 | Handle TCP forwarding for a specific channel in an SSH3 conversation. 257 | 258 | Args: 259 | user: The user object containing user information. 260 | conv: The SSH3 conversation object. 261 | channel: The TCP forwarding channel implementation. 262 | 263 | Returns: 264 | None if successful, or an error if any occurs. 265 | """ 266 | # Note: Rights for socket creation are not checked 267 | # The socket is opened with the process's uid and gid 268 | try: 269 | # Create a TCP socket 270 | conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 271 | # Connect the socket to the remote address 272 | conn.connect(channel.remote_addr) 273 | 274 | # Start forwarding TCP in background 275 | await forward_tcp_in_background(channel, conn) 276 | except Exception as e: 277 | return e 278 | return None 279 | 280 | def new_data_req(user, channel, request): 281 | # Handle data request 282 | running_session, ok = running_sessions[channel] 283 | if not ok: 284 | return Exception("could not find running session for channel") 285 | if running_session.channel_state == channel_type.LARVAL: 286 | return Exception("invalid data on ssh channel with LARVAL state") 287 | if channel.channel_type =="session": 288 | if running_session.running_cmd != None: 289 | if request.data_type == ssh3_message.SSHDataType.SSH_EXTENDED_DATA_NONE: 290 | running_session.running_cmd.stdin_w.write(request.data) 291 | else: 292 | return Exception("invalid data type on ssh channel with session channel type pty") 293 | else: 294 | return Exception("could not find running command for channel") 295 | return None 296 | 297 | async def handle_auth_agent_socket_conn(conn, conversation): 298 | # Handle authentication agent socket connection 299 | pass 300 | 301 | async def listen_and_accept_auth_sockets(conversation, listener): 302 | # Listen and accept authentication agent sockets 303 | pass 304 | 305 | async def open_agent_socket_and_forward_agent(conv, user): 306 | # Open an agent socket and forward agent 307 | pass 308 | 309 | def file_exists(path): 310 | # Check if a file exists 311 | return os.path.exists(path) 312 | 313 | async def main(): 314 | parser = argparse.ArgumentParser() 315 | parser.add_argument("--bind", default="[::]:443", help="the address:port pair to listen to, e.g. 0.0.0.0:443") 316 | parser.add_argument("-v", "--verbose", action="store_true", help="verbose mode, if set") 317 | parser.add_argument("--enablePasswordLogin", action="store_true", help="if set, enable password authentication (disabled by default)") 318 | parser.add_argument("--urlPath", default="/ssh3-term", help="the secret URL path on which the ssh3 server listens") 319 | parser.add_argument("--generateSelfSignedCert", action="store_true", help="if set, generates a self-self-signed cerificate and key that will be stored at the paths indicated by the -cert and -key args (they must not already exist)") 320 | parser.add_argument("--certPath", default="./cert.pem", help="the filename of the server certificate (or fullchain)") 321 | parser.add_argument("--keyPath", default="./priv.pem", help="the filename of the certificate private key") 322 | parser.add_argument("-l","--secrets-log", type=str, help="log secrets to a file, for use with Wireshark") 323 | args = parser.parse_args() 324 | 325 | router = Starlette( 326 | debug=True, 327 | routes=[ 328 | Route(path=args.urlPath, 329 | endpoint=handle_auths, 330 | methods=["CONNECT"]) 331 | ] 332 | ) 333 | 334 | async def app(scope: Scope, receive: Receive, send: Send) -> None: 335 | log.info(f"Received scope: {scope}") 336 | log.info(f"Received receive: {receive}") 337 | log.info(f"Received send: {send}") 338 | log.info(f"Nb Router: {len(router.router.routes)}") 339 | for route in router.router.routes: 340 | log.info(f"Route: {route.path}") 341 | await router(scope, receive, send) 342 | 343 | glob.APPLICATION = app 344 | 345 | if args.verbose: 346 | logging.basicConfig(level=logging.DEBUG) 347 | util.configure_logger("debug") 348 | else: 349 | log_level = os.getenv("SSH3_LOG_LEVEL") 350 | if log_level: 351 | util.configure_logger(log_level) 352 | numeric_level = getattr(log, log_level.upper(), None) 353 | if not isinstance(numeric_level, int): 354 | raise ValueError(f"Invalid log level: {log_level}") 355 | logging.basicConfig(level=numeric_level) 356 | 357 | logFileName = os.getenv("SSH3_LOG_FILE") 358 | if not logFileName or logFileName == "": 359 | logFileName = "ssh3_server.log" 360 | logging.basicConfig(filename=logFileName, level=logging.INFO) 361 | 362 | if not args.enablePasswordLogin: 363 | log.error("password login is currently disabled") 364 | 365 | certPathExists = file_exists(args.certPath) 366 | keyPathExists = file_exists(args.keyPath) 367 | 368 | if not args.generateSelfSignedCert: 369 | if not certPathExists: 370 | log.error(f"the \"{args.certPath}\" certificate file does not exist") 371 | if not keyPathExists: 372 | log.error(f"the \"{args.keyPath}\" certificate private key file does not exist") 373 | if not certPathExists or not keyPathExists: 374 | log.error("If you have no certificate and want a security comparable to traditional SSH host keys, you can generate a self-signed certificate using the -generate-selfsigned-cert arg or using the following script:") 375 | log.error("https://github.com/ElNiak/py-ssh3/blob/main/generate_openssl_selfsigned_certificate.sh") 376 | sys.exit(-1) 377 | log.info(f"Using certificate \"{args.certPath}\" and private key \"{args.keyPath}\"") 378 | else: 379 | if certPathExists: 380 | log.error(f"asked for generating a certificate but the \"{args.certPath}\" file already exists") 381 | if keyPathExists: 382 | log.error(f"asked for generating a private key but the \"{args.keyPath}\" file already exists") 383 | if certPathExists or keyPathExists: 384 | #sys.exit(-1) 385 | pass 386 | pubkey, privkey, err = util.generate_key() 387 | if err != None: 388 | log.error(f"could not generate private key: {err}") 389 | sys.exit(-1) 390 | cert, err = util.generate_cert(privkey, pubkey) 391 | if err != None: 392 | log.error(f"could not generate certificate: {err}") 393 | sys.exit(-1) 394 | 395 | err = util.dump_cert_and_key_to_files(cert, privkey, args.certPath, args.keyPath) 396 | if err != None: 397 | log.error(f"could not save certificate and key to files: {err}") 398 | sys.exit(-1) 399 | 400 | # TODO aioquic does not support this yet (disable or not 0rtt) 401 | 402 | # quicConf = defaults. 403 | 404 | # open SSL log file 405 | if args.secrets_log: 406 | secrets_log_file = open(args.secrets_log, "a") 407 | else: 408 | secrets_log_file = None 409 | 410 | defaults = QuicConfiguration(is_client=False) 411 | configuration = QuicConfiguration( 412 | alpn_protocols=H3_ALPN, 413 | is_client=False, 414 | max_datagram_frame_size=65536, 415 | max_datagram_size=defaults.max_datagram_size, 416 | quic_logger = QuicFileLogger("qlogs/"), 417 | secrets_log_file=secrets_log_file 418 | ) 419 | 420 | # load SSL certificate and key 421 | configuration.load_cert_chain(args.certPath, args.keyPath) 422 | 423 | glob.CONFIGURATION = configuration 424 | 425 | wg = sync.WaitGroup() 426 | wg.add(1) 427 | 428 | log.info(f"Starting server on {args.bind}") 429 | 430 | 431 | async def handle_conv(authenticatedUsername: str, conv: Conversation): 432 | log.info(f"Handling authentification for {authenticatedUsername}") 433 | authUser = get_user(authenticatedUsername) 434 | 435 | log.info(f"Handling conversation {conv}") 436 | channel, err = await conv.accept_channel() 437 | if err != None: 438 | log.error(f"could not accept channel: {err}") 439 | return 440 | 441 | if channel.isinstance(UDPForwardingChannelImpl): 442 | handle_udp_forwarding_channel(authUser, conv, channel) 443 | elif channel.isinstance(TCPForwardingChannelImpl): 444 | handle_tcp_forwarding_channel(authUser, conv, channel) 445 | 446 | # Default 447 | running_sessions[channel] = RunningSession( 448 | channel_state=channel_type.LARVAL, 449 | pty=None, 450 | running_cmd=None 451 | ) 452 | 453 | def handle_session_channel(): 454 | generic_message, err = channel.next_message() 455 | if err != None: 456 | log.error(f"could not get next message: {err}") 457 | return 458 | if generic_message is None: 459 | return 460 | if generic_message.isinstance(ssh3_message.ChannelRequestMessage): 461 | if generic_message.channel_request.isinstance(ssh3_channel.PtyRequest): 462 | err = new_pty_req(authUser, channel, generic_message.channel_request, generic_message.want_reply) 463 | elif generic_message.channel_request.isinstance(ssh3_channel.X11Request): 464 | err = new_x11_req(authUser, channel, generic_message.channel_request, generic_message.want_reply) 465 | elif generic_message.channel_request.isinstance(ssh3_channel.ExecRequest): 466 | err = new_command(authUser, channel, False, generic_message.channel_request.command, generic_message.channel_request.args) 467 | elif generic_message.channel_request.isinstance(ssh3_channel.ShellRequest): 468 | err = new_shell_req(authUser, channel, generic_message.want_reply) 469 | elif generic_message.channel_request.isinstance(ssh3_channel.CommandInShellRequest): 470 | err = new_command_in_shell_req(authUser, channel, generic_message.want_reply, generic_message.channel_request.command) 471 | elif generic_message.channel_request.isinstance(ssh3_channel.SubsystemRequest): 472 | err = err = new_subsystem_req(authUser, channel, generic_message.channel_request, generic_message.want_reply) 473 | elif generic_message.channel_request.isinstance(ssh3_channel.WindowChangeRequest): 474 | err = new_window_change_req(authUser, channel, generic_message.channel_request, generic_message.want_reply) 475 | elif generic_message.channel_request.isinstance(ssh3_channel.SignalRequest): 476 | err = new_signal_req(authUser, channel, generic_message.channel_request, generic_message.want_reply) 477 | elif generic_message.channel_request.isinstance(ssh3_channel.ExitStatusRequest): 478 | err = new_exit_status_req(authUser, channel, generic_message.channel_request, generic_message.want_reply) 479 | elif generic_message.isinstance(ssh3_message.DataOrExtendedDataMessage): 480 | running_session, ok = running_sessions[channel] 481 | if not ok: 482 | log.error("could not find running session for channel") 483 | return 484 | if running_session.channel_state == channel_type.LARVAL: 485 | if generic_message.data == "forward)agent": 486 | running_session.auth_agent_socket_path, err = open_agent_socket_and_forward_agent(conv, authUser) 487 | else: 488 | err = Exception("invalid data on ssh channel with LARVAL state") 489 | else: 490 | err = new_data_req(authUser, channel, generic_message) 491 | if err != None: 492 | log.error(f"error while processing message: {generic_message}: {err}",) 493 | return 494 | 495 | handle_session_channel() 496 | 497 | log.info(f"Nb Router: {len(router.router.routes)}") 498 | for route in router.router.routes: 499 | log.info(f"Route: {route.path}") 500 | 501 | # authenticated_username, new_conv, request_handler 502 | # asyncio.create_task(ssh3Handler(qconn, new_conv, conversations_manager)) 503 | session_ticket_store = SessionTicketStore() 504 | glob.SESSION_TICKET_HANDLER = session_ticket_store.add 505 | quic_server = await serve( 506 | args.bind.split(":")[0], 507 | args.bind.split(":")[1], 508 | configuration=configuration, 509 | create_protocol=AuthHttpServerProtocol, 510 | session_ticket_fetcher=session_ticket_store.pop, 511 | session_ticket_handler=glob.SESSION_TICKET_HANDLER, 512 | ) 513 | 514 | ssh3Server = SSH3Server( 515 | 30000,quic_server, 516 | 10, 517 | conversation_handler=handle_conv) 518 | ssh3Handler = ssh3Server.get_http_handler_func() 519 | 520 | glob.ENABLE_PASSWORD_LOGIN = args.enablePasswordLogin 521 | glob.HANDLER_FUNC = ssh3Handler 522 | glob.QUIC_SERVER = quic_server 523 | 524 | log.info(f"Listening on {args.bind} with URL path {args.urlPath}") 525 | 526 | await asyncio.Future() 527 | 528 | # if err != None: 529 | # log.error(f"could not serve: {err}") 530 | # sys.exit(-1) 531 | 532 | wg.done() 533 | 534 | wg.wait() 535 | 536 | 537 | if __name__ == "__main__": 538 | asyncio.run(main()) 539 | -------------------------------------------------------------------------------- /py-ssh3/ssh3/channel.py: -------------------------------------------------------------------------------- 1 | 2 | import struct 3 | from typing import Tuple, Optional, Callable 4 | import ipaddress 5 | from util import * 6 | from abc import ABC, abstractmethod 7 | import util.type as stype 8 | import util.util as util 9 | import util.quic_util as quic_util 10 | import util.wire as wire 11 | from ssh3.conversation import * 12 | from message.message import * 13 | from message.channel_request import * 14 | from aioquic.quic.stream import QuicStreamReceiver 15 | import socket 16 | import logging 17 | import logging 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | SSH_FRAME_TYPE = 0xaf3627e6 22 | 23 | class ChannelOpenFailure(Exception): 24 | def __init__(self, reason_code, error_msg): 25 | self.reason_code = reason_code # uint64 26 | self.error_msg = error_msg # string 27 | super().__init__(f"Channel open failure: reason: {reason_code}: {error_msg}") 28 | 29 | class MessageOnNonConfirmedChannel(Exception): 30 | def __init__(self, message: Message): 31 | self.message = message 32 | super().__init__(f"A message of type {type(self.message)} has been received on a non-confirmed channel") 33 | 34 | class MessageOnNonConfirmedChannel(Exception): 35 | def __init__(self, channel_id: int): 36 | self.channel_id = channel_id 37 | super().__init__(f"A datagram has been received on non-datagram channel {self.channel_id}") 38 | 39 | class SentDatagramOnNonDatagramChannel(Exception): 40 | def __init__(self, channel_id: int): 41 | self.channel_id = channel_id 42 | super().__init__(f"A datagram has been sent on non-datagram channel {self.channel_id}") 43 | 44 | 45 | class ChannelInfo: 46 | def __init__(self, 47 | max_packet_size, 48 | conv_stream_id, 49 | conv_id, 50 | channel_id, 51 | channel_type): 52 | self.max_packet_size = max_packet_size 53 | self.conv_stream_id = conv_stream_id 54 | self.conv_id = conv_id 55 | self.channel_id = channel_id 56 | self.channel_type = channel_type 57 | 58 | 59 | class Channel(ABC): 60 | @abstractmethod 61 | def channel_id(self) -> stype.ChannelID: 62 | pass 63 | 64 | @abstractmethod 65 | def conversation_id(self) -> ConversationID: 66 | pass 67 | 68 | @abstractmethod 69 | def conversation_stream_id(self) -> int: 70 | logger.debug("conversation_stream_id() called") 71 | pass 72 | 73 | @abstractmethod 74 | def next_message(self) -> Message: 75 | logger.debug("next_message() called") 76 | pass 77 | 78 | @abstractmethod 79 | def receive_datagram(self) -> bytes: 80 | logger.debug("receive_datagram() called") 81 | pass 82 | 83 | @abstractmethod 84 | def send_datagram(self, datagram: bytes) -> None: 85 | logger.debug(f"send_datagram() called with datagram: {datagram}") 86 | pass 87 | 88 | @abstractmethod 89 | def send_request(self, r: ChannelRequestMessage) -> None: 90 | logger.debug(f"send_request() called with request: {r}") 91 | pass 92 | 93 | @abstractmethod 94 | def cancel_read(self) -> None: 95 | logger.debug("cancel_read() called") 96 | pass 97 | 98 | @abstractmethod 99 | def close(self) -> None: 100 | logger.debug("close() called") 101 | pass 102 | 103 | @abstractmethod 104 | def max_packet_size(self) -> int: 105 | logger.debug("max_packet_size() called") 106 | pass 107 | 108 | @abstractmethod 109 | def write_data(self, data_buf: bytes, data_type: SSHDataType) -> int: 110 | logger.debug(f"write_data() called with data_buf: {data_buf}, data_type: {data_type}") 111 | pass 112 | 113 | @abstractmethod 114 | def channel_type(self) -> str: 115 | logger.debug("channel_type() called") 116 | pass 117 | 118 | @abstractmethod 119 | def confirm_channel(self, max_packet_size: int) -> None: 120 | logger.debug(f"confirm_channel() called with max_packet_size: {max_packet_size}") 121 | pass 122 | 123 | @abstractmethod 124 | def set_datagram_sender(self, sender: Callable[[bytes], None]) -> None: 125 | logger.debug(f"set_datagram_sender() called with sender: {sender}") 126 | pass 127 | 128 | @abstractmethod 129 | def wait_add_datagram(self, datagram: bytes) -> None: 130 | logger.debug(f"wait_add_datagram() called with datagram: {datagram}") 131 | pass 132 | 133 | @abstractmethod 134 | def add_datagram(self, datagram: bytes) -> bool: 135 | logger.debug(f"add_datagram() called with datagram: {datagram}") 136 | pass 137 | 138 | @abstractmethod 139 | def maybe_send_header(self) -> None: 140 | logger.debug("maybe_send_header() called") 141 | pass 142 | 143 | @abstractmethod 144 | def set_datagram_queue(self, queue: util.DatagramsQueue) -> None: 145 | logger.debug(f"set_datagram_queue() called with queue: {queue}") 146 | pass 147 | 148 | 149 | class ChannelImpl(Channel): 150 | def __init__(self, channel_info: ChannelInfo, recv: QuicStreamReceiver, send): 151 | logger.debug(f"Creating ChannelImpl object with channel_info: {channel_info}, recv: {recv}, send: {send}") 152 | self.channel_info = channel_info 153 | self.confirm_sent = False 154 | self.confirm_received = False 155 | self.header = [] 156 | self.datagram_sender = None 157 | self.channel_close_listener = None 158 | self.recv = recv 159 | self.send = send 160 | self.datagrams_queue = None 161 | 162 | # Handlers and data handling attributes 163 | self.pty_req_handler = None 164 | self.x11_req_handler = None 165 | self.shell_req_handler = None 166 | self.exec_req_handler = None 167 | self.subsystem_req_handler = None 168 | self.window_change_req_handler = None 169 | self.signal_req_handler = None 170 | self.exit_status_req_handler = None 171 | self.exit_signal_req_handler = None 172 | self.channel_data_handler = None 173 | 174 | def __init__(self, conversation_stream_id: int, conversation_id: ConversationID, channel_id: int, 175 | channel_type: str, max_packet_size: int, recv: QuicStreamReceiver, send, 176 | datagram_sender: Callable, channel_close_listener: Callable, send_header: bool, 177 | confirm_sent: bool, confirm_received: bool, datagrams_queue_size: int, additional_header_bytes: bytes): 178 | logger.debug(f"Creating ChannelImpl object with conversation_stream_id: {conversation_stream_id}, conversation_id: {conversation_id}, channel_id: {channel_id}, channel_type: {channel_type}, max_packet_size: {max_packet_size}, recv: {recv}, send: {send}, datagram_sender: {datagram_sender}, channel_close_listener: {channel_close_listener}, send_header: {send_header}, confirm_sent: {confirm_sent}, confirm_received: {confirm_received}, datagrams_queue_size: {datagrams_queue_size}, additional_header_bytes: {additional_header_bytes}") 179 | self.channel_info = ChannelInfo(max_packet_size, conversation_stream_id, conversation_id, channel_id, channel_type) 180 | self.recv = recv 181 | self.send = send 182 | self.datagrams_queue = stype.DatagramsQueue(datagrams_queue_size) 183 | self.datagram_sender = datagram_sender 184 | self.channel_close_listener = channel_close_listener 185 | self.header = build_header(conversation_stream_id, channel_type, max_packet_size, additional_header_bytes) if send_header else None 186 | self.confirm_sent = confirm_sent 187 | self.confirm_received = confirm_received 188 | 189 | # Handlers and data handling attributes 190 | self.pty_req_handler = None 191 | self.x11_req_handler = None 192 | self.shell_req_handler = None 193 | self.exec_req_handler = None 194 | self.subsystem_req_handler = None 195 | self.window_change_req_handler = None 196 | self.signal_req_handler = None 197 | self.exit_status_req_handler = None 198 | self.exit_signal_req_handler = None 199 | self.channel_data_handler = None 200 | 201 | def channel_id(self): 202 | logger.debug("channel_id() called") 203 | return self.channel_info.channel_id 204 | 205 | def conversation_stream_id(self): 206 | logger.debug("conversation_stream_id() called") 207 | return self.channel_info.conversation_stream_id 208 | 209 | def conversation_id(self): 210 | logger.debug("conversation_id() called") 211 | return self.channel_info.conversation_id 212 | 213 | def next_message(self): 214 | logger.debug("next_message() called") 215 | # The error is EOF only if no bytes were read. If an EOF happens 216 | # after reading some but not all the bytes, next_message returns 217 | # ErrUnexpectedEOF. 218 | return parse_message(self.recv) # Assuming parse_message is defined 219 | 220 | def next_message(self): 221 | logger.debug("next_message() called") 222 | generic_message, err = self.next_message() 223 | if err: 224 | return None, err 225 | 226 | if isinstance(generic_message, ChannelOpenConfirmationMessage): 227 | self.confirm_received = True 228 | return self.next_message() 229 | elif isinstance(generic_message, ChannelOpenFailureMessage): 230 | return None, ChannelOpenFailure(generic_message.reason_code, generic_message.error_message_utf8) 231 | 232 | if not self.confirm_sent: 233 | return None, MessageOnNonConfirmedChannel(generic_message) 234 | return generic_message, None 235 | 236 | def maybe_send_header(self): 237 | logger.debug("maybe_send_header() called") 238 | if self.header: 239 | written, err = self.send.write(self.header) 240 | if err: 241 | return err 242 | self.header = self.header[written:] 243 | return None 244 | 245 | def write_data(self, data_buf, data_type): 246 | logger.debug(f"write_data() called with data_buf: {data_buf}, data_type: {data_type}") 247 | err = self.maybe_send_header() 248 | if err: 249 | return 0, err 250 | written = 0 251 | while data_buf: 252 | data_msg = DataOrExtendedDataMessage(data_type, "") 253 | empty_msg_len = data_msg.length() 254 | msg_len = min(self.channel_info.max_packet_size - empty_msg_len, len(data_buf)) 255 | 256 | data_msg.data = data_buf[:msg_len] 257 | data_buf = data_buf[msg_len:] 258 | 259 | msg_buf = data_msg.write() 260 | n, err = self.send.write(msg_buf) 261 | written += n 262 | if err: 263 | return written, err 264 | return written, None 265 | 266 | def confirm_channel(self, max_packet_size): 267 | logger.debug(f"confirm_channel() called with max_packet_size: {max_packet_size}") 268 | err = self.send_message(ssh3.ChannelOpenConfirmationMessage(max_packet_size)) 269 | if not err: 270 | self.confirm_sent = True 271 | return err 272 | 273 | def send_message(self, message): 274 | logger.debug(f"send_message() called with message: {message}") 275 | err = self.maybe_send_header() 276 | if err: 277 | return err 278 | buf = message.write() 279 | self.send.write(buf) 280 | return None 281 | 282 | def wait_add_datagram(self, datagram): 283 | logger.debug(f"wait_add_datagram() called with datagram: {datagram}") 284 | return self.datagrams_queue.wait_add(datagram) 285 | 286 | def add_datagram(self, datagram): 287 | logger.debug(f"add_datagram() called with datagram: {datagram}") 288 | return self.datagrams_queue.add(datagram) 289 | 290 | def receive_datagram(self): 291 | logger.debug("receive_datagram() called") 292 | return self.datagrams_queue.wait_next() 293 | 294 | def send_datagram(self, datagram): 295 | logger.debug(f"send_datagram() called with datagram: {datagram}") 296 | self.maybe_send_header() 297 | if not self.datagram_sender: 298 | return SentDatagramOnNonDatagramChannel(self.channel_id()) 299 | return self.datagram_sender(datagram) 300 | 301 | def send_request(self, request): 302 | logger.debug(f"send_request() called with request: {request}") 303 | # TODO: make it thread safe 304 | return self.send_message(request) 305 | 306 | def cancel_read(self): 307 | logger.debug("cancel_read() called") 308 | self.recv.cancel_read(42) 309 | 310 | def close(self): 311 | logger.debug("close() called") 312 | self.send.close() 313 | 314 | def max_packet_size(self): 315 | logger.debug("max_packet_size() called") 316 | return self.channel_info.max_packet_size 317 | 318 | def channel_type(self): 319 | logger.debug("channel_type() called") 320 | return self.channel_info.channel_type 321 | 322 | def set_datagram_sender(self, datagram_sender): 323 | logger.debug(f"set_datagram_sender() called with datagram_sender: {datagram_sender}") 324 | self.datagram_sender = datagram_sender 325 | 326 | def set_datagram_queue(self, queue): 327 | logger.debug(f"set_datagram_queue() called with queue: {queue}") 328 | self.datagrams_queue = queue 329 | 330 | class UDPForwardingChannelImpl(ChannelImpl): 331 | def __init__(self, channel_info, remote_addr): 332 | logger.debug(f"Creating UDPForwardingChannelImpl object with channel_info: {channel_info}, remote_addr: {remote_addr}") 333 | super().__init__(channel_info) 334 | self.remote_addr = remote_addr 335 | # Additional initialization 336 | 337 | class TCPForwardingChannelImpl(ChannelImpl): 338 | def __init__(self, channel_info, remote_addr): 339 | logger.debug(f"Creating TCPForwardingChannelImpl object with channel_info: {channel_info}, remote_addr: {remote_addr}") 340 | super().__init__(channel_info) 341 | self.remote_addr = remote_addr 342 | # Additional initialization 343 | 344 | def build_header(conversation_stream_id: int, channel_type: str, max_packet_size: int, additional_bytes: Optional[bytes]) -> bytes: 345 | logger.debug(f"build_header() called with conversation_stream_id: {conversation_stream_id}, channel_type: {channel_type}, max_packet_size: {max_packet_size}, additional_bytes: {additional_bytes}") 346 | channel_type_buf = util.write_ssh_string(channel_type) 347 | buf = wire.append_varint(b'', SSH_FRAME_TYPE) 348 | buf += wire.append_varint(buf, conversation_stream_id) 349 | buf += channel_type_buf 350 | buf += wire.append_varint(buf, max_packet_size) 351 | if additional_bytes: 352 | buf += additional_bytes 353 | return buf 354 | 355 | 356 | def build_forwarding_channel_additional_bytes(remote_addr: ipaddress.IPv4Address, port: int) -> bytes: 357 | logger.debug(f"build_forwarding_channel_additional_bytes() called with remote_addr: {remote_addr}, port: {port}") 358 | buf = b'' 359 | address_family = stype.SSHAFIpv4 if len(remote_addr) == 4 else stype.SSHAFIpv6 360 | buf += wire.append_varint(buf, address_family) 361 | buf += remote_addr 362 | port_buf = struct.pack('>H', port) # Big-endian format for uint16 363 | buf += port_buf 364 | return buf 365 | 366 | def parse_header(channel_id: int, reader) -> Tuple[int, str, int, Optional[Exception]]: 367 | logger.debug(f"parse_header() called with channel_id: {channel_id}, reader: {reader}") 368 | conversation_control_stream_id, err = quic_util.read_var_int(reader) 369 | if err: 370 | return 0, "", 0, err 371 | channel_type, err = util.parse_ssh_string(reader) 372 | if err: 373 | return 0, "", 0, err 374 | max_packet_size, err = quic_util.read_var_int(reader) 375 | if err: 376 | return 0, "", 0, err 377 | return conversation_control_stream_id, channel_type, max_packet_size, None 378 | 379 | def parse_forwarding_header(channel_id, buf): 380 | logger.debug(f"parse_forwarding_header() called with channel_id: {channel_id}, buf: {buf}") 381 | address_family, err = quic_util.read_var_int(buf) 382 | if err: 383 | return None, 0, err 384 | 385 | if address_family == stype.SSHAFIpv4: 386 | address = buf.read(4) 387 | elif address_family == stype.SSHAFIpv6: 388 | address = buf.read(16) 389 | else: 390 | return None, 0, ValueError(f"Invalid address family: {address_family}") 391 | 392 | port_buf = buf.read(2) 393 | if not port_buf: 394 | return None, 0, ValueError("Port buffer read failed") 395 | 396 | port = struct.unpack('>H', port_buf)[0] # Unpack big-endian uint16 397 | return address, port, None 398 | 399 | 400 | def parse_udp_forwarding_header(channel_id, buf): 401 | logger.debug(f"parse_udp_forwarding_header() called with channel_id: {channel_id}, buf: {buf}") 402 | address, port, err = parse_forwarding_header(channel_id, buf) 403 | if err: 404 | return None, err 405 | return socket.getaddrinfo(address, port, socket.AF_UNSPEC, socket.SOCK_DGRAM), None 406 | 407 | def parse_tcp_forwarding_header(channel_id, buf): 408 | logger.debug(f"parse_tcp_forwarding_header() called with channel_id: {channel_id}, buf: {buf}") 409 | address, port, err = parse_forwarding_header(channel_id, buf) 410 | if err: 411 | return None, err 412 | return socket.getaddrinfo(address, port, socket.AF_UNSPEC, socket.SOCK_STREAM), None 413 | 414 | # Define types for SSH channel request handlers 415 | PtyReqHandler = Callable[[Channel, PtyRequest, bool], None] 416 | X11ReqHandler = Callable[[Channel, X11Request, bool], None] 417 | ShellReqHandler = Callable[[Channel, ShellRequest, bool], None] 418 | ExecReqHandler = Callable[[Channel, ExecRequest, bool], None] 419 | SubsystemReqHandler = Callable[[Channel, SubsystemRequest, bool], None] 420 | WindowChangeReqHandler = Callable[[Channel, WindowChangeRequest, bool], None] 421 | SignalReqHandler = Callable[[Channel, SignalRequest, bool], None] 422 | ExitStatusReqHandler = Callable[[Channel, ExitStatusRequest, bool], None] 423 | ExitSignalReqHandler = Callable[[Channel, ExitSignalRequest, bool], None] 424 | 425 | # Define a type for handling SSH channel data 426 | ChannelDataHandler = Callable[[Channel, SSHDataType, str], None] 427 | 428 | # Define an interface for channel close listeners 429 | class ChannelCloseListener: 430 | def onChannelClose(self, channel: Channel): 431 | logger.debug(f"onChannelClose() called with channel: {channel}") 432 | pass 433 | -------------------------------------------------------------------------------- /py-ssh3/ssh3/conversation.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import ssl 4 | import logging 5 | import contextlib 6 | from typing import Callable, Tuple 7 | import util.util as util 8 | from ssh3.version import parse_version 9 | from http3.http3_client import * 10 | from ssh3.resources_manager import * 11 | import secrets 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | class ConversationID: 16 | def __init__(self, value: bytes): 17 | self.value = value # 32 bytes 18 | assert len(value) <= 32 19 | log.debug(f"ConversationID object created with value: {value}") 20 | 21 | def __str__(self): 22 | return base64.b64encode(self.value).decode('utf-8') 23 | 24 | from ssh3.channel import * 25 | 26 | 27 | def random_bytes(length: int) -> bytes: 28 | return secrets.token_bytes(length) 29 | 30 | 31 | def generate_conversation_id(tls_connection_state: ssl.SSLObject) -> Tuple[bytes, Exception]: 32 | log.debug("generate_conversation_id function called") 33 | result = random_bytes(32), None 34 | log.debug(f"generate_conversation_id function returned result: {result}") 35 | return result 36 | # try: # TODO 37 | # if not tls_connection_state: 38 | # return b'', Exception("TLS connection state is None") 39 | # key_material = tls_connection_state.export_keying_material("EXPORTER-SSH3", 32) 40 | # if len(key_material) != 32: 41 | # raise ValueError(f"TLS returned a tls-exporter with the wrong length ({len(key_material)} instead of 32)") 42 | # return key_material, None 43 | # except Exception as e: 44 | # return b'', e 45 | 46 | class Conversation: 47 | def __init__(self, control_stream, max_packet_size, default_datagrams_queue_size, stream_creator, message_sender, channels_manager, conversation_id): 48 | self.control_stream = control_stream 49 | self.max_packet_size = max_packet_size 50 | self.default_datagrams_queue_size = default_datagrams_queue_size 51 | self.stream_creator = stream_creator 52 | self.message_sender = message_sender 53 | self.channels_manager = channels_manager 54 | self.context = None # Will be set using context manager 55 | self.cancel_context = None # Will be set using context manager 56 | self.conversation_id = conversation_id 57 | self.channels_accept_queue = None # Set to an appropriate queue type 58 | log.debug(f"Conversation ({self}) object created with control_stream: {control_stream}, max_packet_size: {max_packet_size}, default_datagrams_queue_size: {default_datagrams_queue_size}, stream_creator: {stream_creator}, message_sender: {message_sender}, channels_manager: {channels_manager}, conversation_id: {conversation_id}") 59 | 60 | def __init__(self, max_packet_size, default_datagrams_queue_size, tls: ssl.SSLContext): 61 | self.control_stream = None 62 | self.channels_accept_queue = util.AcceptQueue() # Assuming a suitable implementation 63 | self.stream_creator = None 64 | self.max_packet_size = max_packet_size 65 | self.default_datagrams_queue_size = default_datagrams_queue_size 66 | self.channels_manager = ChannelsManager() # Assuming a suitable implementation 67 | self.conversation_id, err = generate_conversation_id(tls) 68 | if err: 69 | log.error(f"could not generate conversation ID: {err}") 70 | raise err 71 | log.debug(f"Conversation ({self}) object created with max_packet_size: {max_packet_size}, default_datagrams_queue_size: {default_datagrams_queue_size}, tls: {tls}") 72 | 73 | def __str__(self) -> str: 74 | return f"Conversation {self.conversation_id}, {self.control_stream}, {self.channels_accept_queue}, {self.stream_creator}, {self.max_packet_size}, {self.default_datagrams_queue_size}, {self.channels_manager}" 75 | 76 | async def accept_channel(self): 77 | while True: 78 | if not self.channels_accept_queue.empty(): 79 | channel = await self.channels_accept_queue.get() 80 | channel.confirm_channel(self.max_packet_size) 81 | self.channels_manager.add_channel(channel) 82 | return channel 83 | else: 84 | await asyncio.sleep(0.1) # Small delay to prevent busy waiting 85 | 86 | 87 | async def establish_client_conversation(self, request:HttpRequest, round_tripper: RoundTripper): 88 | log.debug(f"establish_client_conversation function called with request: {request}, round_tripper: {round_tripper}") 89 | # Stream hijacker 90 | def stream_hijacker(frame_type, stream_id, data, end_stream): 91 | # Your stream hijacking logic 92 | """ 93 | Process data received on a hijacked stream. 94 | 95 | :param frame_type: The type of frame received (inferred from the data) 96 | :param stream_id: The ID of the stream 97 | :param data: The data received on the stream 98 | :param end_stream: Flag indicating if the stream has ended 99 | """ 100 | log.debug(f"Stream hijacker called with frame_type: {frame_type}, stream_id: {stream_id}, data: {data}, end_stream: {end_stream}") 101 | if frame_type != SSH_FRAME_TYPE: 102 | # If the frame type is not what we're interested in, ignore it 103 | return False, None 104 | try: 105 | # Parse the header from the data 106 | control_stream_id, channel_type, max_packet_size = parse_header(stream_id, data) 107 | # Create a new channel 108 | channel_info = ChannelInfo( 109 | conversation_id=self.conversation_id, 110 | conversation_stream_id=control_stream_id, 111 | channel_id=stream_id, 112 | channel_type=channel_type, 113 | max_packet_size=max_packet_size 114 | ) 115 | new_channel = ChannelImpl( 116 | channel_info.conversation_stream_id, 117 | channel_info.conversation_id, 118 | channel_info.channel_id, 119 | channel_info.channel_type, 120 | channel_info.max_packet_size, 121 | stream_reader=None, # Replace with the actual stream reader 122 | stream_writer=None, # Replace with the actual stream writer 123 | channels_manager=self.channels_manager, 124 | default_datagrams_queue_size=self.default_datagrams_queue_size 125 | ) 126 | # Set the datagram sender and add the new channel to the queue 127 | new_channel.set_datagram_sender(self.get_datagram_sender_for_channel(new_channel.channel_id)) 128 | self.channels_accept_queue.add(new_channel) 129 | return True, None 130 | except Exception as e: 131 | # Log the error and return False with the error 132 | log.error(f"Error in stream hijacker: {e}") 133 | return False, e 134 | 135 | # Assigning the hijacker to the round_tripper 136 | round_tripper.hijack_stream = stream_hijacker 137 | 138 | log.debug(f"Establishing conversation with server: {request}") 139 | response = await round_tripper.round_trip_opt(request=request, opt=RoundTripOpt(dont_close_request_stream=True)) 140 | 141 | log.debug(f"Established conversation with server: {response}") 142 | # Performing the HTTP request 143 | for http_event in response: 144 | if isinstance(http_event, HeadersReceived): 145 | log.debug(f"Established conversation with server: {http_event}") 146 | server_version = -1 147 | major, minor, patch = -1,-1,-1 148 | status = 500 149 | for h,v in http_event.headers: 150 | log.debug(f"Established conversation with server: {h} : {v}") 151 | if b':status' in h: 152 | status = int(v.decode('utf-8')) 153 | elif b'server' in h: 154 | server_version = v.decode('utf-8') 155 | major, minor, patch = parse_version(server_version) 156 | log.debug(f"Established conversation with server: {server_version}") 157 | if major == -1 or minor == -1 or patch == -1: 158 | raise Exception(f"Could not parse server version: {server_version}") 159 | if status == 200: # TODO 160 | self.control_stream = http_event.http_stream 161 | self.stream_creator = http_event.stream_creator 162 | self.message_sender = http_event.http_connection._quic 163 | await self.handle_datagrams(round_tripper) 164 | return None 165 | elif status == 401: 166 | raise Exception("Authentication failed") 167 | else: 168 | raise Exception(f"Returned non-200 and non-401 status code: {status}") 169 | 170 | async def handle_datagrams(self, connection): 171 | log.debug("handle_datagrams function called with connection: {connection}") 172 | while True: 173 | try: 174 | datagram = await connection.datagram_received() 175 | # Process datagram 176 | # ... 177 | except asyncio.CancelledError: 178 | break 179 | 180 | async def close(self): 181 | log.debug("close function called") 182 | # Close the conversation 183 | # ... 184 | # self.control_stream.close() # TODO not implemented in aioquic 185 | self.control_stream = None 186 | 187 | async def add_datagram(self,dgram): 188 | log.debug(f"add_datagram function called with dgram: {dgram}") 189 | # Add a datagram to the conversation 190 | # ... 191 | self.message_sender.send_datagram(dgram) 192 | 193 | async def new_client_conversation(max_packet_size, queue_size, tls_state): 194 | log.debug(f"new_client_conversation function called with max_packet_size: {max_packet_size}, queue_size: {queue_size}, tls_state: {tls_state}") 195 | # Additional logic for creating a new client conversation 196 | result = Conversation(max_packet_size, queue_size, tls_state) 197 | log.debug(f"new_client_conversation function returned result: {result}") 198 | return result 199 | 200 | async def new_server_conversation(max_packet_size, queue_size, tls_state,control_stream,stream_creator): 201 | log.debug(f"new_client_conversation function called with max_packet_size: {max_packet_size}, queue_size: {queue_size}, tls_state: {tls_state}") 202 | # Additional logic for creating a new client conversation 203 | result = Conversation(max_packet_size, queue_size, tls_state) 204 | result.control_stream = control_stream 205 | result.stream_creator = stream_creator 206 | log.debug(f"new_client_conversation function returned result: {result}") 207 | return result -------------------------------------------------------------------------------- /py-ssh3/ssh3/identity.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class Identity(ABC): 4 | @abstractmethod 5 | def set_authorization_header(self, request, username: str, conversation): 6 | pass 7 | 8 | @abstractmethod 9 | def auth_hint(self) -> str: 10 | pass 11 | 12 | @abstractmethod 13 | def __str__(self): 14 | pass 15 | -------------------------------------------------------------------------------- /py-ssh3/ssh3/known_host.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import ssl 4 | from typing import Dict, List, Tuple 5 | 6 | class InvalidKnownHost(Exception): 7 | def __init__(self, line: str): 8 | self.line = line 9 | 10 | def __str__(self): 11 | return f"invalid known host line: {self.line}" 12 | 13 | def parse_known_hosts(filename: str) -> Tuple[Dict[str, List[ssl.SSLContext]], List[int], Exception]: 14 | known_hosts = {} 15 | invalid_lines = [] 16 | 17 | if not os.path.exists(filename): 18 | # The known hosts file simply does not exist yet, so there is no known host 19 | return known_hosts, invalid_lines, None 20 | 21 | with open(filename, 'r') as file: 22 | for i, line in enumerate(file): 23 | known_host = line.strip() 24 | fields = known_host.split() 25 | if len(fields) != 3 or fields[1] != "x509-certificate": 26 | invalid_lines.append(i) 27 | continue 28 | 29 | try: 30 | cert_bytes = base64.b64decode(fields[2]) 31 | cert = ssl.DER_cert_to_PEM_cert(cert_bytes) 32 | if fields[0] not in known_hosts: 33 | known_hosts[fields[0]] = [] 34 | known_hosts[fields[0]].append(cert) 35 | except (base64.binascii.Error, ValueError): 36 | invalid_lines.append(i) 37 | continue 38 | 39 | return known_hosts, invalid_lines, None 40 | 41 | def append_known_host(filename: str, host: str, cert: ssl.SSLContext) -> Exception: 42 | encoded_cert = base64.b64encode(cert.public_bytes()).decode('utf-8') 43 | 44 | try: 45 | with open(filename, 'a') as known_hosts: 46 | known_hosts.write(f"{host} x509-certificate {encoded_cert}\n") 47 | except Exception as e: 48 | return e 49 | 50 | return None 51 | -------------------------------------------------------------------------------- /py-ssh3/ssh3/resources_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | # Assuming util and http3 modules are available in your Python environment 4 | # or you have equivalent functionality implemented in Python 5 | import util 6 | import http3 7 | 8 | # type ControlStreamID = uint64 9 | 10 | 11 | class ConversationsManager: 12 | def __init__(self, connection): 13 | self.connection = connection 14 | self.conversations = {} 15 | self.lock = asyncio.Lock() 16 | 17 | async def add_conversation(self, conversation): 18 | async with self.lock: 19 | self.conversations[conversation.control_stream.stream_id] = conversation 20 | 21 | async def get_conversation(self, id): 22 | async with self.lock: 23 | return self.conversations.get(id, None) 24 | 25 | async def remove_conversation(self, conversation): 26 | async with self.lock: 27 | self.conversations.pop(conversation.control_stream.stream_id, None) 28 | 29 | class ChannelsManager: 30 | def __init__(self): 31 | self.channels = {} 32 | self.dangling_dgram_queues = {} 33 | self.lock = asyncio.Lock() 34 | 35 | async def add_channel(self, channel): 36 | async with self.lock: 37 | dgrams_queue = self.dangling_dgram_queues.pop(channel.channel_id, None) 38 | if dgrams_queue: 39 | channel.set_dgram_queue(dgrams_queue) 40 | self.channels[channel.channel_id] = channel 41 | 42 | async def add_dangling_datagrams_queue(self, id, queue): 43 | async with self.lock: 44 | channel = self.channels.get(id) 45 | if channel: 46 | while True: 47 | dgram = queue.next() 48 | if dgram is None: 49 | break 50 | channel.add_datagram(dgram) 51 | else: 52 | self.dangling_dgram_queues[id] = queue 53 | 54 | async def get_channel(self, id): 55 | async with self.lock: 56 | return self.channels.get(id, None) 57 | 58 | async def remove_channel(self, channel): 59 | async with self.lock: 60 | self.channels.pop(channel.channel_id, None) 61 | -------------------------------------------------------------------------------- /py-ssh3/ssh3/ssh3_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | import jwt 4 | import paramiko 5 | from typing import Tuple, List 6 | from ssh3.identity import Identity 7 | import time 8 | from util.util import jwt_signing_method_from_crypto_pubkey 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class OIDCAuthMethod: 14 | def __init__(self, do_pkce: bool, config): 15 | self.do_pkce = do_pkce 16 | self.config = config 17 | 18 | def oidc_config(self): 19 | return self.config 20 | 21 | def into_identity(self, bearer_token: str) -> 'Identity': 22 | return RawBearerTokenIdentity(bearer_token) 23 | 24 | class PasswordAuthMethod: 25 | def __init__(self): 26 | pass 27 | 28 | def into_identity(self, password: str) -> 'Identity': 29 | return PasswordBasedIdentity(password) 30 | 31 | class PrivkeyFileAuthMethod: 32 | def __init__(self, filename: str): 33 | self.filename : str = filename 34 | 35 | def filename(self) -> str: 36 | return self.filename 37 | 38 | def into_identity(self, password:str) -> 'Identity': 39 | return 40 | if self.filename.startswith("~/"): 41 | dirname = os.path.expanduser("~") 42 | self.filename = os.path.join(dirname, self.filename[2:]) 43 | with open(self.filename, "rb") as pub_key: 44 | # pub_key = f.read() 45 | try: 46 | if password is None: 47 | signer = paramiko.RSAKey.from_private_key(pub_key) 48 | else: 49 | signer = paramiko.RSAKey.from_private_key(pub_key, password) 50 | except paramiko.SSHException as e: 51 | logger.error(f"Failed to load private key file {self.filename}: {e}") 52 | exit(0) 53 | try: 54 | crypto_signer = paramiko.AgentKey(signer) 55 | except paramiko.SSHException as e: 56 | logger.error(f"crypto_signer - Failed to load private key file {self.filename}: {e}") 57 | exit(0) 58 | 59 | signing_method, e = jwt_signing_method_from_crypto_pubkey(crypto_signer) 60 | if e: 61 | logger.error(f"Failed to load private key file {self.filename}: {e}") 62 | exit(0) 63 | return PrivkeyFileIdentity(self.filename, crypto_signer, signing_method) 64 | 65 | 66 | def into_identity_without_passphrase(self) -> 'Identity': 67 | # Implement logic to read the private key file 68 | return self.into_identity_with_passphrase(None) 69 | 70 | def into_identity_with_passphrase(self, passphrase: str) -> 'Identity': 71 | # Implement logic to read the private key file with passphrase 72 | return self.into_identity(passphrase) 73 | 74 | class AgentAuthMethod: 75 | def __init__(self, pubkey): 76 | self.pubkey = pubkey 77 | 78 | def into_identity(self, agent): 79 | return AgentBasedIdentity(self.pubkey, agent) 80 | 81 | class AgentBasedIdentity(Identity): 82 | def __init__(self, pubkey, agent: paramiko.Agent): 83 | self.pubkey = pubkey 84 | self.agent = agent 85 | 86 | def set_authorization_header(self, req, username: str, conversation): 87 | # Implement logic to use SSH agent for signing 88 | # and setting the Authorization header 89 | pass 90 | 91 | def auth_hint(self) -> str: 92 | return "pubkey" 93 | 94 | def __str__(self): 95 | return f"agent-identity: {self.pubkey.get_name()}" 96 | 97 | class PasswordBasedIdentity(Identity): 98 | def __init__(self, password:str): 99 | self.password = password 100 | 101 | def set_authorization_header(self, req, username: str, conversation): 102 | # Implement logic to use SSH agent for signing 103 | # and setting the Authorization header 104 | req.headers['authorization'] = f"Basic {base64.b64encode(f'{username}:{self.password}'.encode('utf-8')).decode('utf-8')}" 105 | 106 | def auth_hint(self): 107 | return "password" 108 | 109 | def __str__(self): 110 | return "password-identity" 111 | 112 | class PrivkeyFileIdentity(Identity): 113 | def __init__(self, private_key, algorithm): 114 | self.private_key = private_key 115 | self.algorithm = algorithm 116 | 117 | def set_authorization_header(self, req, username: str, conversation): 118 | # Implement logic to use SSH agent for signing 119 | # and setting the Authorization header 120 | try: 121 | bearer_token = build_jwt_bearer_token(self.algorithm, self.private_key, username, conversation) 122 | req.headers['uthorization'] = f'Bearer {bearer_token}' 123 | except Exception as e: 124 | raise e 125 | 126 | def auth_hint(self): 127 | return "pubkey" 128 | 129 | def __str__(self): 130 | return f"pubkey-identity: ALG={self.algorithm}" 131 | 132 | class RawBearerTokenIdentity(Identity): 133 | def __init__(self, bearer_token: str): 134 | self.bearer_token = bearer_token 135 | 136 | def set_authorization_header(self, req, username: str, conversation): 137 | req.headers['authorization'] = f"Bearer {self.bearer_token}" 138 | 139 | def auth_hint(self) -> str: 140 | return "jwt" 141 | 142 | def __str__(self): 143 | return "raw-bearer-identity" 144 | 145 | def build_jwt_bearer_token(signing_method, key, username: str, conversation) -> str: 146 | # Implement JWT token generation logic 147 | try: 148 | conv_id = conversation.conversation_id() 149 | b64_conv_id = base64.b64encode(conv_id).decode('utf-8') 150 | 151 | # Prepare the token claims 152 | claims = { 153 | "iss": username, 154 | "iat": int(time.time()), 155 | "exp": int(time.time()) + 10, # Token expiration 10 seconds from now 156 | "sub": "ssh3", 157 | "aud": "unused", 158 | "client_id": f"ssh3-{username}", 159 | "jti": b64_conv_id 160 | } 161 | 162 | # Sign the token 163 | encoded_jwt = jwt.encode(claims, key, algorithm=signing_method) 164 | return encoded_jwt 165 | 166 | except Exception as e: 167 | return None, str(e) 168 | 169 | def get_config_for_host(host: str, config) -> Tuple[str, int, str, List]: 170 | # Parse SSH config for the given host 171 | if config is None: 172 | return None, -1, None, [] 173 | 174 | hostname = config.lookup(host).get("hostname", host) 175 | port = int(config.lookup(host).get("port", -1)) 176 | user = config.lookup(host).get("user") 177 | auth_methods_to_try = [] 178 | 179 | identity_files = config.lookup(host).get("IdentityFile", []) 180 | for identity_file in identity_files: 181 | identity_file_path = os.path.expanduser(identity_file) 182 | if os.path.exists(identity_file_path): 183 | auth_methods_to_try.append(identity_file_path) 184 | 185 | return hostname, port, user, auth_methods_to_try 186 | -------------------------------------------------------------------------------- /py-ssh3/ssh3/ssh3_server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import random 4 | from typing import Callable, Tuple 5 | from ssh3.conversation import Conversation, ConversationsManager 6 | from http3.http3_server import * 7 | from ssh3.resources_manager import * 8 | import util.util as util 9 | import util.quic_util as quic_util 10 | from ssh3.version import parse_version 11 | from ssh3.channel import * 12 | from starlette.responses import PlainTextResponse, Response 13 | from aioquic.quic.connection import NetworkAddress, QuicConnection 14 | log = logging.getLogger(__name__) 15 | 16 | class SSH3Server: 17 | def __init__(self, max_packet_size, 18 | h3_server: HttpServerProtocol, 19 | default_datagram_queue_size, 20 | conversation_handler, *args, **kwargs): 21 | # super().__init__(*args, **kwargs) 22 | self.h3_server = h3_server 23 | self.max_packet_size = max_packet_size 24 | self.conversations = {} # Map of StreamCreator to ConversationManager 25 | self.conversation_handler = conversation_handler 26 | self.lock = asyncio.Lock() 27 | self.new_conv = None 28 | log.debug("SSH3Server initialized") 29 | #self.h3_server._stream_handler = self.stream_hijacker 30 | 31 | def stream_hijacker(frame_type, stream_id, data, end_stream): 32 | # Your stream hijacking logic 33 | """ 34 | Process data received on a hijacked stream. 35 | 36 | :param frame_type: The type of frame received (inferred from the data) 37 | :param stream_id: The ID of the stream 38 | :param data: The data received on the stream 39 | :param end_stream: Flag indicating if the stream has ended 40 | """ 41 | log.debug(f"Stream hijacker called with frame_type: {frame_type}, stream_id: {stream_id}, data: {data}, end_stream: {end_stream}") 42 | if frame_type != SSH_FRAME_TYPE: 43 | # If the frame type is not what we're interested in, ignore it 44 | return False, None 45 | 46 | try: 47 | conversation_control_stream_id, channel_type, max_packet_size = parse_header(stream_id, data) 48 | conversations_manager = self.get_conversations_manager() 49 | conversation = conversations_manager.get_conversation(conversation_control_stream_id) # Implement this function 50 | if conversation is None: 51 | err = Exception(f"Could not find SSH3 conversation with control stream id {conversation_control_stream_id} for new channel {stream_id}") 52 | log.error(str(err)) 53 | return False, err 54 | 55 | channel_info = ChannelInfo( 56 | conversation_id=conversation.conversation_id, 57 | conversation_stream_id=conversation_control_stream_id, 58 | channel_id=stream_id, 59 | channel_type=channel_type, 60 | max_packet_size=max_packet_size 61 | ) 62 | 63 | new_channel = ChannelImpl( 64 | channel_info.conversation_stream_id, 65 | channel_info.conversation_id, 66 | channel_info.channel_id, 67 | channel_info.channel_type, 68 | channel_info.max_packet_size, 69 | stream_reader=None, # Replace with the actual stream reader 70 | stream_writer=None, # Replace with the actual stream writer 71 | channels_manager=conversation.channels_manager, 72 | default_datagrams_queue_size=conversation.default_datagrams_queue_size 73 | ) 74 | 75 | if channel_type == "direct-udp": 76 | udp_addr = parse_udp_forwarding_header(channel_info.channel_id, stream) 77 | new_channel.set_datagram_sender(conversation.get_datagram_sender_for_channel(channel_info.channel_id)) 78 | new_channel = UDPForwardingChannelImpl(new_channel, udp_addr) 79 | elif channel_type == "direct-tcp": 80 | tcp_addr = parse_tcp_forwarding_header(channel_info.channel_id, stream) 81 | new_channel = TCPForwardingChannelImpl(new_channel, tcp_addr) 82 | 83 | conversation.channels_accept_queue.add(new_channel) 84 | return True, None 85 | except Exception as e: 86 | log.error(f"Error in stream hijacker: {e}") 87 | return False, e 88 | 89 | # self._stream_handler = stream_hijacker 90 | 91 | 92 | async def get_conversations_manager(self, stream_creator): 93 | async with self.lock: 94 | return self.conversations.get(stream_creator, None) 95 | 96 | async def get_or_create_conversations_manager(self, stream_creator): 97 | async with self.lock: 98 | if stream_creator not in self.conversations: 99 | self.conversations[stream_creator] = ConversationsManager(stream_creator) 100 | return self.conversations[stream_creator] 101 | 102 | async def remove_connection(self, stream_creator): 103 | async with self.lock: 104 | self.conversations.pop(stream_creator, None) 105 | 106 | async def handle_datagrams(self, event: H3Event): 107 | log.debug(f"SSH3 server received datagram event: {event}") 108 | if isinstance(event, DatagramFrameReceived): 109 | try: 110 | # Receive a datagram from the QUIC connection 111 | # dgram = qconn.datagram_received() 112 | dgram = event.data 113 | # Process the datagram 114 | # Assuming quic_util.read_var_int and util.bytes_read_closer are defined to parse the conversation ID 115 | buf = util.BytesReadCloser(dgram) 116 | conv_id, err = quic_util.read_var_int(buf) 117 | if err: 118 | log.error(f"Could not read conv id from datagram: {err}") 119 | return 120 | 121 | if conv_id == self.new_conv.control_stream.stream_id: 122 | # Assuming newConv has an AddDatagram method 123 | try: 124 | await self.new_conv.add_datagram(dgram[len(dgram)-buf.remaining():]) 125 | except util.ChannelNotFound as e: 126 | log.warning(f"Could not find channel {e.channel_id}, queuing datagram in the meantime") 127 | except Exception as e: 128 | log.error(f"Could not add datagram to conv id {self.new_conv.control_stream.stream_id}: {e}") 129 | return 130 | else: 131 | log.error(f"Discarding datagram with invalid conv id {conv_id}") 132 | 133 | except asyncio.CancelledError: 134 | # Handling cancellation of the datagram listener 135 | return 136 | except Exception as e: 137 | if not isinstance(e, (asyncio.CancelledError, ConnectionError)): 138 | log.error(f"Could not receive message from connection: {e}") 139 | return 140 | 141 | async def manage_conversation(self,server, authenticated_username, new_conv, conversations_manager, stream_creator): 142 | try: 143 | log.debug(f"Managing conversation: {new_conv.conversation_id}, user {authenticated_username} and stream creator {stream_creator}") 144 | # Call the conversation handler 145 | await self.conversation_handler(authenticated_username, new_conv) 146 | 147 | except asyncio.CancelledError: 148 | # Handle cancellation of the conversation handler 149 | logging.info(f"Conversation canceled for conversation id {new_conv.conversation_id}, user {authenticated_username}") 150 | except Exception as e: 151 | # Log other errors 152 | logging.error(f"Error while handling new conversation: {new_conv.conversation_id} for user {authenticated_username}: {e}") 153 | 154 | finally: 155 | # Perform cleanup on conversation completion or error 156 | if conversations_manager: 157 | await conversations_manager.remove_conversation(new_conv) 158 | if new_conv: 159 | await new_conv.close() # move after remove_conversation? because: File "/usr/lib/python3.8/asyncio/transports.py", line 35, in close raise NotImplementedError 160 | if stream_creator: 161 | await server.remove_connection(stream_creator) 162 | 163 | def get_http_handler_func(self): 164 | """ 165 | Returns a handler function for authenticated HTTP requests. 166 | """ 167 | async def handler(authenticated_username, new_conv, request): 168 | log.info(f"Got auth request: {request}") 169 | log.debug(f"request: {dir(request)}") 170 | # for attr in dir(request): 171 | # try: 172 | # log.debug(f"request.{attr}: {getattr(request, attr)}") # TODO AuthenticationMiddleware must be installed to access request.auth 173 | # except Exception as e: 174 | # pass 175 | log.debug(f"request.url: {request.url}") 176 | log.debug(f"request.header: {request.headers}") 177 | if request.method == "CONNECT" and request.scope.get("scheme", None) == "ssh3": # request.url.scheme == "ssh3": TODO 178 | # Assuming that request_handler can act as a hijacker 179 | 180 | protocols_keys = list(glob.QUIC_SERVER._protocols.keys()) 181 | prot = glob.QUIC_SERVER._protocols[protocols_keys[-1]] 182 | hijacker = prot.hijacker #self.h3_server.hijacker 183 | stream_creator = hijacker.stream_creator() 184 | qcon = hijacker.protocol 185 | self.new_conv = new_conv 186 | conversations_manager = await self.get_or_create_conversations_manager(stream_creator) 187 | await conversations_manager.add_conversation(new_conv) 188 | 189 | # Handling datagrams and conversation 190 | # asyncio.create_task(self.handle_datagrams(qconn=qcon,new_conv=new_conv)) 191 | 192 | self.h3_server.quic_event_received = self.handle_datagrams 193 | 194 | asyncio.create_task(self.manage_conversation(server=self, 195 | authenticated_username=authenticated_username, 196 | new_conv=new_conv, 197 | conversations_manager=conversations_manager, 198 | stream_creator=stream_creator)) 199 | 200 | return Response(status_code=200) 201 | else: 202 | logger.error(f"Invalid request: {request.headers}, {request.scope}") 203 | return Response(status_code=404) 204 | 205 | return handler 206 | -------------------------------------------------------------------------------- /py-ssh3/ssh3/version.py: -------------------------------------------------------------------------------- 1 | 2 | MAJOR = 0 3 | MINOR = 1 4 | PATCH = 0 5 | 6 | class InvalidSSHVersion(Exception): 7 | def __init__(self, version_string): 8 | self.version_string = version_string 9 | 10 | def __str__(self): 11 | return f"invalid ssh version string: {self.version_string}" 12 | 13 | class UnsupportedSSHVersion(Exception): 14 | def __init__(self, version_string): 15 | self.version_string = version_string 16 | 17 | def __str__(self): 18 | return f"unsupported ssh version: {self.version_string}" 19 | 20 | def get_current_version(): 21 | return f"SSH 3.0 ElNiak/py-ssh3 {MAJOR}.{MINOR}.{PATCH}" 22 | 23 | def parse_version(version): 24 | fields = version.split() 25 | if len(fields) != 4 or fields[0] != "SSH" or fields[1] != "3.0": 26 | raise InvalidSSHVersion(version_string=version) 27 | 28 | major_dot_minor = fields[3].split(".") 29 | if len(major_dot_minor) != 3: 30 | raise InvalidSSHVersion(version_string=version) 31 | 32 | try: 33 | major = int(major_dot_minor[0]) 34 | minor = int(major_dot_minor[1]) 35 | patch = int(major_dot_minor[2]) 36 | except ValueError: 37 | raise InvalidSSHVersion(version_string=version) 38 | 39 | return major, minor, patch 40 | -------------------------------------------------------------------------------- /py-ssh3/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/test/__init__.py -------------------------------------------------------------------------------- /py-ssh3/test/integration_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/test/integration_test/__init__.py -------------------------------------------------------------------------------- /py-ssh3/test/integration_test/ssh3_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/test/integration_test/ssh3_test.py -------------------------------------------------------------------------------- /py-ssh3/test/unit_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/test/unit_test/__init__.py -------------------------------------------------------------------------------- /py-ssh3/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/util/__init__.py -------------------------------------------------------------------------------- /py-ssh3/util/globals.py: -------------------------------------------------------------------------------- 1 | APPLICATION = None 2 | ENABLE_PASSWORD_LOGIN = False 3 | DEFAULT_MAX_PACKET_SIZE = 30000 4 | HANDLER_FUNC : callable = None 5 | QUIC_SERVER = None 6 | CONFIGURATION = None 7 | SESSION_TICKET_HANDLER = None -------------------------------------------------------------------------------- /py-ssh3/util/linux_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/util/linux_util/__init__.py -------------------------------------------------------------------------------- /py-ssh3/util/linux_util/agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | def new_unix_socket_path(): 5 | dir = tempfile.mkdtemp(prefix="", dir="/tmp") 6 | return os.path.join(dir, f"agent.{os.getpid()}") -------------------------------------------------------------------------------- /py-ssh3/util/linux_util/cmd.py: -------------------------------------------------------------------------------- 1 | import pty 2 | import subprocess 3 | import os 4 | 5 | def start_with_size_and_pty(command, args, ws, login_shell=False): 6 | master, slave = pty.openpty() 7 | if ws: 8 | pty.set_winsize(master, ws[0], ws[1]) 9 | 10 | env = os.environ.copy() 11 | if login_shell: 12 | args[0] = f"-{os.path.basename(args[0])}" 13 | 14 | process = subprocess.Popen(args, preexec_fn=os.setsid, stdin=slave, stdout=slave, stderr=slave, env=env) 15 | os.close(slave) 16 | return process, master -------------------------------------------------------------------------------- /py-ssh3/util/linux_util/linux_user.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from pwd import getpwnam 3 | from crypt import crypt 4 | import subprocess 5 | import os 6 | import spwd 7 | 8 | class User: 9 | def __init__(self, username, uid, gid, dir, shell): 10 | self.username = username 11 | self.uid = uid 12 | self.gid = gid 13 | self.dir = dir 14 | self.shell = shell 15 | 16 | class ShadowEntry: 17 | def __init__(self, username, password): 18 | self.username = username 19 | self.password = password 20 | 21 | def getspnam(name): 22 | password = spwd.getspnam(name) 23 | return ShadowEntry(password.sp_namp, password.sp_pwdp) 24 | 25 | def user_password_authentication(username, password): 26 | shadow_entry = getspnam(username) 27 | return crypt(password, shadow_entry.password) == shadow_entry.password 28 | 29 | def get_user(username): 30 | pw = getpwnam(username) 31 | return User(pw.pw_name, pw.pw_uid, pw.pw_gid, pw.pw_dir, pw.pw_shell) 32 | 33 | def create_command(user, command, args, login_shell=False): 34 | # Construct subprocess command with user environment 35 | cmd = [command] + args 36 | if login_shell: 37 | cmd[0] = f"-{os.path.basename(command)}" 38 | process = subprocess.Popen(cmd, preexec_fn=os.setsid, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=user['dir']) 39 | return process 40 | -------------------------------------------------------------------------------- /py-ssh3/util/quic_util.py: -------------------------------------------------------------------------------- 1 | 2 | def var_int_len(value): 3 | """ Calculates the length of a variable integer. """ 4 | if value <= 0xFF: 5 | return 1 6 | elif value <= 0xFFFF: 7 | return 2 8 | elif value <= 0xFFFFFFFF: 9 | return 4 10 | else: 11 | return 8 12 | 13 | def var_int_to_bytes(value): 14 | """ Converts a variable integer to bytes. """ 15 | if value <= 0xFF: 16 | return value.to_bytes(1, byteorder='big') 17 | elif value <= 0xFFFF: 18 | return value.to_bytes(2, byteorder='big') 19 | elif value <= 0xFFFFFFFF: 20 | return value.to_bytes(4, byteorder='big') 21 | else: 22 | return value.to_bytes(8, byteorder='big') 23 | 24 | def read_var_int(buf): 25 | """ Reads a variable-length integer from the buffer. """ 26 | first_byte = buf.read(1)[0] 27 | if first_byte <= 0xFF: 28 | return first_byte 29 | elif first_byte <= 0xFFFF: 30 | return int.from_bytes(buf.read(1), byteorder='big', signed=False) + (first_byte << 8) 31 | elif first_byte <= 0xFFFFFFFF: 32 | return int.from_bytes(buf.read(3), byteorder='big', signed=False) + (first_byte << 24) 33 | else: 34 | return int.from_bytes(buf.read(7), byteorder='big', signed=False) + (first_byte << 56) 35 | -------------------------------------------------------------------------------- /py-ssh3/util/type.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | # A JWT bearer token, encoded following the JWT specification 4 | class JWTTokenString: 5 | def __init__(self, token:str): 6 | self.token = token 7 | 8 | class SSHForwardingProtocol: 9 | def __init__(self, value:int): 10 | self.value = value 11 | 12 | class SSHForwardingAddressFamily: 13 | def __init__(self, value:int): 14 | self.value = value 15 | 16 | class ChannelID: 17 | def __init__(self, value:int): 18 | self.value = value 19 | 20 | # SSH forwarding protocols 21 | SSHProtocolUDP = SSHForwardingProtocol(0) 22 | SSHForwardingProtocolTCP = SSHForwardingProtocol(1) 23 | 24 | # SSH forwarding address families 25 | SSHAFIpv4 = SSHForwardingAddressFamily(4) 26 | SSHAFIpv6 = SSHForwardingAddressFamily(6) 27 | 28 | class UserNotFound(Exception): 29 | def __init__(self, username): 30 | super().__init__("User not found: " + username) 31 | self.username = username 32 | 33 | class ChannelNotFound(Exception): 34 | def __init__(self, channelID): 35 | super().__init__("Channel not found: " + str(channelID)) 36 | self.channel_id = channelID 37 | 38 | class InvalidSSHString(Exception): 39 | def __init__(self, reason): 40 | super().__init__("Invalid SSH string: " + str(reason)) 41 | self.reason = reason 42 | 43 | class Unauthorized(Exception): 44 | def __init__(self): 45 | super().__init__("Unauthorized") 46 | 47 | # class BytesReadCloser(io.BufferedReader): 48 | # def __init__(self, reader): 49 | # super().__init__(reader) 50 | # self.reader = reader 51 | 52 | # def read(self): 53 | # return None 54 | 55 | # Sends an SSH3 datagram. The function must know the ID of the channel. 56 | class SSH3DatagramSenderFunc: 57 | def __init__(self, func): 58 | self.Func = func 59 | 60 | # MessageSender interface 61 | class MessageSender: 62 | def __init__(self, send_func): 63 | self.SendMessage = send_func 64 | -------------------------------------------------------------------------------- /py-ssh3/util/util.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import hashlib 3 | from typing import Tuple, List 4 | import contextlib 5 | import base64 6 | from typing import Any, Callable, Optional 7 | import threading 8 | from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey,Ed25519PublicKey 9 | from cryptography import x509 10 | from cryptography.x509.oid import NameOID 11 | from cryptography.hazmat.primitives import serialization 12 | from cryptography.hazmat.primitives import hashes 13 | from cryptography.hazmat.primitives.asymmetric import rsa 14 | import logging 15 | import jwt 16 | from cryptography.hazmat.backends import default_backend 17 | import logging 18 | 19 | log = logging.getLogger(__name__) 20 | 21 | class UnknownSSHPubkeyType(Exception): 22 | def __init__(self, pubkey): 23 | self.pubkey = pubkey 24 | 25 | def __str__(self): 26 | return f"unknown signing method: {type(self.pubkey)}" 27 | 28 | 29 | # copied from "net/http/internal/ascii" 30 | # EqualFold is strings.EqualFold, ASCII only. It reports whether s and t 31 | # are equal, ASCII-case-insensitively. 32 | def equal_fold(s: str, t: str) -> bool: 33 | log.debug(f"equal_fold: s={s}, t={t}") 34 | if len(s) != len(t): 35 | return False 36 | for i in range(len(s)): 37 | if lower(s[i]) != lower(t[i]): 38 | return False 39 | return True 40 | 41 | 42 | # lower returns the ASCII lowercase version of b. 43 | def lower(b: bytes) -> bytes: 44 | log.debug(f"lower: b={b}") 45 | if b.isascii() and b.isupper(): 46 | return bytes([b[i] + (ord('a') - ord('A')) for i in range(len(b))]) 47 | return b 48 | 49 | 50 | def configure_logger(log_level: str) -> None: 51 | log.debug(f"configure_logger: log_level={log_level}") 52 | log_level = log_level.lower() 53 | if log_level == "debug": 54 | logging.log_level = logging.DEBUG 55 | elif log_level == "info": 56 | logging.log_level = logging.INFO 57 | elif log_level == "warning": 58 | logging.log_level = logging.WARN 59 | elif log_level == "error": 60 | logging.log_level = logging.ERROR 61 | else: 62 | logging.log_level = logging.WARN 63 | logging.basicConfig(level=logging.log_level,format="%(asctime)s %(levelname)s %(name)s %(message)s") 64 | 65 | 66 | class AcceptQueue: 67 | def __init__(self) -> None: 68 | self.lock = threading.Lock() 69 | self.c = threading.Condition(self.lock) 70 | self.queue: List = [] 71 | 72 | def add(self, item) -> None: 73 | log.debug(f"AcceptQueue.add: item={item}") 74 | with self.lock: 75 | self.queue.append(item) 76 | self.c.notify() 77 | 78 | def next(self): 79 | log.debug("AcceptQueue.next") 80 | with self.lock: 81 | while not self.queue: 82 | self.c.wait() 83 | return self.queue.pop(0) 84 | 85 | def chan(self) -> threading.Condition: 86 | log.debug("AcceptQueue.chan") 87 | return self.c 88 | 89 | 90 | class DatagramsQueue: 91 | def __init__(self, maxlen: int) -> None: 92 | self.c = threading.Condition() 93 | self.queue: List[bytes] = [] 94 | self.maxlen = maxlen 95 | 96 | def add(self, datagram: bytes) -> bool: 97 | log.debug(f"DatagramsQueue.add: datagram={datagram}") 98 | with self.c: 99 | if len(self.queue) >= self.maxlen: 100 | return False 101 | self.queue.append(datagram) 102 | self.c.notify() 103 | return True 104 | 105 | def wait_add(self, ctx: contextlib.AbstractContextManager, datagram: bytes) -> Optional[Exception]: 106 | log.debug(f"DatagramsQueue.wait_add: datagram={datagram}") 107 | with self.c: 108 | if len(self.queue) >= self.maxlen: 109 | return Exception("queue full") 110 | self.queue.append(datagram) 111 | self.c.notify() 112 | return None 113 | 114 | def next(self) -> Optional[bytes]: 115 | log.debug("DatagramsQueue.next") 116 | with self.c: 117 | if not self.queue: 118 | return None 119 | return self.queue.pop(0) 120 | 121 | def wait_next(self, ctx: contextlib.AbstractContextManager) -> Optional[bytes]: 122 | log.debug("DatagramsQueue.wait_next") 123 | with self.c: 124 | while not self.queue: 125 | if ctx.exception() is not None: 126 | return None 127 | self.c.wait() 128 | return self.queue.pop(0) 129 | 130 | 131 | def jwt_signing_method_from_crypto_pubkey(pubkey) -> Tuple[str, Exception]: 132 | log.debug(f"jwt_signing_method_from_crypto_pubkey: pubkey={pubkey}") 133 | try: 134 | if isinstance(pubkey, rsa.RSAPublicKey): 135 | return "RS256", None 136 | elif isinstance(pubkey, Ed25519PublicKey): 137 | return "EdDSA", None 138 | else: 139 | return None, UnknownSSHPubkeyType(pubkey) 140 | except Exception as e: 141 | return None, e 142 | 143 | 144 | def sha256_fingerprint(in_bytes: bytes) -> str: 145 | log.debug(f"sha256_fingerprint: in_bytes={in_bytes}") 146 | sha256_hash = hashlib.sha256() 147 | sha256_hash.update(in_bytes) 148 | return base64.b64encode(sha256_hash.digest()).decode('utf-8') 149 | 150 | # def get_san_extension(cert_pem): 151 | # # Load the certificate from PEM format 152 | # cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend()) 153 | 154 | # # Look for the Subject Alternative Name extension 155 | # try: 156 | # san_extension = cert.extensions.get_extension_for_oid(x509.OID_SUBJECT_ALTERNATIVE_NAME) 157 | # return san_extension.value 158 | # except x509.ExtensionNotFound: 159 | # return None 160 | 161 | # def for_each_san(der: bytes, callback: Callable[[int, bytes], Exception]) -> Optional[Exception]: 162 | # try: 163 | # idx = 0 164 | # der_len = len(der) 165 | # while idx < der_len: 166 | # tag = int(der[idx]) 167 | # idx += 1 168 | # length = int(der[idx:idx + 2]) 169 | # idx += 2 170 | # if idx + length > der_len: 171 | # raise ValueError("x509: invalid subject alternative name") 172 | # data = der[idx:idx + length] 173 | # idx += length 174 | # err = callback(tag, data) 175 | # if err is not None: 176 | # return err 177 | # return None 178 | # except Exception as e: 179 | # return e 180 | 181 | # def cert_has_ip_sans(cert: x509.Certificate) -> Tuple[bool, Optional[Exception]]: 182 | # SANExtension = get_san_extension(cert) 183 | # if SANExtension is None: 184 | # return False, None 185 | 186 | # name_type_ip = 7 187 | # ip_addresses = [] 188 | 189 | # def callback(tag: int, data: bytes) -> Optional[Exception]: 190 | # if tag == name_type_ip: 191 | # if len(data) == 4 or len(data) == 16: 192 | # ip_addresses.append(data) 193 | # else: 194 | # return ValueError(f"x509: cannot parse IP address of length {len(data)}") 195 | # return None 196 | 197 | # err = for_each_san(SANExtension, callback) 198 | # if err is not None: 199 | # return False, err 200 | 201 | # return len(ip_addresses) > 0, None 202 | 203 | def cert_has_ip_sans(cert_pem): 204 | log.debug(f"cert_has_ip_sans: cert_pem={cert_pem}") 205 | # Load the certificate from PEM format 206 | cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend()) 207 | 208 | # Extract SAN extension 209 | try: 210 | san_extension = cert.extensions.get_extension_for_oid(x509.OID_SUBJECT_ALTERNATIVE_NAME) 211 | except x509.ExtensionNotFound as e: 212 | log.error(f"could not find SAN extension in certificate: {e}") 213 | return False, e 214 | 215 | # Check for IP addresses in the SANs 216 | ip_addresses = [general_name for general_name in san_extension.value 217 | if isinstance(general_name, x509.IPAddress)] 218 | 219 | return len(ip_addresses) > 0, None 220 | 221 | 222 | def generate_key() -> Tuple[Ed25519PublicKey, Ed25519PrivateKey, Optional[Exception]]: 223 | log.debug("generate_key()") 224 | try: 225 | private_key = Ed25519PrivateKey.generate() 226 | signature = private_key.sign(b"my authenticated message") # TODO 227 | public_key = private_key.public_key() 228 | public_key.verify(signature, b"my authenticated message") 229 | log.debug(f"public_key={public_key}") 230 | log.debug(f"private_key={private_key}") 231 | log.debug(f"signature={signature}") 232 | return public_key, private_key, None 233 | except Exception as e: 234 | log.error(f"could not generate key: {e}") 235 | return None, None, e 236 | 237 | # TODO cant use SHA256 for Ed25519 238 | def generate_cert(priv: Ed25519PrivateKey, pub:Ed25519PublicKey) -> Tuple[x509.Certificate, Optional[Exception]]: 239 | log.info(f"generate_cert: priv={priv}") 240 | try: 241 | subject = issuer = x509.Name([ 242 | x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), 243 | x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), 244 | x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), 245 | x509.NameAttribute(NameOID.ORGANIZATION_NAME, "SSH3Organization"), 246 | x509.NameAttribute(NameOID.COMMON_NAME, "elniak.selfsigned.ssh3"), # TODO maybe change for interop 247 | 248 | ]) 249 | cert = x509.CertificateBuilder( 250 | ).subject_name( 251 | subject 252 | ).issuer_name( 253 | issuer 254 | ).public_key( 255 | pub 256 | ).serial_number( 257 | x509.random_serial_number() 258 | ).not_valid_before( 259 | datetime.datetime.now(datetime.timezone.utc) 260 | ).not_valid_after( 261 | # Our certificate will be valid for 10 days 262 | datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=10) 263 | ).add_extension( 264 | x509.SubjectAlternativeName([x509.DNSName("*"), x509.DNSName("elniak.selfsigned.ssh3")]), 265 | critical=False, 266 | # Sign our certificate with our private key 267 | ).add_extension( 268 | x509.BasicConstraints(ca=True, path_length=None), 269 | critical=True, 270 | ).add_extension( 271 | x509.KeyUsage(digital_signature=True, key_encipherment=True, key_cert_sign=True, 272 | key_agreement=False, content_commitment=False, data_encipherment=False, 273 | crl_sign=False, encipher_only=False, decipher_only=False), 274 | critical=True, 275 | ).add_extension( 276 | x509.ExtendedKeyUsage([x509.oid.ExtendedKeyUsageOID.SERVER_AUTH]), 277 | critical=True, 278 | ).sign(priv, hashes.SHA256(), default_backend()) 279 | log.debug(f"cert={cert}") 280 | return cert, None 281 | except Exception as e: 282 | log.error(f"could not generate cert: {e}") 283 | return None, e 284 | 285 | def dump_cert_and_key_to_files(cert: x509.Certificate, priv: Ed25519PrivateKey, cert_file: str, key_file: str) -> Optional[Exception]: 286 | log.info(f"dump_cert_to_file: cert_file={cert_file}, key_file={key_file}") 287 | try: 288 | pem = cert.public_bytes(encoding=serialization.Encoding.PEM) 289 | with open(cert_file, "wb") as f: 290 | f.write(pem) 291 | except Exception as e: 292 | log.error(f"could not dump cert to file: {e}") 293 | return e 294 | 295 | try: 296 | # Now we want to generate a cert from that root 297 | # TODO check if this is correct 298 | key_byte = priv.private_bytes(encoding=serialization.Encoding.PEM, 299 | format=serialization.PrivateFormat.PKCS8, 300 | encryption_algorithm=serialization.NoEncryption()) 301 | with open(key_file, "wb") as f: 302 | f.write(key_byte) 303 | except Exception as e: 304 | log.error(f"could not dump key to file: {e}") 305 | return e 306 | 307 | def parse_ssh_string(buf): 308 | log.debug(f"parse_ssh_string: buf={buf}") 309 | """ Parses an SSH formatted string from the buffer. """ 310 | length = int.from_bytes(buf.read(4), byteorder='big') 311 | return buf.read(length).decode('utf-8') 312 | 313 | def write_ssh_string(buf, value): 314 | log.debug(f"write_ssh_string: value={value}") 315 | """ Writes an SSH formatted string into the buffer. """ 316 | encoded_value = value.encode('utf-8') 317 | buf.extend(len(encoded_value).to_bytes(4, byteorder='big')) 318 | buf.extend(encoded_value) 319 | return len(encoded_value) + 4 320 | 321 | def read_boolean(buf): 322 | log.debug("read_boolean") 323 | """ Reads a boolean value from the buffer. """ 324 | return buf.read(1)[0] != 0 325 | 326 | def ssh_string_len(s): 327 | log.debug(f"ssh_string_len: s={s}") 328 | # Length of a 32-bit integer in bytes is 4 329 | int_length = 4 330 | # Length of the string 331 | str_length = len(s) 332 | # Total length is the length of the integer plus the length of the string 333 | return int_length + str_length -------------------------------------------------------------------------------- /py-ssh3/util/waitgroup.py: -------------------------------------------------------------------------------- 1 | import threading # :( 2 | 3 | class WaitGroup(object): 4 | """WaitGroup is like Go sync.WaitGroup. 5 | 6 | Without all the useful corner cases. 7 | """ 8 | def __init__(self): 9 | self.count = 0 10 | self.cv = threading.Condition() 11 | 12 | def add(self, n): 13 | self.cv.acquire() 14 | self.count += n 15 | self.cv.release() 16 | 17 | def done(self): 18 | self.cv.acquire() 19 | self.count -= 1 20 | if self.count == 0: 21 | self.cv.notify_all() 22 | self.cv.release() 23 | 24 | def wait(self): 25 | self.cv.acquire() 26 | while self.count > 0: 27 | self.cv.wait() 28 | self.cv.release() -------------------------------------------------------------------------------- /py-ssh3/util/wire.py: -------------------------------------------------------------------------------- 1 | import io 2 | from message.channel_request import * 3 | from util.type import * 4 | 5 | # taken from the QUIC draft 6 | Min = 0 7 | Max = 4611686018427387903 8 | maxVarInt1 = 63 9 | maxVarInt2 = 16383 10 | maxVarInt4 = 1073741823 11 | maxVarInt8 = 4611686018427387903 12 | 13 | # class Reader(io.ByteReader, io.Reader): 14 | # pass 15 | 16 | # def NewReader(r): 17 | # if isinstance(r, Reader): 18 | # return r 19 | # return byteReader(r) 20 | 21 | # class byteReader(Reader): 22 | # def ReadByte(self): 23 | # b = self.Reader.read(1) 24 | # if len(b) == 1: 25 | # return b[0], None 26 | # return None, io.EOF 27 | 28 | # class Writer(io.ByteWriter, io.Writer): 29 | # pass 30 | 31 | # def NewWriter(w): 32 | # if isinstance(w, Writer): 33 | # return w 34 | # return byteWriter(w) 35 | 36 | # class byteWriter(Writer): 37 | # def WriteByte(self, c): 38 | # return self.Writer.write(bytes([c])) 39 | 40 | def read_varint(r): 41 | firstByte, err = r.ReadByte() 42 | if err is not None: 43 | return 0, err 44 | length = 1 << ((firstByte & 0xc0) >> 6) 45 | b1 = firstByte & (0xff - 0xc0) 46 | if length == 1: 47 | return int(b1), None 48 | b2, err = r.ReadByte() 49 | if err is not None: 50 | return 0, err 51 | if length == 2: 52 | return int(b2) + (int(b1) << 8), None 53 | b3, err = r.ReadByte() 54 | if err is not None: 55 | return 0, err 56 | b4, err = r.ReadByte() 57 | if err is not None: 58 | return 0, err 59 | if length == 4: 60 | return int(b4) + (int(b3) << 8) + (int(b2) << 16) + (int(b1) << 24), None 61 | b5, err = r.ReadByte() 62 | if err is not None: 63 | return 0, err 64 | b6, err = r.ReadByte() 65 | if err is not None: 66 | return 0, err 67 | b7, err = r.ReadByte() 68 | if err is not None: 69 | return 0, err 70 | b8, err = r.ReadByte() 71 | if err is not None: 72 | return 0, err 73 | if length == 8: 74 | return int(b8) + (int(b7) << 8) + (int(b6) << 16) + (int(b5) << 24) + (int(b4) << 32) + (int(b3) << 40) + (int(b2) << 48) + (int(b1) << 56), None 75 | 76 | def append_varint(b, i): 77 | if i <= maxVarInt1: 78 | return b + bytes([i]) 79 | if i <= maxVarInt2: 80 | return b + bytes([(i >> 8) | 0x40, i]) 81 | if i <= maxVarInt4: 82 | return b + bytes([(i >> 24) | 0x80, (i >> 16) & 0xff, (i >> 8) & 0xff, i & 0xff]) 83 | if i <= maxVarInt8: 84 | return b + bytes([ 85 | (i >> 56) | 0xc0, (i >> 48) & 0xff, (i >> 40) & 0xff, (i >> 32) & 0xff, 86 | (i >> 24) & 0xff, (i >> 16) & 0xff, (i >> 8) & 0xff, i & 0xff 87 | ]) 88 | raise Exception("%x doesn't fit into 62 bits",i) 89 | 90 | def append_varintWithLen(b, i, length): 91 | if length != 1 and length != 2 and length != 4 and length != 8: 92 | raise Exception("invalid varint length") 93 | l = varint_len(i) 94 | if l == length: 95 | return append_varint(b, i) 96 | if l > length: 97 | raise Exception("cannot encode %d in %d bytes", i, length) 98 | if length == 2: 99 | b = b + bytes([0b01000000]) 100 | elif length == 4: 101 | b = b + bytes([0b10000000]) 102 | elif length == 8: 103 | b = b + bytes([0b11000000]) 104 | for j in range(1, length-l): 105 | b = b + bytes([0]) 106 | for j in range(l): 107 | b = b + bytes([(i >> (8 * (l - 1 - j))) & 0xff]) 108 | return b 109 | 110 | def varint_len(i): 111 | if i <= maxVarInt1: 112 | return 1 113 | if i <= maxVarInt2: 114 | return 2 115 | if i <= maxVarInt4: 116 | return 4 117 | if i <= maxVarInt8: 118 | return 8 119 | raise Exception("value doesn't fit into 62 bits: %x",i) 120 | 121 | def ParseSSHString(buf): 122 | length, err = read_varint(buf) 123 | if err is not None: 124 | return "", InvalidSSHString(err) 125 | out = bytearray(length) 126 | n, err = io.ReadFull(buf, out) 127 | if n != length: 128 | return "", InvalidSSHString("expected length %d, read length %d",length, n) 129 | if err is not None and err != io.EOF: 130 | return "", err 131 | return out.decode('utf-8'), err 132 | 133 | def WriteSSHString(out, s): 134 | if len(out) < SSHStringLen(s): 135 | raise Exception("buffer too small to write varint: %d < %d", len(out), SSHStringLen(s)) 136 | buf = append_varint(bytearray(), len(s)) 137 | out = out + buf 138 | out = out + s.encode('utf-8') 139 | return len(out), None 140 | 141 | def SSHStringLen(s): 142 | return varint_len(len(s)) + len(s) 143 | 144 | def MinUint64(a, b): 145 | if a <= b: 146 | return a 147 | return b 148 | 149 | # import struct 150 | # import io 151 | 152 | # # Constants for QUIC varints 153 | # MAX_VAR_INT1 = 63 154 | # MAX_VAR_INT2 = 16383 155 | # MAX_VAR_INT4 = 1073741823 156 | # MAX_VAR_INT8 = 4611686018427387903 157 | 158 | # class ByteReader: 159 | # # A ByteReader class implementing io.ByteReader and io.Reader interfaces 160 | # def __init__(self, reader): 161 | # self.reader = reader 162 | 163 | # def read_byte(self): 164 | # return self.reader.read(1) 165 | 166 | # def read(self, n=-1): 167 | # return self.reader.read(n) 168 | 169 | # def new_reader(reader): 170 | # # Returns a ByteReader for the given reader 171 | # return ByteReader(reader) 172 | 173 | # def read_varint(reader): 174 | # # Read a QUIC varint from the given reader 175 | # first_byte = ord(reader.read_byte()) 176 | # length = 1 << ((first_byte & 0xc0) >> 6) 177 | # b1 = first_byte & 0x3f 178 | # if length == 1: 179 | # return b1 180 | # b2 = ord(reader.read_byte()) 181 | # if length == 2: 182 | # return (b2 << 8) | b1 183 | # b3 = ord(reader.read_byte()) 184 | # b4 = ord(reader.read_byte()) 185 | # if length == 4: 186 | # return (b4 << 24) | (b3 << 16) | (b2 << 8) | b1 187 | # b5 = ord(reader.read_byte()) 188 | # b6 = ord(reader.read_byte()) 189 | # b7 = ord(reader.read_byte()) 190 | # b8 = ord(reader.read_byte()) 191 | # return (b8 << 56) | (b7 << 48) | (b6 << 40) | (b5 << 32) | (b4 << 24) | (b3 << 16) | (b2 << 8) | b1 192 | 193 | # def append_varint(b, i): 194 | # # Append a QUIC varint to the given byte array 195 | # if i <= MAX_VAR_INT1: 196 | # return b + struct.pack('B', i) 197 | # if i <= MAX_VAR_INT2: 198 | # return b + struct.pack('>H', i | 0x4000) 199 | # if i <= MAX_VAR_INT4: 200 | # return b + struct.pack('>I', i | 0x80000000) 201 | # if i <= MAX_VAR_INT8: 202 | # return b + struct.pack('>Q', i | 0xC000000000000000) 203 | # raise ValueError(f"{i} doesn't fit into 62 bits") 204 | 205 | # def varint_len(i): 206 | # # Determine the number of bytes needed to write the number i 207 | # if i <= MAX_VAR_INT1: 208 | # return 1 209 | # if i <= MAX_VAR_INT2: 210 | # return 2 211 | # if i <= MAX_VAR_INT4: 212 | # return 4 213 | # if i <= MAX_VAR_INT8: 214 | # return 8 215 | # raise ValueError(f"value doesn't fit into 62 bits: {i}") 216 | 217 | # def parse_ssh_string(buf): 218 | # # Parse an SSH string from the given buffer 219 | # length, _ = read_varint(buf) 220 | # return buf.read(length).decode('utf-8') 221 | 222 | # def write_ssh_string(out, s): 223 | # # Write an SSH string to the given output buffer 224 | # length = len(s) 225 | # out.write(append_varint(b'', length)) 226 | # out.write(s.encode('utf-8')) 227 | # return len(out.getvalue()) 228 | 229 | # def ssh_string_len(s): 230 | # # Calculate the length of an SSH string 231 | # return varint_len(len(s)) + len(s) 232 | 233 | # def min_uint64(a, b): 234 | # # Return the minimum of two uint64 values 235 | # return min(a, b) 236 | -------------------------------------------------------------------------------- /py-ssh3/winsize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/py-ssh3/winsize/__init__.py -------------------------------------------------------------------------------- /py-ssh3/winsize/common.py: -------------------------------------------------------------------------------- 1 | class WindowSize: 2 | def __init__(self, nrows=0, ncols=0, pixel_width=0, pixel_height=0): 3 | self.nrows = nrows 4 | self.ncols = ncols 5 | self.pixel_width = pixel_width 6 | self.pixel_height = pixel_height 7 | -------------------------------------------------------------------------------- /py-ssh3/winsize/winsize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fcntl 3 | import termios 4 | import struct 5 | from winsize.common import WindowSize 6 | 7 | def get_winsize_unix(): 8 | ws = WindowSize() 9 | try: 10 | packed = fcntl.ioctl(0, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0)) 11 | rows, cols, xpix, ypix = struct.unpack('HHHH', packed) 12 | ws = WindowSize(nrows=rows, ncols=cols, pixel_width=xpix, pixel_height=ypix) 13 | except Exception as e: 14 | print(f"Error getting window size: {e}") 15 | return ws 16 | -------------------------------------------------------------------------------- /py-ssh3/winsize/winsize_windows.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from ctypes import wintypes 3 | from winsize.common import WindowSize 4 | 5 | class CONSOLE_SCREEN_BUFFER_INFO(ctypes.Structure): 6 | _fields_ = [("dwSize", wintypes._COORD), 7 | ("dwCursorPosition", wintypes._COORD), 8 | ("wAttributes", wintypes.WORD), 9 | ("srWindow", wintypes.SMALL_RECT), 10 | ("dwMaximumWindowSize", wintypes._COORD)] 11 | 12 | def get_winsize_windows(): 13 | ws = WindowSize() 14 | h_std_out = ctypes.windll.kernel32.GetStdHandle(-11) # STD_OUTPUT_HANDLE 15 | csbi = CONSOLE_SCREEN_BUFFER_INFO() 16 | ctypes.windll.kernel32.GetConsoleScreenBufferInfo(h_std_out, ctypes.byref(csbi)) 17 | ws.ncols, ws.nrows = csbi.srWindow.Right - csbi.srWindow.Left + 1, csbi.srWindow.Bottom - csbi.srWindow.Top + 1 18 | return ws 19 | -------------------------------------------------------------------------------- /qlogs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/qlogs/.gitkeep -------------------------------------------------------------------------------- /qlogs_client/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElNiak/PySSH3/842429c479551afe8b7f6e219dd515015d53987f/qlogs_client/.gitkeep -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="py-ssh3", 5 | version="0.1", 6 | description="Python SSH3 version", 7 | author="ElNiak", 8 | author_email="elniak@email.com", 9 | packages=find_packages(), 10 | install_requires=[ 11 | # "aiohttp", 12 | # "pyOpenSSL", 13 | "cryptography", 14 | # "aioquic==0.9.24", 15 | "pyjwt", 16 | # "http3", 17 | "authlib", 18 | # "PyCryptodome", 19 | # "sanic", 20 | "paramiko", 21 | # "h11==0.9.0", 22 | "wsproto", 23 | "jinja2", 24 | "starlette", 25 | ], 26 | ) 27 | 28 | --------------------------------------------------------------------------------